In [2]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [15]:
import time
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import pandas as pd
from conv_gemm.triton_kernels.int8.col2img_int8_kernel import col2img_int32_kernel

In [16]:
def col2img_int32(
    cols_i32: torch.Tensor,
    N: int, Cin: int,
    H: int, W: int,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int,
    BLOCK_K: int,
    num_warps: int = 4,
    num_stages: int = 2,
):

    assert cols_i32.is_cuda
    assert cols_i32.dtype == torch.int32
    cols_i32 = cols_i32.contiguous()

    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    M = N * Ho * Wo
    K = Cin * Kh * Kw

    assert cols_i32.shape == (M, K), f"cols shape {cols_i32.shape}, expected {(M, K)}"

    x_i32 = torch.zeros((N, Cin, H, W), device=cols_i32.device, dtype=torch.int32)
    sN, sC, sH, sW = x_i32.stride()

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))

    col2img_int32_kernel[grid](
        cols_i32, x_i32,
        N, Cin, H, W,
        Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
        Ho, Wo,
        sN, sC, sH, sW,
        K,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    return x_i32

In [17]:
@torch.no_grad()
def bench_once_col2img_int32_vs_fold(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):

    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    assert Ho > 0 and Wo > 0, f"Invalid Ho,Wo = {(Ho, Wo)}"

    M = N * Ho * Wo
    K = Cin * Kh * Kw

    cols_i32 = torch.randint(
        low=-128, high=127,
        size=(M, K),
        device=device,
        dtype=torch.int32,
    )

    # Torch  F.fold 
    cols_f = cols_i32.float()
    cols_fold = cols_f.view(N, Ho * Wo, K).permute(0, 2, 1).contiguous()  # [N, K, L]

    def _call_torch():
        return F.fold(
            cols_fold,
            output_size=(H, W),
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )  


    for _ in range(5):
        x_ref = _call_torch()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_ref = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # Triton
    def _call_triton():
        x_i32 = col2img_int32(
            cols_i32,
            N, Cin,
            H, W,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return x_i32


    for _ in range(5):
        x_i32 = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_i32 = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters


    # bandwidth 
    bytes_moved = (cols_i32.numel() + x_i32.numel()) * 4.0  # int32
    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "N": N,
        "Cin": Cin,
        "H": H,
        "W": W,
        "Kh": Kh,
        "Kw": Kw,
        "M": M,
        "K": K,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": 0,  
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
    }


In [18]:
@torch.no_grad()
def tune_col2img_int32_tiles_for_shape(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BK in blocks_K:
            for Wp in warps:
                for S in stages:
                    try:
                        rec = bench_once_col2img_int32_vs_fold(
                            N, Cin, H, W,
                            Kh, Kw,
                            Sh, Sw,
                            Ph, Pw,
                            Dh, Dw,
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=Wp,
                            num_stages=S,
                            iters=iters,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] BM={BM}, BK={BK}, W={Wp}, S={S}: {e}")
                        continue

                    print(
                        f"BM={BM}, BK={BK}, W={Wp}, S={S}: "
                        f"t_triton={rec['t_triton_ms']:.3f} ms, "
                        f"speed_vs_torch={rec['speed_vs_torch']:.3f}x, "
                    )
                    records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this COL2IMG shape")

    df = pd.DataFrame(records)
    return df

In [22]:
df_c2i = tune_col2img_int32_tiles_for_shape(
    N=16, Cin=1, H=256, W=256,
    Kh=11, Kw=11,
    Sh=1, Sw=1,
    Ph=5, Pw=5,
    Dh=1, Dw=1,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)

BM=32, BK=32, W=2, S=2: t_triton=4.030 ms, speed_vs_torch=0.849x, 
BM=32, BK=32, W=2, S=3: t_triton=2.787 ms, speed_vs_torch=1.334x, 
BM=32, BK=32, W=4, S=2: t_triton=1.659 ms, speed_vs_torch=1.425x, 
BM=32, BK=32, W=4, S=3: t_triton=1.681 ms, speed_vs_torch=1.404x, 
BM=32, BK=32, W=8, S=2: t_triton=1.774 ms, speed_vs_torch=1.328x, 
BM=32, BK=32, W=8, S=3: t_triton=1.764 ms, speed_vs_torch=1.337x, 
BM=32, BK=64, W=2, S=2: t_triton=5.503 ms, speed_vs_torch=0.429x, 
BM=32, BK=64, W=2, S=3: t_triton=4.812 ms, speed_vs_torch=0.490x, 
BM=32, BK=64, W=4, S=2: t_triton=2.085 ms, speed_vs_torch=1.163x, 
BM=32, BK=64, W=4, S=3: t_triton=2.021 ms, speed_vs_torch=1.195x, 
BM=32, BK=64, W=8, S=2: t_triton=1.699 ms, speed_vs_torch=1.388x, 
BM=32, BK=64, W=8, S=3: t_triton=2.357 ms, speed_vs_torch=1.007x, 
BM=32, BK=128, W=2, S=2: t_triton=7.052 ms, speed_vs_torch=0.336x, 
BM=32, BK=128, W=2, S=3: t_triton=7.635 ms, speed_vs_torch=0.308x, 
BM=32, BK=128, W=4, S=2: t_triton=8.896 ms, speed_vs_torch=0

In [23]:
df_c2i.sort_values("t_triton_ms").head(10)

Unnamed: 0,N,Cin,H,W,Kh,Kw,M,K,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs
2,16,1,256,256,11,11,1048576,121,32,0,32,4,2,1.659329,2.364605,1.425037,308.380712,216.401904
3,16,1,256,256,11,11,1048576,121,32,0,32,4,3,1.680546,2.359115,1.403779,304.487482,216.905575
23,16,1,256,256,11,11,1048576,121,64,0,32,8,3,1.684777,2.980211,1.768906,303.722782,171.700947
10,16,1,256,256,11,11,1048576,121,32,0,64,8,2,1.699415,2.359619,1.388489,301.10657,216.859196
22,16,1,256,256,11,11,1048576,121,64,0,32,8,2,1.711444,2.376568,1.388633,298.990304,215.312651
5,16,1,256,256,11,11,1048576,121,32,0,32,8,3,1.763631,2.357735,1.336865,290.142992,217.032448
4,16,1,256,256,11,11,1048576,121,32,0,32,8,2,1.773947,2.355675,1.327929,288.455699,217.222263
9,16,1,256,256,11,11,1048576,121,32,0,64,4,3,2.020703,2.413804,1.194537,253.231202,211.99114
21,16,1,256,256,11,11,1048576,121,64,0,32,4,3,2.070018,2.34117,1.13099,247.198326,218.568101
8,16,1,256,256,11,11,1048576,121,32,0,64,4,2,2.085397,2.425691,1.163179,245.375338,210.952311


In [24]:
INT8_COL2IMG_BEST_BLOCK_M = 32
INT8_COL2IMG_BEST_BLOCK_N = 0
INT8_COL2IMG_BEST_BLOCK_K = 32
INT8_COL2IMG_BEST_WARPS = 4
INT8_COL2IMG_BEST_STAGES = 2