In [1]:
# %% Sparse-by-K (structured) for Conv2d -> im2col -> GEMM
import torch, torch.nn.functional as F
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
try: torch.set_float32_matmul_precision("high")
except: pass

In [2]:
device = "cuda"
dtype  = torch.float16 

# ----- ПАРАМЕТРЫ -----
N, C, H, W   = 4, 64, 128, 128     
OUT, KH, KW  = 128, 3, 3          
SH, SW       = 1, 1
PH, PW       = 1, 1
DH, DW       = 1, 1

In [3]:
target_sparsity = 0.5

In [5]:
g = torch.Generator(device=device).manual_seed(0)
x = torch.randn(N, C, H, W, device=device, dtype=dtype, generator=g)
w = torch.randn(OUT, C, KH, KW, device=device, dtype=dtype, generator=g)
b = torch.randn(OUT, device=device, dtype=dtype, generator=g)

# ----- ВЫХОДНЫЕ РАЗМЕРЫ -----
H_OUT = (H + 2*PH - DH*(KH-1) - 1)//SH + 1
W_OUT = (W + 2*PW - DW*(KW-1) - 1)//SW + 1
L     = H_OUT * W_OUT
K     = C * KH * KW
print(f"Input:  {tuple(x.shape)}, Weights: {tuple(w.shape)}, Bias: {tuple(b.shape)}")
print(f"Output: (N, OUT, H_OUT, W_OUT) = ({N}, {OUT}, {H_OUT}, {W_OUT})")
print(f"K = C*KH*KW = {K}, L = H_OUT*W_OUT = {L}")

Input:  (4, 64, 128, 128), Weights: (128, 64, 3, 3), Bias: (128,)
Output: (N, OUT, H_OUT, W_OUT) = (4, 128, 128, 128)
K = C*KH*KW = 576, L = H_OUT*W_OUT = 16384


In [6]:
# ===================== БАЗА: cuDNN Conv2d =====================
for _ in range(5):
    y_ref = F.conv2d(x, w, b, stride=(SH,SW), padding=(PH,PW), dilation=(DH,DW))

torch.cuda.synchronize()
e0s, e0e = torch.cuda.Event(True), torch.cuda.Event(True)
e0s.record()
for _ in range(20):
    y_ref = F.conv2d(x, w, b, stride=(SH,SW), padding=(PH,PW), dilation=(DH,DW))
e0e.record(); torch.cuda.synchronize()
t_conv_ms = e0s.elapsed_time(e0e) / 20.0

In [9]:
import torch
import torch.nn.functional as F

def _out_hw(H, W, KH, KW, SH, SW, PH, PW, DH, DW):
    H_OUT = (H + 2*PH - DH*(KH-1) - 1)//SH + 1
    W_OUT = (W + 2*PW - DW*(KW-1) - 1)//SW + 1
    return H_OUT, W_OUT

