In [1]:
import torch
import time
from torch import nn

# === импорт твоего слоя INT8 ===
from conv_gemm.layers.triton_conv2d_int8 import TritonConv2dINT8

In [2]:
torch.manual_seed(0)
device = "cuda"

In [3]:
N = 1
Cin = 3
Cout = 8
H = W = 32
K = 3

x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)
conv_ref = nn.Conv2d(Cin, Cout, K, padding=K//2, bias=True).to(device)

In [4]:
conv_int8 = TritonConv2dINT8(
    Cin, Cout, K,
    padding=K//2,
    precision_mode="int8_infer"   # чистый INT8 инференс
).to(device)

# base forvard

In [9]:
with torch.no_grad():
    conv_int8.weight.copy_(conv_ref.weight)
    if conv_ref.bias is not None:
        conv_int8.bias.copy_(conv_ref.bias)

In [10]:
# FP32 reference
y_ref = conv_ref(x)

# INT8 Triton
y_int8 = conv_int8(x)

print("y_ref.shape:", y_ref.shape)
print("y_int8.shape:", y_int8.shape)

y_ref.shape: torch.Size([1, 8, 32, 32])
y_int8.shape: torch.Size([1, 8, 32, 32])


In [11]:
# ============================================================
#                   ОЦЕНКА ТОЧНОСТИ
# ============================================================

err = (y_ref - y_int8).abs()
print("\n=== ACCURACY CHECK ===")
print("max error:", err.max().item())
print("mean error:", err.mean().item())


=== ACCURACY CHECK ===
max error: 2.6488137245178223
mean error: 0.44135573506355286


In [12]:
def bench(fn, iters=200):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - start) * 1000 / iters  # ms

t_ref = bench(lambda: conv_ref(x))
t_int8 = bench(lambda: conv_int8(x))

print("\n=== SPEED (ms) ===")
print(f"PyTorch FP32 Conv2D:   {t_ref:.3f} ms")
print(f"Triton INT8 Conv2D:    {t_int8:.3f} ms")
print(f"Speedup: {t_ref / t_int8:.3f}x")


=== SPEED (ms) ===
PyTorch FP32 Conv2D:   0.045 ms
Triton INT8 Conv2D:    0.545 ms
Speedup: 0.083x


# BENCH FORVARD

In [13]:
import torch
import time
import pandas as pd
from torch import nn

from conv_gemm.layers.triton_conv2d_int8 import TritonConv2dINT8  # путь подгони под себя

device = "cuda"
torch.manual_seed(0)


def bench_ms(fn, iters=50):
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters


