In [10]:
import time
import torch
import torch.nn as nn
import sys
import sys
sys.path.append('../../')
from lib.factorization.layers import LowRankLinear, LowRankConv2d

import time
import torch
import torch.nn as nn

# assumes LowRankLinear and LowRankConv2d are already defined

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
torch.manual_seed(0)
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

B_lin, IN_lin, OUT_lin = 1024, 2048, 2048
B_c, Cin, Cout, H, W, K = 32, 256, 512, 14, 14, 3
stride, padding, groups = 1, 1, 1
ranks_conv = [1, 2, 4, 5, 6, 7, 8, 16, 32, 64, 128, 130, 140, 146, 180, 256]
ranks_lin = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]

warmup_iters, bench_iters = 10, 30

def bench(model, x):
    model.to(device)
    x = x.to(device)
    # warmup
    for _ in range(warmup_iters):
        y = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(bench_iters):
        y = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / bench_iters  # ms/iter

# Linear
x_lin = torch.randn(B_lin, IN_lin, device=device, dtype=dtype)
dense_linear = nn.Linear(IN_lin, OUT_lin).to(device, dtype)
dense_ms = bench(dense_linear, x_lin)
print(f"nn.Linear             | {dense_ms:.2f} ms/iter | baseline")

for r in ranks_lin:
    if True:
        lr_lin = LowRankLinear(IN_lin, OUT_lin, r).to(device, dtype)
        ms = bench(lr_lin, x_lin)
        speedup = dense_ms / ms
        nparams_orig = IN_lin * OUT_lin
        nparams_lr = r * (IN_lin + OUT_lin)
        nflops_orig = 2 * B_lin * IN_lin * OUT_lin
        nflops_lr = 2 * B_lin * r * (IN_lin + OUT_lin)
        print(f"LowRankLinear(rank={r:3d}) | {ms:.2f} ms/iter | speedup x{speedup:.2f} | params {nparams_lr/nparams_orig:.2%}, FLOPs {nflops_lr/nflops_orig:.2%} of original")

# Conv
x_conv = torch.randn(B_c, Cin, H, W, device=device, dtype=dtype)
dense_conv = nn.Conv2d(Cin, Cout, K, stride=stride, padding=padding, groups=groups).to(device, dtype)
dense_ms = bench(dense_conv, x_conv)
print(f"\nnn.Conv2d             | {dense_ms:.2f} ms/iter | baseline")

for r in ranks_conv:
    if True:
        lr_conv = LowRankConv2d(Cin, Cout, K, r, stride=stride, padding=padding, groups=groups).to(device, dtype)
        ms = bench(lr_conv, x_conv)
        speedup = dense_ms / ms
        nparams_orig = (Cin) * K * K * Cout
        nparams_lr = Cin * r * K * K + r * Cout
        flops_orig = 2 * B_c * (Cin) * K * K * H * W * Cout
        flops_lr = 2 * B_c * (Cin) * K * K * H * W * r + 2 * B_c * r * H * W * Cout
        print(f"LowRankConv2d(rank={r:3d}) | {ms:.2f} ms/iter | speedup x{speedup:.2f} | params {nparams_lr/nparams_orig:.2%}, FLOPs {flops_lr/flops_orig:.2%} of original")


nn.Linear             | 2.15 ms/iter | baseline
LowRankLinear(rank=  1) | 0.13 ms/iter | speedup x16.84 | params 0.10%, FLOPs 0.10% of original
LowRankLinear(rank=  2) | 0.13 ms/iter | speedup x16.77 | params 0.20%, FLOPs 0.20% of original
LowRankLinear(rank=  4) | 0.12 ms/iter | speedup x17.29 | params 0.39%, FLOPs 0.39% of original
LowRankLinear(rank=  8) | 0.13 ms/iter | speedup x16.71 | params 0.78%, FLOPs 0.78% of original
LowRankLinear(rank= 16) | 0.13 ms/iter | speedup x16.65 | params 1.56%, FLOPs 1.56% of original
LowRankLinear(rank= 32) | 0.14 ms/iter | speedup x15.24 | params 3.12%, FLOPs 3.12% of original
LowRankLinear(rank= 64) | 0.19 ms/iter | speedup x11.64 | params 6.25%, FLOPs 6.25% of original
LowRankLinear(rank=128) | 0.35 ms/iter | speedup x6.14 | params 12.50%, FLOPs 12.50% of original
LowRankLinear(rank=256) | 0.58 ms/iter | speedup x3.70 | params 25.00%, FLOPs 25.00% of original
LowRankLinear(rank=512) | 1.12 ms/iter | speedup x1.92 | params 50.00%, FLOPs 50.00% o

In [14]:
!nvidia-smi


Wed Sep 10 17:15:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.99                 Driver Version: 555.99         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2070      WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   51C    P3             34W /  115W |    7861MiB /   8192MiB |     17%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                