@torch.inference_mode()
def benchmark_total(x, w, b,
                    KH, KW, SH, SW, PH, PW, DH, DW,
                    iters_total=50,
                    prune_topk_ratio=None,   # например 0.5 оставит 50% K-столбцов
                    int8_emulation=False,    # True — эмулировать INT8 (квантизация, dequant, GEMM FP16)
                    seed=0):
    assert x.is_cuda and w.is_cuda and b.is_cuda
    torch.manual_seed(seed)
    torch.cuda.synchronize()

    # ---- формы
    N, C_in, H, W = x.shape
    OUT = w.shape[0]
    K = C_in * KH * KW
    H_OUT, W_OUT = _out_hw(H, W, KH, KW, SH, SW, PH, PW, DH, DW)
    L = H_OUT * W_OUT

    # ---- baseline: цельная conv2d (TOTAL)
    for _ in range(3):
        _ = F.conv2d(x, w, b, stride=(SH,SW), padding=(PH,PW), dilation=(DH,DW))
    torch.cuda.synchronize()
    ecs, ece = torch.cuda.Event(True), torch.cuda.Event(True)
    ecs.record()
    for _ in range(max(10, iters_total//2)):
        _ = F.conv2d(x, w, b, stride=(SH,SW), padding=(PH,PW), dilation=(DH,DW))
    ece.record(); torch.cuda.synchronize()
    t_conv2d_total_ms = ecs.elapsed_time(ece) / max(10, iters_total//2)

    results = {
        "H_OUT": H_OUT, "W_OUT": W_OUT, "L": L, "K": K, "OUT": OUT,
        "CONV2D_TOTAL_ms": t_conv2d_total_ms
    }

    # ---- 1) DENSE FP16 TOTAL: unfold + (prep) + GEMM + bias + fold   (всё внутри одного цикла)
    for _ in range(3):
        X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW))     # (N,K,L)
        A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()                             # (N*L, K)
        B2d = w.view(OUT, -1).transpose(0, 1).contiguous()                                  # (K, OUT)
        Y2d = A2d @ B2d
        Y2d_bias = Y2d + b.view(1, OUT)
        Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()
        _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)
    torch.cuda.synchronize()

    e1s, e1e = torch.cuda.Event(True), torch.cuda.Event(True)
    e1s.record()
    for _ in range(iters_total):
        # unfold
        X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW))     # (N,K,L)
        # prep матриц для GEMM
        A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()                             # (N*L, K)
        B2d = w.view(OUT, -1).transpose(0, 1).contiguous()                                  # (K, OUT)
        # GEMM
        Y2d = A2d @ B2d                                                                     # (N*L, OUT)
        # bias + fold
        Y2d_bias = Y2d + b.view(1, OUT)
        Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()                       # (N,OUT,L)
        _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)                                    # (N,OUT,H_OUT,W_OUT)
    e1e.record(); torch.cuda.synchronize()
    results["DENSE_FP16_TOTAL_ms"] = e1s.elapsed_time(e1e) / iters_total

    # ---- 2) PRUNED-K TOTAL: считаем «спецификацию» (маску top-|K|), режем столбцы, тот же путь
    if prune_topk_ratio is not None and 0 < prune_topk_ratio < 1.0:
        K_keep = max(1, int(round(K * prune_topk_ratio)))

        # прогрев + вычисление важности по L2 (по столбцам K)
        with torch.cuda.amp.autocast(dtype=torch.float16):
            W2d = w.view(OUT, -1).transpose(0, 1).contiguous()                              # (K, OUT)
            # важность по столбцу K как L2 по OUT
            importance = (W2d.float().pow(2).sum(dim=1))                                    # (K,)
            keep_idx = torch.topk(importance, k=K_keep, dim=0).indices.sort()[0]            # отсортируем индексы

        # прогрев полного цикла
        for _ in range(3):
            X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW)) # (N,K,L)
            A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()                         # (N*L,K)
            A2d_k = A2d.index_select(dim=1, index=keep_idx)                                 # (N*L, K_keep)
            B2d_k = W2d.index_select(dim=0, index=keep_idx).contiguous()                    # (K_keep, OUT)
            Y2d = A2d_k @ B2d_k
            Y2d_bias = Y2d + b.view(1, OUT)
            Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()
            _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)

        torch.cuda.synchronize()
        e2s, e2e = torch.cuda.Event(True), torch.cuda.Event(True)
        e2s.record()
        for _ in range(iters_total):
            # unfold
            X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW)) # (N,K,L)
            # calc importance + topk (включаем в TOTAL по твоему требованию)
            W2d = w.view(OUT, -1).transpose(0, 1).contiguous()                              # (K,OUT)
            importance = (W2d.float().pow(2).sum(dim=1))
            keep_idx = torch.topk(importance, k=K_keep, dim=0).indices.sort()[0]
            # подготовка A,B под срез
            A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()
            A2d_k = A2d.index_select(dim=1, index=keep_idx)                                 # (N*L, K_keep)
            B2d_k = W2d.index_select(dim=0, index=keep_idx).contiguous()                    # (K_keep, OUT)
            # GEMM
            Y2d = A2d_k @ B2d_k
            # bias + fold
            Y2d_bias = Y2d + b.view(1, OUT)
            Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()
            _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)
        e2e.record(); torch.cuda.synchronize()

        results["PRUNED_K_ratio"] = prune_topk_ratio
        results["PRUNED_K_TOTAL_ms"] = e2s.elapsed_time(e2e) / iters_total
        results["PRUNED_K_keep"] = K_keep

    # ---- 3) INT8 EMULATION TOTAL: квант A/B → dequant scale → GEMM в FP16 → bias → fold
    # ВНИМАНИЕ: это корректность INT8, НО не скорость настоящего INT8 на CUDA.
    if int8_emulation:
        def quant_per_tensor(xf):
            # симметричная PTQ: scale = max(|x|)/127 (per-tensor)
            maxv = xf.abs().amax()
            scale = (maxv / 127.0).clamp(min=1e-12)
            qi = torch.clamp((xf / scale).round(), -128, 127).to(torch.int8)
            return qi, scale

        for _ in range(3):
            # unfold
            X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW)) # (N,K,L)
            A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()                         # (N*L, K)
            B2d = w.view(OUT, -1).transpose(0, 1).contiguous()                              # (K, OUT)
            # квант
            Ai, sA = quant_per_tensor(A2d.float())
            Bi, sB = quant_per_tensor(B2d.float())
            # dequant в FP16 для GEMM (эмуляция)
            A_deq = (Ai.float() * sA).half().contiguous()
            B_deq = (Bi.float() * sB).half().contiguous()
            Y2d = A_deq @ B_deq
            Y2d_bias = Y2d + b.view(1, OUT)
            Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()
            _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)

        torch.cuda.synchronize()
        e3s, e3e = torch.cuda.Event(True), torch.cuda.Event(True)
        e3s.record()
        for _ in range(iters_total):
            X_col = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW))
            A2d = X_col.transpose(1, 2).reshape(-1, K).contiguous()
            B2d = w.view(OUT, -1).transpose(0, 1).contiguous()
            Ai, sA = quant_per_tensor(A2d.float())
            Bi, sB = quant_per_tensor(B2d.float())
            A_deq = (Ai.float() * sA).half().contiguous()
            B_deq = (Bi.float() * sB).half().contiguous()
            Y2d = A_deq @ B_deq
            Y2d_bias = Y2d + b.view(1, OUT)
            Y_col = Y2d_bias.view(N, L, OUT).transpose(1, 2).contiguous()
            _ = F.fold(Y_col, (H_OUT, W_OUT), kernel_size=1)
        e3e.record(); torch.cuda.synchronize()

        results["INT8_EMU_TOTAL_ms"] = e3s.elapsed_time(e3e) / iters_total

        # посчитаем ошибку относительно dense
        with torch.cuda.amp.autocast(dtype=torch.float16):
            # эталон (один прогон)
            Xc = F.unfold(x, (KH,KW), dilation=(DH,DW), padding=(PH,PW), stride=(SH,SW))
            A2d = Xc.transpose(1, 2).reshape(-1, K).contiguous()
            B2d = w.view(OUT, -1).transpose(0, 1).contiguous()
            Y_ref = A2d @ B2d
            Y_ref = (Y_ref + b.view(1, OUT)).view(N, L, OUT).transpose(1, 2).contiguous()
            y_ref = F.fold(Y_ref, (H_OUT, W_OUT), kernel_size=1)

            # int8 эмуляция (один прогон)
            Ai, sA = quant_per_tensor(A2d.float())
            Bi, sB = quant_per_tensor(B2d.float())
            A_deq = (Ai.float() * sA).half().contiguous()
            B_deq = (Bi.float() * sB).half().contiguous()
            Y_q = A_deq @ B_deq
            Y_q = (Y_q + b.view(1, OUT)).view(N, L, OUT).transpose(1, 2).contiguous()
            y_q = F.fold(Y_q, (H_OUT, W_OUT), kernel_size=1)

        err = (y_q.float() - y_ref.float())
        results["INT8_EMU_max_abs_err"] = err.abs().amax().item()
        results["INT8_EMU_mse"] = (err.pow(2).mean().item())

    # ---- GFLOP/s для текущего GEMM dense (информативно)
    flops_dense = 2.0 * (N*L) * K * OUT
    results["DENSE_GEMM_GFLOP_s_est"] = (flops_dense / 1e9) / (results["DENSE_FP16_TOTAL_ms"] / 1e3)

    return results

res = benchmark_total(x, w, b, KH,KW, SH,SW, PH,PW, DH,DW,
                      iters_total=50,
                      prune_topk_ratio=0.5,   # или None
                      int8_emulation=True)
print(res)


  with torch.cuda.amp.autocast(dtype=torch.float16):


{'H_OUT': 128, 'W_OUT': 128, 'L': 16384, 'K': 576, 'OUT': 128, 'CONV2D_TOTAL_ms': 0.429752311706543, 'DENSE_FP16_TOTAL_ms': 2.4207154846191408, 'PRUNED_K_ratio': 0.5, 'PRUNED_K_TOTAL_ms': 2.400809020996094, 'PRUNED_K_keep': 288, 'INT8_EMU_TOTAL_ms': 6.909661865234375, 'INT8_EMU_max_abs_err': 2.109375, 'INT8_EMU_mse': 0.14916929602622986, 'DENSE_GEMM_GFLOP_s_est': 3992.074441379639}


  with torch.cuda.amp.autocast(dtype=torch.float16):