def run_int8_conv_bench(
    image_sizes=(32, 64, 112, 224, 512),
    batch_sizes=(1, 2, 3, 4),
    channels=((1, 1), (1, 3), (3, 8), (8, 16), (16, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters=100,
):
    """
    channels: кортеж пар (Cin, Cout),
      например: ((1,1), (1,3), (3,8), (8,16), (16,32))
    """
    rows = []

    for H in image_sizes:
        W = H
        for N in batch_sizes:
            for (Cin, Cout) in channels:
                for K in kernels:
                    # Нормальная свёртка не умеет kernel > spatial
                    if K > H or K > W:
                        continue

                    print(f"[bench] img={H} N={N} Cin={Cin} Cout={Cout} K={K}")

                    # Создаём вход и две свёртки
                    x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)

                    conv_ref = nn.Conv2d(
                        Cin, Cout, kernel_size=K,
                        stride=1,
                        padding=K // 2,
                        bias=True,
                    ).to(device)

                    conv_int8 = TritonConv2dINT8(
                        in_channels=Cin,
                        out_channels=Cout,
                        kernel_size=K,
                        stride=1,
                        padding=K // 2,
                        dilation=1,
                        bias=True,
                        precision_mode="int8_infer",   # чистый инференс
                        use_weight_shadow=False,
                    ).to(device)

                    # Копируем веса/биасы, чтобы честно сравнивать
                    with torch.no_grad():
                        conv_int8.weight.copy_(conv_ref.weight)
                        if conv_ref.bias is not None:
                            conv_int8.bias.copy_(conv_ref.bias)

                    # FP32 reference
                    try:
                        with torch.no_grad():
                            y_ref = conv_ref(x)
                    except Exception as e:
                        print("  [SKIP] FP32 conv failed:", e)
                        rows.append([H, N, Cin, Cout, K,
                                     None, None, None, None, None, str(e)])
                        continue

                    # INT8 forward (один прогон для ошибки)
                    try:
                        with torch.no_grad():
                            y_int8 = conv_int8(x)
                    except Exception as e:
                        print("  [SKIP] INT8 conv failed:", e)
                        rows.append([H, N, Cin, Cout, K,
                                     None, None, None, None, None, f"int8_fail: {e}"])
                        continue

                    # Ошибки
                    err = (y_ref - y_int8).abs()
                    err_max = err.max().item()
                    err_mean = err.mean().item()

                    # Бенч
                    try:
                        t_ref = bench_ms(lambda: conv_ref(x), iters=iters)
                        t_int8 = bench_ms(lambda: conv_int8(x), iters=iters)
                        speedup = t_ref / t_int8 if t_int8 > 0 else None
                    except Exception as e:
                        print("  [WARN] bench failed:", e)
                        t_ref = t_int8 = speedup = None

                    rows.append([
                        H, N, Cin, Cout, K,
                        t_ref, t_int8, speedup,
                        err_max, err_mean,
                        None,
                    ])

    df = pd.DataFrame(
        rows,
        columns=[
            "img", "N", "Cin", "Cout", "K",
            "t_ref_ms", "t_int8_ms", "speedup",
            "err_max", "err_mean", "note",
        ],
    )
    return df




In [14]:
channels_cfg = (
    (1, 1),
    (1, 3),
    (3, 8),
    (8, 16),
    (16, 32),
    (32, 64),
)

df = run_int8_conv_bench(
    image_sizes=(32, 64, 112, 224),
    batch_sizes=(1, 2, 4, 8),
    channels=channels_cfg,
    kernels=(1, 3, 5, 7, 9, 11),
    iters=30,
)

[bench] img=32 N=1 Cin=1 Cout=1 K=1
[bench] img=32 N=1 Cin=1 Cout=1 K=3
[bench] img=32 N=1 Cin=1 Cout=1 K=5
[bench] img=32 N=1 Cin=1 Cout=1 K=7
[bench] img=32 N=1 Cin=1 Cout=1 K=9
[bench] img=32 N=1 Cin=1 Cout=1 K=11
[bench] img=32 N=1 Cin=1 Cout=3 K=1
[bench] img=32 N=1 Cin=1 Cout=3 K=3
[bench] img=32 N=1 Cin=1 Cout=3 K=5
[bench] img=32 N=1 Cin=1 Cout=3 K=7
[bench] img=32 N=1 Cin=1 Cout=3 K=9
[bench] img=32 N=1 Cin=1 Cout=3 K=11
[bench] img=32 N=1 Cin=3 Cout=8 K=1
[bench] img=32 N=1 Cin=3 Cout=8 K=3
[bench] img=32 N=1 Cin=3 Cout=8 K=5
[bench] img=32 N=1 Cin=3 Cout=8 K=7
[bench] img=32 N=1 Cin=3 Cout=8 K=9
[bench] img=32 N=1 Cin=3 Cout=8 K=11
[bench] img=32 N=1 Cin=8 Cout=16 K=1
[bench] img=32 N=1 Cin=8 Cout=16 K=3
[bench] img=32 N=1 Cin=8 Cout=16 K=5
[bench] img=32 N=1 Cin=8 Cout=16 K=7
[bench] img=32 N=1 Cin=8 Cout=16 K=9
[bench] img=32 N=1 Cin=8 Cout=16 K=11
[bench] img=32 N=1 Cin=16 Cout=32 K=1
[bench] img=32 N=1 Cin=16 Cout=32 K=3
[bench] img=32 N=1 Cin=16 Cout=32 K=5
[bench] img=

In [15]:
# Топ по ускорению среди конфигов, где всё отработало
df_valid = df.dropna(subset=["t_ref_ms", "t_int8_ms", "speedup"])
df_top = df_valid.sort_values("speedup", ascending=False).head(30)
df_top

Unnamed: 0,img,N,Cin,Cout,K,t_ref_ms,t_int8_ms,speedup,err_max,err_mean,note
242,64,4,16,32,5,0.694116,0.592287,1.171924,2.656394,0.452488,
278,64,8,16,32,5,0.509389,0.974933,0.522486,3.487287,0.447657,
521,224,4,3,8,11,1.052221,4.354588,0.241635,2.957939,0.452319,
557,224,8,3,8,11,2.083238,8.77552,0.237392,2.881968,0.449797,
71,32,2,32,64,11,0.16516,0.733606,0.225135,2.454377,0.416937,
177,64,1,32,64,7,0.13454,0.608317,0.221167,2.725905,0.447466,
556,224,8,3,8,9,1.34333,6.215819,0.216115,3.06981,0.451914,
35,32,1,32,64,11,0.118732,0.554172,0.214252,2.567429,0.419619,
485,224,2,3,8,11,0.514722,2.420974,0.212609,2.999856,0.459054,
558,224,8,8,16,1,0.293859,1.394796,0.210682,3.578571,0.43917,


# base backvard

In [5]:
def bench_backward_ms(module, x, iters=50):
    """
    Меряем время полного прохода: forward + backward по loss = y.sum().
    Градиенты по весам и входу считаются, но не используются.
    """
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        # новый вход с requires_grad, чтобы каждый раз строился граф
        x_in = x.detach().clone().requires_grad_(True)
        module.zero_grad(set_to_none=True)

        y = module(x_in)
        loss = y.sum()
        loss.backward()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters  # ms

In [6]:
# ==== CHECK GRADS (один прогон) ====
x_ref = x.detach().clone().requires_grad_(True)
x_int = x.detach().clone().requires_grad_(True)

conv_ref.zero_grad(set_to_none=True)
conv_int8.zero_grad(set_to_none=True)

y_ref = conv_ref(x_ref)
y_int = conv_int8(x_int)

loss_ref = y_ref.sum()
loss_int = y_int.sum()

loss_ref.backward()
loss_int.backward()

# градиенты по входу
dx_err = (x_ref.grad - x_int.grad).abs()
dx_err_max = dx_err.max().item()
dx_err_mean = dx_err.mean().item()

# градиенты по весам
dw_err = (conv_ref.weight.grad - conv_int8.weight.grad).abs()
dw_err_max = dw_err.max().item()
dw_err_mean = dw_err.mean().item()

print("\n=== GRAD CHECK ===")
print(f"dX  max err: {dx_err_max:.6e}, mean err: {dx_err_mean:.6e}")
print(f"dW  max err: {dw_err_max:.6e}, mean err: {dw_err_mean:.6e}")


CompilationError: at 38:15:
            mask=mask_a, other=0
        )
        b = tl.load(
            B_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn,
            mask=mask_b, other=0
        )

        if USE_FP16:
            a = a.to(tl.float16)
            b = b.to(tl.float16)
        # acc остаётся fp32
        acc += tl.dot(a, b, allow_tf32=False)
               ^