In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [2]:
import triton
import triton.language as tl
import torch
import time
import pandas as pd
from conv_gemm.triton_kernels.fp16.gemm_kernel import triton_gemm

In [3]:
@torch.no_grad()
def bench_once_gemm_fp16_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):

    A_f16 = torch.randn(M, K, device=device, dtype=torch.float16)
    B_f16 = torch.randn(K, N, device=device, dtype=torch.float16)

    # Torch FP16 GEMM
    def _call_torch():
        return (A_f16 @ B_f16).float()

    # warmup
    for _ in range(5):
        _call_torch()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters


    # Triton FP16 GEMM
    def _call_triton():
        return triton_gemm(
            A_f16, B_f16,
            use_fp16=True,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    for _ in range(5):
        _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_tr = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # Bandwidth
    bytes_moved = (
        A_f16.numel() * 2 +
        B_f16.numel() * 2 +
        C_tr.numel() * 4
    )
    bytes_moved = float(bytes_moved)

    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
    }


In [7]:
@torch.no_grad()
def tune_gemm_fp16_tiles_for_shape(
    M, K, N,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BN in blocks_N:
            for BK in blocks_K:
                for W in warps:
                    for S in stages:

                        try:
                            rec = bench_once_gemm_fp16_vs_torch(
                                M, K, N,
                                BLOCK_M=BM,
                                BLOCK_N=BN,
                                BLOCK_K=BK,
                                num_warps=W,
                                num_stages=S,
                                iters=iters,
                                device=device,
                            )
                        except Exception as e:
                            print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: {e}")
                            continue

                        print(
                            f"BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: "
                            f"t_triton={rec['t_triton_ms']:.3f} ms, "
                            f"speed_vs_torch={rec['speed_vs_torch']:.3f}x, "
                        )
                        records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this GEMM FP16 shape")

    return pd.DataFrame(records)


In [8]:
df_fp16 = tune_gemm_fp16_tiles_for_shape(
    M=4096, K=1024, N=1024,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)


BM=32, BN=32, BK=32, W=2, S=2: t_triton=0.340 ms, speed_vs_torch=0.920x, 
BM=32, BN=32, BK=32, W=2, S=3: t_triton=0.462 ms, speed_vs_torch=0.608x, 
BM=32, BN=32, BK=32, W=4, S=2: t_triton=0.340 ms, speed_vs_torch=0.789x, 
BM=32, BN=32, BK=32, W=4, S=3: t_triton=0.450 ms, speed_vs_torch=0.597x, 
BM=32, BN=32, BK=32, W=8, S=2: t_triton=0.472 ms, speed_vs_torch=0.563x, 
BM=32, BN=32, BK=32, W=8, S=3: t_triton=0.549 ms, speed_vs_torch=0.486x, 
BM=32, BN=32, BK=64, W=2, S=2: t_triton=0.321 ms, speed_vs_torch=0.870x, 
BM=32, BN=32, BK=64, W=2, S=3: t_triton=0.353 ms, speed_vs_torch=0.776x, 
BM=32, BN=32, BK=64, W=4, S=2: t_triton=0.344 ms, speed_vs_torch=0.788x, 
BM=32, BN=32, BK=64, W=4, S=3: t_triton=0.336 ms, speed_vs_torch=0.807x, 
BM=32, BN=32, BK=64, W=8, S=2: t_triton=0.357 ms, speed_vs_torch=0.772x, 
BM=32, BN=32, BK=64, W=8, S=3: t_triton=0.362 ms, speed_vs_torch=0.762x, 
BM=32, BN=32, BK=128, W=2, S=2: t_triton=0.415 ms, speed_vs_torch=0.663x, 
BM=32, BN=32, BK=128, W=2, S=3: t_tri

In [9]:
df_fp16.sort_values("t_triton_ms").head(10)

Unnamed: 0,M,K,N,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs
93,4096,1024,1024,64,128,32,4,3,0.212696,0.277762,1.305912,128.178156,98.152237
99,4096,1024,1024,64,128,64,4,3,0.221131,0.269691,1.219598,123.288768,101.089643
137,4096,1024,1024,128,64,64,8,3,0.223396,0.272501,1.21981,122.038595,100.047226
129,4096,1024,1024,128,64,32,4,3,0.223889,0.278304,1.243042,121.76998,97.961284
135,4096,1024,1024,128,64,64,4,3,0.225016,0.272817,1.212434,121.159951,99.931194
98,4096,1024,1024,64,128,64,4,2,0.227905,0.267245,1.172617,119.624494,102.01499
92,4096,1024,1024,64,128,32,4,2,0.228575,0.267142,1.168728,119.273622,102.054244
37,4096,1024,1024,32,128,32,2,3,0.229122,0.270004,1.178428,118.988851,100.972493
38,4096,1024,1024,32,128,32,4,2,0.229209,0.269264,1.174753,118.943754,101.250053
101,4096,1024,1024,64,128,64,8,3,0.229492,0.276581,1.205188,118.796873,98.571244


In [11]:
FP16_GEMM_BEST_BLOCK_M = 64
FP16_GEMM_BEST_BLOCK_N = 128
FP16_GEMM_BEST_BLOCK_K = 32
FP16_GEMM_BEST_WARPS   = 4
FP16_GEMM_BEST_STAGES  = 3