In [1]:
import torch
import time
import pandas as pd
from torch import nn

from conv_gemm.baseline_layers.triton_conv2d_int8 import TritonConv2dINT8
from conv_gemm.triton_kernels.int8.int8_quant import quantize_int8_sym_tensor  

In [2]:
device = "cuda"
torch.manual_seed(0)


def bench_ms(fn, iters=50):
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters

In [3]:
@torch.no_grad()
def load_fp32_weights_and_calibrate_activations_sym(
    int8_layer: TritonConv2dINT8,
    w_f: torch.Tensor,      # FP32/FP16 веса Conv2d: [Cout, Cin, Kh, Kw]
    b_f: torch.Tensor | None,
    act_sample: torch.Tensor,  # FP16/FP32 активации для калибровки
):
    """
    PTQ-калибровка:
    - снаружи считаем кванты и scale'ы,
    - внутрь слоя передаём уже посчитанные параметры,
    - реальное копирование в буферы делает сам слой через load_quant_params.
    """

    # 1) Квантизация весов (Только real 4D, БЕЗ K_pad!)
    w_q, w_scale, _ = quantize_int8_sym_tensor(w_f.float())
    w_q = w_q.to(torch.int8)

    # 2) Калибровка scale для активаций по выборке
    _, act_scale, _ = quantize_int8_sym_tensor(act_sample.float())

    # 3) Делегируем копирование внутрь класса
    int8_layer.load_quant_params(
        w_q=w_q,
        w_scale=w_scale,
        act_scale=act_scale,
        bias=b_f.float() if b_f is not None else None,
    )

