In [1]:
import torch
import time
import pandas as pd
from torch import nn

from conv_gemm.baseline_layers.triton_conv2d_int8 import TritonConv2dINT8
from conv_gemm.triton_kernels.int8.int8_quant import quantize_int8_sym_tensor  

In [2]:
device = "cuda"
torch.manual_seed(0)

def bench_ms(fn, iters=50, warmup=10):
    """
    fn: функция без аргументов, внутри делает cuda-работу
    warmup: сколько раз вызвать fn до замера (прогрев)
    iters: сколько раз мерить
    """
    # прогрев
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters


In [3]:
@torch.no_grad()
def load_fp32_weights_and_calibrate_activations_sym(
    int8_layer: TritonConv2dINT8,
    w_f: torch.Tensor,      # FP32/FP16 веса Conv2d: [Cout, Cin, Kh, Kw]
    b_f: torch.Tensor | None,
    act_sample: torch.Tensor,  # FP16/FP32 активации для калибровки
):
    """
    PTQ-калибровка:
    - снаружи считаем кванты и scale'ы,
    - внутрь слоя передаём уже посчитанные параметры,
    - реальное копирование в буферы делает сам слой через load_quant_params.
    """

    # 1) Квантизация весов (Только real 4D, БЕЗ K_pad!)
    w_q, w_scale, _ = quantize_int8_sym_tensor(w_f.float())
    w_q = w_q.to(torch.int8)

    # 2) Калибровка scale для активаций по выборке
    _, act_scale, _ = quantize_int8_sym_tensor(act_sample.float())

    # 3) Делегируем копирование внутрь класса
    int8_layer.load_quant_params(
        w_q=w_q,
        w_scale=w_scale,
        act_scale=act_scale,
        bias=b_f.float() if b_f is not None else None,
    )

