In [5]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [8]:
import triton
import triton.language as tl
import torch
import time
import pandas as pd
from conv_gemm.triton_kernels.int8.gemm_int8_kernel import gemm_int8_tc_kernel

In [9]:
def gemm_int8_tc(
    A_q: torch.Tensor,   
    B_q: torch.Tensor,  
    *,
    BLOCK_M: int = 64,
    BLOCK_N: int = 64,
    BLOCK_K: int = 32,
    num_warps: int = 4,
    num_stages: int = 2,
):
    if not A_q.is_contiguous():
        A_q = A_q.contiguous()
    if not B_q.is_contiguous():
        B_q = B_q.contiguous()

    M, K1 = A_q.shape
    K2, N = B_q.shape
    assert K1 == K2, f"K mismatch: {K1} vs {K2}"

    assert K1 % 4 == 0, f"K={K1} must be divisible by 4 for INT8 dot"
    assert BLOCK_K % 4 == 0, f"BLOCK_K={BLOCK_K} must be divisible by 4"

    C_i32 = torch.empty((M, N), dtype=torch.int32, device=A_q.device)

    a_m, a_k = A_q.stride()
    b_k, b_n = B_q.stride()
    c_m, c_n = C_i32.stride()

    grid = (
        triton.cdiv(M, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    gemm_int8_tc_kernel[grid](
        A_q, B_q, C_i32,
        M, N, K1,
        a_m, a_k,
        b_k, b_n,
        c_m, c_n,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return C_i32

# title search

In [10]:
@torch.no_grad()
def bench_once_gemm_int8_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):
    """
    INT8 Triton GEMM vs torch FP16 matmul
    """
    A_q = torch.randint(-128, 127, (M, K), device=device, dtype=torch.int8)
    B_q = torch.randint(-128, 127, (K, N), device=device, dtype=torch.int8)

    # FP16 baseline
    A_f16 = A_q.to(torch.float16)
    B_f16 = B_q.to(torch.float16)

    with torch.no_grad():
        C_ref = (A_f16 @ B_f16).float() 

    def _call_triton():
        C_i32 = gemm_int8_tc(
            A_q, B_q,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return C_i32
        
    # triton matmul int8
    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        C_i32 = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # torch matmul FP16
    def _call_torch():
        return A_f16 @ B_f16
    for _ in range(5):
        _ = _call_torch()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref2 = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # bandwidth
    bytes_moved = A_q.numel() + B_q.numel()      
    bytes_moved += C_i32.numel() * 4            
    bytes_moved = float(bytes_moved)

    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
    }


In [11]:
@torch.no_grad()
def tune_gemm_int8_tiles_for_shape(
    M, K, N,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3,4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BN in blocks_N:
            for BK in blocks_K:
                if (K % 4 != 0) or (BK % 4 != 0):
                    print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}: K/BK not multiple of 4")
                    continue

                for W in warps:
                    for S in stages:
                        try:
                            rec = bench_once_gemm_int8_vs_torch(
                                M, K, N,
                                BLOCK_M=BM,
                                BLOCK_N=BN,
                                BLOCK_K=BK,
                                num_warps=W,
                                num_stages=S,
                                iters=iters,
                                device=device,
                            )
                        except RuntimeError as e:
                            print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: {e}")
                            continue

                        print(
                            f"BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: "
                            f"t_triton={rec['t_triton_ms']:.3f} ms, "
                            f"speed_vs_torch={rec['speed_vs_torch']:.3f}x, "
                        )
                        records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this GEMM shape")

    df = pd.DataFrame(records)
    return df

In [12]:
df_gemm_tiles = tune_gemm_int8_tiles_for_shape(
    M=4096, K=1024, N=1024,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)

BM=32, BN=32, BK=32, W=2, S=2: t_triton=0.241 ms, speed_vs_torch=0.961x, 
BM=32, BN=32, BK=32, W=2, S=3: t_triton=0.245 ms, speed_vs_torch=0.941x, 
BM=32, BN=32, BK=32, W=4, S=2: t_triton=0.401 ms, speed_vs_torch=0.563x, 
BM=32, BN=32, BK=32, W=4, S=3: t_triton=0.410 ms, speed_vs_torch=0.530x, 
BM=32, BN=32, BK=32, W=8, S=2: t_triton=0.491 ms, speed_vs_torch=0.445x, 
BM=32, BN=32, BK=32, W=8, S=3: t_triton=0.470 ms, speed_vs_torch=0.462x, 
BM=32, BN=32, BK=64, W=2, S=2: t_triton=0.219 ms, speed_vs_torch=1.090x, 
BM=32, BN=32, BK=64, W=2, S=3: t_triton=0.224 ms, speed_vs_torch=1.039x, 
BM=32, BN=32, BK=64, W=4, S=2: t_triton=0.331 ms, speed_vs_torch=0.730x, 
BM=32, BN=32, BK=64, W=4, S=3: t_triton=0.327 ms, speed_vs_torch=0.665x, 
BM=32, BN=32, BK=64, W=8, S=2: t_triton=0.373 ms, speed_vs_torch=0.617x, 
BM=32, BN=32, BK=64, W=8, S=3: t_triton=0.386 ms, speed_vs_torch=0.561x, 
BM=32, BN=32, BK=128, W=2, S=2: t_triton=0.213 ms, speed_vs_torch=1.101x, 
BM=32, BN=32, BK=128, W=2, S=3: t_tri

In [13]:
df_gemm_tiles.sort_values("t_triton_ms").head(10)

Unnamed: 0,M,K,N,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs
153,4096,1024,1024,128,128,64,4,3,0.101361,0.220684,2.177198,217.243762,99.781338
97,4096,1024,1024,64,128,64,2,3,0.10648,0.216821,2.036252,206.799329,101.558822
99,4096,1024,1024,64,128,64,4,3,0.106584,0.222467,2.087237,206.597942,98.981529
134,4096,1024,1024,128,64,64,4,2,0.107135,0.227395,2.122513,205.536362,96.836346
98,4096,1024,1024,64,128,64,4,2,0.110027,0.221197,2.010384,200.133003,99.549662
105,4096,1024,1024,64,128,128,4,3,0.111288,0.223643,2.009592,197.866555,98.461052
152,4096,1024,1024,128,128,64,4,2,0.111548,0.228416,2.04769,197.404218,96.403371
91,4096,1024,1024,64,128,32,2,3,0.112947,0.218499,1.934527,194.959849,100.779092
133,4096,1024,1024,128,64,64,2,3,0.113013,0.228455,2.021493,194.845785,96.38707
147,4096,1024,1024,128,128,32,4,3,0.113029,0.220613,1.951824,194.818247,99.813422


In [11]:
INT8_GEMM_BEST_BLOCK_M = 128
INT8_GEMM_BEST_BLOCK_N = 128
INT8_GEMM_BEST_BLOCK_K = 64
INT8_GEMM_BEST_WARPS   = 4
INT8_GEMM_BEST_STAGES  = 3