In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [2]:
import time
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import pandas as pd
from conv_gemm.triton_kernels.int8.img2col_int8_kernel import img2col_int8_kernel

In [3]:
def img2col_int8(
    x_q,
    Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
    K_pad,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
):
    N, Cin, H, W = x_q.shape

    Ho = (H + 2*Ph - Dh*(Kh - 1) - 1)//Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw - 1) - 1)//Sw + 1
    M = N * Ho * Wo
    K_real = Cin * Kh * Kw

    cols_q = torch.empty((M, K_pad), dtype=torch.int8, device=x_q.device)

    sN, sC, sH, sW = x_q.stride()

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K_pad, BLOCK_K))

    img2col_int8_kernel[grid](
        x_q, cols_q,
        N, Cin, H, W,
        Kh, Kw,
        Sh, Sw,
        Ph, Pw,
        Dh, Dw,
        Ho, Wo,
        sN, sC, sH, sW,
        K_real, K_pad,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages
    )

    return cols_q, (Ho, Wo)

In [4]:
@torch.no_grad()
def bench_once_img2col_int8_vs_unfold(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):

    x_q = torch.randint(-128, 127, (N, Cin, H, W),
                        device=device, dtype=torch.int8)
    x_f = x_q.float()

    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1

    K = Cin * Kh * Kw
    M = N * Ho * Wo
    K_pad = ((K + 3) // 4) * 4  

    #  F.unfold 
    def _call_torch():
        return F.unfold(
            x_f,
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )

    for _ in range(5):
        _ = _call_torch()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        unf = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    N2, KC, L = unf.shape
    assert N2 == N
    assert KC == K
    assert L == Ho * Wo

    cols_ref = unf.permute(0, 2, 1).contiguous().view(M, K)  # [M, K]

    # img2col_int8
    if K_pad != K:
        cols_ref_pad = torch.zeros((M, K_pad), device=device, dtype=cols_ref.dtype)
        cols_ref_pad[:, :K] = cols_ref
    else:
        cols_ref_pad = cols_ref

    cols_ref_q = cols_ref_pad.to(torch.int8)
    
    def _call_triton():
        cols_q, (Ho2, Wo2) = img2col_int8(
            x_q,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            K_pad,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        assert Ho2 == Ho and Wo2 == Wo
        assert cols_q.shape == (M, K_pad)
        return cols_q

    for _ in range(5):
        cols_q = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        cols_q = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # bandwidth 
    bytes_moved = 2.0 * M * K_pad  
    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch / 1e9

    return {
        "M": M, "K": K_pad, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": 0,              
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
    }

In [5]:
@torch.no_grad()
def tune_img2col_int8_tiles_for_shape(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BK in blocks_K:
            for W in warps:
                for S in stages:
                    try:
                        rec = bench_once_img2col_int8_vs_unfold(
                            N, Cin, H, W,
                            Kh, Kw,
                            Sh, Sw,
                            Ph, Pw,
                            Dh, Dw,
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=W,
                            num_stages=S,
                            iters=iters,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] BM={BM}, BK={BK}, W={W}, S={S}: {e}")
                        continue

                    print(
                        f"BM={BM}, BK={BK}, W={W}, S={S}: "
                        f"t_triton={rec['t_triton_ms']:.3f} ms, "
                        f"speed_vs_torch={rec['speed_vs_torch']:.3f}x, "
                    )
                    records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this IMG2COL shape")

    df = pd.DataFrame(records)
    return df


In [8]:
df_i2c = tune_img2col_int8_tiles_for_shape(
    N=16, Cin=1, H=256, W=256,
    Kh=11, Kw=11,
    Sh=1, Sw=1,
    Ph=5, Pw=5,
    Dh=1, Dw=1,
    blocks_M=(32, 64, 128, ),
    blocks_K=(32, 64, 128, ),
    warps=(2, 4, 8, 16),
    stages=(2, 3),
    iters=200,
    device="cuda",
)

BM=32, BK=32, W=2, S=2: t_triton=0.112 ms, speed_vs_torch=3.818x, 
BM=32, BK=32, W=2, S=3: t_triton=0.110 ms, speed_vs_torch=3.870x, 
BM=32, BK=32, W=4, S=2: t_triton=0.205 ms, speed_vs_torch=3.671x, 
BM=32, BK=32, W=4, S=3: t_triton=0.206 ms, speed_vs_torch=3.643x, 
BM=32, BK=32, W=8, S=2: t_triton=0.407 ms, speed_vs_torch=1.851x, 
BM=32, BK=32, W=8, S=3: t_triton=0.412 ms, speed_vs_torch=1.866x, 
BM=32, BK=32, W=16, S=2: t_triton=0.803 ms, speed_vs_torch=0.957x, 
BM=32, BK=32, W=16, S=3: t_triton=0.672 ms, speed_vs_torch=1.133x, 
BM=32, BK=64, W=2, S=2: t_triton=0.053 ms, speed_vs_torch=3.828x, 
BM=32, BK=64, W=2, S=3: t_triton=0.054 ms, speed_vs_torch=3.829x, 
BM=32, BK=64, W=4, S=2: t_triton=0.109 ms, speed_vs_torch=3.473x, 
BM=32, BK=64, W=4, S=3: t_triton=0.103 ms, speed_vs_torch=3.664x, 
BM=32, BK=64, W=8, S=2: t_triton=0.207 ms, speed_vs_torch=1.788x, 
BM=32, BK=64, W=8, S=3: t_triton=0.406 ms, speed_vs_torch=1.069x, 
BM=32, BK=64, W=16, S=2: t_triton=0.813 ms, speed_vs_torch=0

In [9]:
df_i2c.sort_values("t_triton_ms").head(10)

Unnamed: 0,M,K,N,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs
8,8192,124,16,32,0,64,2,2,0.053481,0.20472,3.827891,37.98743,9.923854
9,8192,124,16,32,0,64,2,3,0.053515,0.204915,3.829088,37.963214,9.914426
11,16384,124,16,32,0,64,4,3,0.103473,0.379096,3.663727,39.268602,10.718212
10,16384,124,16,32,0,64,4,2,0.108805,0.377879,3.472992,37.34414,10.752729
1,8192,124,16,32,0,32,2,3,0.110279,0.426819,3.870359,18.422534,4.759903
0,8192,124,16,32,0,32,2,2,0.11176,0.426679,3.817825,18.178423,4.761461
24,8192,124,16,64,0,32,2,2,0.128552,0.414733,3.22619,15.803864,4.898616
32,8192,124,16,64,0,64,2,2,0.129532,0.422118,3.258802,15.684305,4.812906
49,8192,124,16,128,0,32,2,3,0.130104,0.422792,3.249634,15.615274,4.805241
25,8192,124,16,64,0,32,2,3,0.130912,0.445858,3.405774,15.518905,4.556645


In [74]:
INT8_IMG2COL_BEST_BLOCK_M = 32
INT8_IMG2COL_BEST_BLOCK_N = 0
INT8_IMG2COL_BEST_BLOCK_K = 64
INT8_IMG2COL_BEST_WARPS   = 2
INT8_IMG2COL_BEST_STAGES  = 2