In [16]:
# ============================================================
#     FULL BENCH — INT8 vs FP16 (честный пайплайн)
# ============================================================
def run_int8_conv_bench_sym_only(
    image_sizes=(32, 64, 112, 224, 512),
    batch_sizes=(1, 2, 4),
    channels=((1, 1), (1, 3), (3, 8), (8, 16), (16, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters=100,
):
    """
    Бенчмаркаем INT8 симметричную свёртку против FP16 cuDNN.

    ВАЖНО:
    - PTQ (квант веса + калибровка act_scale) делаем ОДИН РАЗ до тайминга.
    - Квант входа считаем через conv_int8.quantize_input (apply scale),
      и для тайминга фиксируем x_q_static, чтобы замерять ЧИСТЫЙ conv.
    """

    rows = []

    for H in image_sizes:
        W = H

        for N in batch_sizes:
            for Cin, Cout in channels:
                for K in kernels:

                    # пропускаем ядра, которые больше входа
                    if K > H or K > W:
                        continue

                    print(f"[bench] img={H}, N={N}, Cin={Cin}, Cout={Cout}, K={K}")

                    # -------------------------------------------------------
                    # FP16 INPUT
                    # -------------------------------------------------------
                    x_fp16 = torch.randn(N, Cin, H, W, device=device).half()

                    # -------------------------------------------------------
                    # FP16 BASELINE (cuDNN)
                    # -------------------------------------------------------
                    conv_ref = nn.Conv2d(
                        Cin, Cout, kernel_size=K,
                        stride=1, padding=K // 2, bias=True
                    ).to(device).half()

                    # -------------------------------------------------------
                    # INT8 LAYER
                    # -------------------------------------------------------
                    conv_int8 = TritonConv2dINT8(
                        in_channels=Cin,
                        out_channels=Cout,
                        kernel_size=K,
                        stride=1,
                        padding=K // 2,
                        dilation=1,
                        bias=True,
                    ).to(device)

                    # -------------------------------------------------------
                    # PTQ (SYMM): квантуем веса + калибруем act_scale
                    # -------------------------------------------------------
                    load_fp32_weights_and_calibrate_activations_sym(
                        int8_layer=conv_int8,
                        w_f=conv_ref.weight,   # FP16/FP32 веса
                        b_f=conv_ref.bias,     # FP16/FP32 bias
                        act_sample=x_fp16,     # пример активаций
                    )

                    # -------------------------------------------------------
                    # FORWARD FP16 (reference)
                    # -------------------------------------------------------
                    with torch.no_grad():
                        y_ref = conv_ref(x_fp16).float()

                    # -------------------------------------------------------
                    # FORWARD INT8 (с уже настроенным act_scale)
                    # -------------------------------------------------------
                    with torch.no_grad():
                        # применяем уже откалиброванный act_scale
                        x_q = conv_int8.quantize_input(x_fp16)
                        y_int8 = conv_int8(x_q).float()

                    # -------------------------------------------------------
                    # ERRORS
                    # -------------------------------------------------------
                    diff = (y_ref - y_int8).abs()
                    err_max = diff.max().item()
                    err_mean = diff.mean().item()

                    # -------------------------------------------------------
                    # PERFORMANCE (ЧЕСТНЫЙ ЗАМЕР)
                    #   t_fp16: чистый cuDNN Conv2d
                    #   t_int8: чистый Triton INT8 Conv (без квантайзера)
                    # -------------------------------------------------------
                    try:
                        # baseline FP16
                        t_fp16 = bench_ms(lambda: conv_ref(x_fp16), iters)

                        # фиксируем один раз x_q_static
                        with torch.no_grad():
                            x_q_static = conv_int8.quantize_input(x_fp16)

                        def int8_step():
                            conv_int8(x_q_static)

                        t_int8 = bench_ms(int8_step, iters)
                        speedup = t_fp16 / t_int8 if t_int8 is not None and t_int8 > 0 else None
                    except Exception as e:
                        t_fp16 = t_int8 = speedup = None

                    rows.append([
                        H, N, Cin, Cout, K,
                        t_fp16, t_int8, speedup,
                        err_max, err_mean, None
                    ])

    df = pd.DataFrame(
        rows,
        columns=[
            "img", "N", "Cin", "Cout", "K",
            "t_fp16_ms", "t_int8_ms", "speedup",
            "err_max", "err_mean", "note",
        ],
    )

    return df


In [None]:
channels_cfg = (
    (1, 1),
    (1, 3),
    (3, 8),
    (8, 16),
    (16, 32),
    (32, 64),
)

df = run_int8_conv_bench_sym_only(
    image_sizes=(32, 64, 128, 224),
    batch_sizes=(1, 2, 4,8),
    channels=channels_cfg,
    kernels=(1, 3, 5, 7, 9, 11),
    iters=100,
)

[bench] img=32, N=1, Cin=1, Cout=1, K=1
[bench] img=32, N=1, Cin=1, Cout=1, K=3
[bench] img=32, N=1, Cin=1, Cout=1, K=5
[bench] img=32, N=1, Cin=1, Cout=1, K=7
[bench] img=32, N=1, Cin=1, Cout=1, K=9
[bench] img=32, N=1, Cin=1, Cout=1, K=11
[bench] img=32, N=1, Cin=1, Cout=3, K=1
[bench] img=32, N=1, Cin=1, Cout=3, K=3
[bench] img=32, N=1, Cin=1, Cout=3, K=5
[bench] img=32, N=1, Cin=1, Cout=3, K=7
[bench] img=32, N=1, Cin=1, Cout=3, K=9
[bench] img=32, N=1, Cin=1, Cout=3, K=11
[bench] img=32, N=1, Cin=3, Cout=8, K=1
[bench] img=32, N=1, Cin=3, Cout=8, K=3
[bench] img=32, N=1, Cin=3, Cout=8, K=5
[bench] img=32, N=1, Cin=3, Cout=8, K=7
[bench] img=32, N=1, Cin=3, Cout=8, K=9
[bench] img=32, N=1, Cin=3, Cout=8, K=11
[bench] img=32, N=1, Cin=8, Cout=16, K=1
[bench] img=32, N=1, Cin=8, Cout=16, K=3
[bench] img=32, N=1, Cin=8, Cout=16, K=5
[bench] img=32, N=1, Cin=8, Cout=16, K=7
[bench] img=32, N=1, Cin=8, Cout=16, K=9
[bench] img=32, N=1, Cin=8, Cout=16, K=11
[bench] img=32, N=1, Cin=16, C

In [None]:
# Топ по ускорению среди валидных конфигов
df_valid = df.dropna(subset=["t_fp16_ms", "t_int8_ms", "speedup"])

df_top = df_valid.sort_values("speedup", ascending=False).head(30)
df_top

______________________________________________________________________