In [4]:
#                      FULL BENCH
# ============================================================
def run_int8_conv_bench_sym_only(
    image_sizes=(32, 64, 112, 224, 512),
    batch_sizes=(1, 2, 4),
    channels=((1, 1), (1, 3), (3, 8), (8, 16), (16, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters=100,
):
    """
    Бенчмаркаем INT8 симметричную свёртку против FP16 cuDNN.
    """

    rows = []

    for H in image_sizes:
        W = H

        for N in batch_sizes:
            for Cin, Cout in channels:
                for K in kernels:

                    # пропускаем ядра, которые больше входа
                    if K > H or K > W:
                        continue

                    print(f"[bench] img={H}, N={N}, Cin={Cin}, Cout={Cout}, K={K}")

                    # -------------------------------------------------------
                    # FP16 INPUT
                    # -------------------------------------------------------
                    x_fp16 = torch.randn(N, Cin, H, W, device=device).half()

                    # -------------------------------------------------------
                    # FP16 BASELINE (cuDNN)
                    # -------------------------------------------------------
                    conv_ref = nn.Conv2d(
                        Cin, Cout, kernel_size=K,
                        stride=1, padding=K // 2, bias=True
                    ).to(device).half()

                    # -------------------------------------------------------
                    # INT8 LAYER
                    # -------------------------------------------------------
                    conv_int8 = TritonConv2dINT8(
                        in_channels=Cin,
                        out_channels=Cout,
                        kernel_size=K,
                        stride=1,
                        padding=K // 2,
                        dilation=1,
                        bias=True,
                    ).to(device)

                    # -------------------------------------------------------
                    # PTQ (SYMM)
                    # -------------------------------------------------------
                    load_fp32_weights_and_calibrate_activations_sym(
                        int8_layer=conv_int8,
                        w_f=conv_ref.weight,
                        b_f=conv_ref.bias,
                        act_sample=x_fp16,
                    )

                    # -------------------------------------------------------
                    # FORWARD FP16
                    # -------------------------------------------------------
                    with torch.no_grad():
                        y_ref = conv_ref(x_fp16).float()

                    # -------------------------------------------------------
                    # FORWARD INT8
                    # -------------------------------------------------------
                    with torch.no_grad():
                        x_q = quantize_int8_sym_tensor(x_fp16)[0]  # quant only
                        y_int8 = conv_int8(x_q).float()

                    # -------------------------------------------------------
                    # ERRORS
                    # -------------------------------------------------------
                    diff = (y_ref - y_int8).abs()
                    err_max = diff.max().item()
                    err_mean = diff.mean().item()

                    # -------------------------------------------------------
                    # PERFORMANCE
                    # -------------------------------------------------------
                    try:
                        t_fp16 = bench_ms(lambda: conv_ref(x_fp16), iters)

                        def int8_step():
                            x_q_local = quantize_int8_sym_tensor(x_fp16)[0]
                            conv_int8(x_q_local)

                        t_int8 = bench_ms(int8_step, iters)
                        speedup = t_fp16 / t_int8
                    except Exception as e:
                        t_fp16 = t_int8 = speedup = None

                    rows.append([
                        H, N, Cin, Cout, K,
                        t_fp16, t_int8, speedup,
                        err_max, err_mean, None
                    ])

    df = pd.DataFrame(
        rows,
        columns=[
            "img", "N", "Cin", "Cout", "K",
            "t_fp16_ms", "t_int8_ms", "speedup",
            "err_max", "err_mean", "note",
        ],
    )

    return df


In [5]:
channels_cfg = (
    (1, 1),
    (1, 3),
    (3, 8),
    (8, 16),
    (16, 32),
    (32, 64),
)

df = run_int8_conv_bench_sym_only(
    image_sizes=(32, 64, 128),
    batch_sizes=(1, 2, 4,8),
    channels=channels_cfg,
    kernels=(1, 3, 5, 7, 9, 11),
    iters=100,
)

[bench] img=32, N=1, Cin=1, Cout=1, K=1
[bench] img=32, N=1, Cin=1, Cout=1, K=3
[bench] img=32, N=1, Cin=1, Cout=1, K=5
[bench] img=32, N=1, Cin=1, Cout=1, K=7
[bench] img=32, N=1, Cin=1, Cout=1, K=9
[bench] img=32, N=1, Cin=1, Cout=1, K=11
[bench] img=32, N=1, Cin=1, Cout=3, K=1
[bench] img=32, N=1, Cin=1, Cout=3, K=3
[bench] img=32, N=1, Cin=1, Cout=3, K=5
[bench] img=32, N=1, Cin=1, Cout=3, K=7
[bench] img=32, N=1, Cin=1, Cout=3, K=9
[bench] img=32, N=1, Cin=1, Cout=3, K=11
[bench] img=32, N=1, Cin=3, Cout=8, K=1
[bench] img=32, N=1, Cin=3, Cout=8, K=3
[bench] img=32, N=1, Cin=3, Cout=8, K=5
[bench] img=32, N=1, Cin=3, Cout=8, K=7
[bench] img=32, N=1, Cin=3, Cout=8, K=9
[bench] img=32, N=1, Cin=3, Cout=8, K=11
[bench] img=32, N=1, Cin=8, Cout=16, K=1
[bench] img=32, N=1, Cin=8, Cout=16, K=3
[bench] img=32, N=1, Cin=8, Cout=16, K=5
[bench] img=32, N=1, Cin=8, Cout=16, K=7
[bench] img=32, N=1, Cin=8, Cout=16, K=9
[bench] img=32, N=1, Cin=8, Cout=16, K=11
[bench] img=32, N=1, Cin=16, C

In [6]:
# Топ по ускорению среди валидных конфигов
df_valid = df.dropna(subset=["t_fp16_ms", "t_int8_ms", "speedup"])

df_top = df_valid.sort_values("speedup", ascending=False).head(30)
df_top

Unnamed: 0,img,N,Cin,Cout,K,t_fp16_ms,t_int8_ms,speedup,err_max,err_mean,note
107,32,4,32,64,11,0.21621,1.045249,0.20685,0.026095,0.00434,
101,32,4,16,32,11,0.11739,0.763383,0.153776,0.028369,0.004782,
29,32,1,16,32,11,0.062397,0.408018,0.152926,0.022163,0.003889,
57,32,2,8,16,7,0.049257,0.332632,0.148083,0.02286,0.004046,
233,64,4,3,8,11,0.089536,0.607915,0.147283,0.027238,0.004427,
31,32,1,32,64,3,0.047379,0.32393,0.146262,0.029725,0.005281,
21,32,1,8,16,7,0.050526,0.345798,0.146113,0.020164,0.004076,
305,128,1,3,8,11,0.088284,0.60941,0.144868,0.024246,0.004432,
413,128,8,3,8,11,0.421617,2.919846,0.144397,0.033809,0.005308,
26,32,1,16,32,5,0.046914,0.325119,0.144297,0.022798,0.0042,
