In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
from img2col_int8_kernel import img2col_int8
from gemm_int8_kernel    import gemm_int8_tc
from col2img_int8_kernel import col2img_int32
import time
import pandas as pd
import torch
import math
import time
os.environ["TORCH_USE_CUDA_DSA"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton

In [3]:
I2C_BLOCK_M = 128
I2C_BLOCK_K = 32
I2C_WARPS   = 2
I2C_STAGES  = 3

# GEMM INT8
GEMM_BLOCK_M = 64
GEMM_BLOCK_N = 128
GEMM_BLOCK_K = 64
GEMM_WARPS   = 4
GEMM_STAGES  = 3

# col2img INT32
C2I_BLOCK_M = 32
C2I_BLOCK_K = 32
C2I_WARPS   = 4
C2I_STAGES  = 3


In [4]:
def bench_mean(fn, iters=50, warmup=10):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()

    return (time.perf_counter() - t0) / iters


In [8]:
def conv2d_int8_forward(
    x, w, bias,
    stride, padding, dilation
):
    # ======== ВСЕ КОНСТАНТЫ БЕРЕМ ИЗ JUPYTER =========
    global I2C_BLOCK_M, I2C_BLOCK_K, I2C_WARPS, I2C_STAGES
    global GEMM_BLOCK_M, GEMM_BLOCK_N, GEMM_BLOCK_K, GEMM_WARPS, GEMM_STAGES

    # floats
    x_f = x.float()
    w_f = w.float()

    N, Cin, H, W = x_f.shape
    Cout, Cin2, Kh, Kw = w_f.shape
    assert Cin == Cin2

    Sh, Sw = stride
    Ph, Pw = padding
    Dh, Dw = dilation

    Ho = (H + 2*Ph - Dh*(Kh-1) - 1)//Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw-1) - 1)//Sw + 1
    M = N * Ho * Wo

    # ====== K-padding ======
    K_real = Cin * Kh * Kw
    K4 = (K_real + 3) // 4 * 4
    K_pad = (K4 + GEMM_BLOCK_K - 1)//GEMM_BLOCK_K * GEMM_BLOCK_K

    # ====== quant ======
    sx = max(x_f.abs().max().item()/127, 1e-8)
    sw = max(w_f.abs().max().item()/127, 1e-8)

    x_q = torch.clamp((x_f/sx).round(), -128, 127).to(torch.int8)
    w_q = torch.clamp((w_f/sw).round(), -128, 127).to(torch.int8)

    # ====== IMG2COL ======
    cols_q, _ = img2col_int8(
        x_q,
        Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
        K_pad,
        I2C_BLOCK_M, I2C_BLOCK_K, I2C_WARPS, I2C_STAGES
    )

    # ====== PREPARE WEIGHTS ======
    Wq = w_q.view(Cout, K_real).t().contiguous()
    if K_pad > K_real:
        pad = torch.zeros((K_pad-K_real, Cout), dtype=torch.int8, device=Wq.device)
        Wq = torch.cat([Wq, pad], dim=0)

    # ====== GEMM ======
    C_i32 = gemm_int8_tc(
        cols_q, Wq,
        GEMM_BLOCK_M, GEMM_BLOCK_N, GEMM_BLOCK_K,
        GEMM_WARPS, GEMM_STAGES
    )

    # ====== DEQUANT ======
    y = C_i32.float() * (sx * sw)
    if bias is not None:
        y += bias.float().reshape(1, Cout)

    return y.reshape(N, Ho, Wo, Cout).permute(0,3,1,2).contiguous()


In [9]:
import torch
import torch.nn as nn
import pandas as pd
import triton
import triton.runtime
import triton.runtime.jit as jit
def run_int8_forward_bench(
    image_sizes=[56,112,224],
    batch_sizes=[1,2,4],
    channels=[(1,1),(1,3),(3,8),(8,16),(16,32),(32,64)],
    kernels=[1,3,7,9,11,13],
    iters=50
):

    rows = []

    total = len(image_sizes)*len(batch_sizes)*len(channels)*len(kernels)
    step  = 0

    for img in image_sizes:
        H = W = img

        for N in batch_sizes:
            for Cin, Cout in channels:
                for ks in kernels:

                    step += 1
                    print(f"[{step}/{total}] img={img}  N={N}  Cin={Cin} Cout={Cout} K={ks}")

                    x = torch.randn(N, Cin, H, W, device='cuda', dtype=torch.float16)

                    conv = nn.Conv2d(
                        Cin, Cout, ks,
                        stride=1,
                        padding=ks//2,
                        bias=True
                    ).cuda().half()

                    w = conv.weight
                    b = conv.bias

                    # === FP16 baseline ===
                    t_fp16 = bench_mean(lambda: conv(x), iters)
                    y_ref = conv(x).float()

                    # === INT8 forward ===
                    y_int = conv2d_int8_forward(
                        x, w, b,
                        stride=(1,1),
                        padding=(ks//2, ks//2),
                        dilation=(1,1)
                    )

                    if y_int is None:
                        rows.append([img, N, Cin, Cout, ks,
                                     t_fp16*1e3, None, None, None])

                        # === CLEAR CACHE ===
                        torch.cuda.synchronize()
                        jit.specialize_impl_cache.clear()
                        torch.cuda.empty_cache()
                        continue

                    # === INT8 benchmark ===
                    t_int8 = bench_mean(
                        lambda: conv2d_int8_forward(
                            x, w, b,
                            stride=(1,1),
                            padding=(ks//2, ks//2),
                            dilation=(1,1)
                        ),
                        iters
                    )

                    err = (y_ref - y_int.float()).abs().max().item()

                    rows.append([
                        img, N, Cin, Cout, ks,
                        t_fp16*1e3,
                        t_int8*1e3,
                        t_fp16 / t_int8,
                        err
                    ])

                    # ====================================
                    #            CLEAR MT CACHE
                    # ====================================
                    torch.cuda.synchronize()
                    
                    jit.specialize_impl_cache.clear()
                    torch.cuda.empty_cache()

    df = pd.DataFrame(rows, columns=[
        "img","N","Cin","Cout","K",
        "Conv2D FP16 (ms)","Int8 (ms)",
        "speedup","err_max"
    ])
    return df


In [10]:
df = run_int8_forward_bench(
    image_sizes=[112,224,512,1024],
    batch_sizes=[1,2],
    channels=[(1,1),(1,3),(3,8),(8,16)],
    kernels=[1,3,7,9,11],
    iters=50
)



[1/160] img=112  N=1  Cin=1 Cout=1 K=1
[2/160] img=112  N=1  Cin=1 Cout=1 K=3
[3/160] img=112  N=1  Cin=1 Cout=1 K=7
[4/160] img=112  N=1  Cin=1 Cout=1 K=9
[5/160] img=112  N=1  Cin=1 Cout=1 K=11
[6/160] img=112  N=1  Cin=1 Cout=3 K=1
[7/160] img=112  N=1  Cin=1 Cout=3 K=3
[8/160] img=112  N=1  Cin=1 Cout=3 K=7
[9/160] img=112  N=1  Cin=1 Cout=3 K=9
[10/160] img=112  N=1  Cin=1 Cout=3 K=11
[11/160] img=112  N=1  Cin=3 Cout=8 K=1
[12/160] img=112  N=1  Cin=3 Cout=8 K=3
[13/160] img=112  N=1  Cin=3 Cout=8 K=7
[14/160] img=112  N=1  Cin=3 Cout=8 K=9
[15/160] img=112  N=1  Cin=3 Cout=8 K=11
[16/160] img=112  N=1  Cin=8 Cout=16 K=1
[17/160] img=112  N=1  Cin=8 Cout=16 K=3
[18/160] img=112  N=1  Cin=8 Cout=16 K=7
[19/160] img=112  N=1  Cin=8 Cout=16 K=9
[20/160] img=112  N=1  Cin=8 Cout=16 K=11
[21/160] img=112  N=2  Cin=1 Cout=1 K=1
[22/160] img=112  N=2  Cin=1 Cout=1 K=3
[23/160] img=112  N=2  Cin=1 Cout=1 K=7
[24/160] img=112  N=2  Cin=1 Cout=1 K=9
[25/160] img=112  N=2  Cin=1 Cout=1 K=11

In [12]:
df_sorted = df.sort_values(by="speedup", ascending=False)
df_sorted.head(50)

Unnamed: 0,img,N,Cin,Cout,K,Conv2D FP16 (ms),Int8 (ms),speedup,err_max
4,112,1,1,1,11,0.378325,0.551725,0.685713,0.026297
0,112,1,1,1,1,0.25769,0.500004,0.515376,0.008919
53,224,1,3,8,9,0.489019,1.205028,0.405815,0.032138
33,112,2,3,8,9,0.43203,1.175181,0.367628,0.028213
38,112,2,8,16,9,0.472144,1.355743,0.348255,0.030536
91,512,1,3,8,3,0.564765,1.859998,0.303637,0.035408
75,224,2,8,16,1,0.278038,0.930263,0.298881,0.042226
69,224,2,1,3,11,0.405824,1.525819,0.265971,0.029239
32,112,2,3,8,7,0.274263,1.114813,0.246017,0.029469
64,224,2,1,1,11,0.290742,1.182929,0.245781,0.03235
