In [2]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [3]:
import triton
import triton.language as tl
import torch
import time
import pandas as pd
from conv_gemm.triton_kernels.fp16.gemm_kernel import triton_gemm

In [4]:
@torch.no_grad()
def bench_once_gemm_fp16_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):

    A_f16 = torch.randn(M, K, device=device, dtype=torch.float16)
    B_f16 = torch.randn(K, N, device=device, dtype=torch.float16)

    # Torch FP16 GEMM
    def _call_torch():
        return (A_f16 @ B_f16).float()

    # warmup
    for _ in range(5):
        _call_torch()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # Triton FP16 GEMM
    def _call_triton():
        return triton_gemm(
            A_f16, B_f16,
            use_fp16=True,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    for _ in range(5):
        C_tr = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_tr = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters


    diff = (C_tr - C_ref).abs()
    max_abs_err16 = diff.max().item()
    mean_abs_err16 = diff.mean().item()

    
    # Bandwidth
    bytes_moved = (
        A_f16.numel() * 2 +
        B_f16.numel() * 2 +
        C_tr.numel() * 4
    )
    bytes_moved = float(bytes_moved)

    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch16_ms": t_torch * 1e3,
        "speed_vs_torch16": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
        "max_abs_err": max_abs_err16,
        "mean_abs_err": mean_abs_err16,
    }


In [9]:
@torch.no_grad()
def tune_gemm_fp16_tiles_for_shape(
    M, K, N,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BN in blocks_N:
            for BK in blocks_K:
                for W in warps:
                    for S in stages:

                        try:
                            rec = bench_once_gemm_fp16_vs_torch(
                                M, K, N,
                                BLOCK_M=BM,
                                BLOCK_N=BN,
                                BLOCK_K=BK,
                                num_warps=W,
                                num_stages=S,
                                iters=iters,
                                device=device,
                            )
                        except Exception as e:
                            print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: {e}")
                            continue

                        print(
                            f"BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: "
                            f"t_triton={rec['t_triton_ms']:.3f} ms, "
                            f"speed_vs_torch={rec['speed_vs_torch16']:.3f}x, "
                        )
                        records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this GEMM FP16 shape")

    return pd.DataFrame(records)


In [10]:
df_fp16 = tune_gemm_fp16_tiles_for_shape(
    M=4096, K=1024, N=1024,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)


BM=32, BN=32, BK=32, W=2, S=2: t_triton=0.385 ms, speed_vs_torch=0.897x, 
BM=32, BN=32, BK=32, W=2, S=3: t_triton=0.409 ms, speed_vs_torch=0.665x, 
BM=32, BN=32, BK=32, W=4, S=2: t_triton=0.355 ms, speed_vs_torch=0.783x, 
BM=32, BN=32, BK=32, W=4, S=3: t_triton=0.485 ms, speed_vs_torch=0.573x, 
BM=32, BN=32, BK=32, W=8, S=2: t_triton=0.474 ms, speed_vs_torch=0.562x, 
BM=32, BN=32, BK=32, W=8, S=3: t_triton=0.558 ms, speed_vs_torch=0.480x, 
BM=32, BN=32, BK=64, W=2, S=2: t_triton=0.379 ms, speed_vs_torch=0.736x, 
BM=32, BN=32, BK=64, W=2, S=3: t_triton=0.330 ms, speed_vs_torch=0.828x, 
BM=32, BN=32, BK=64, W=4, S=2: t_triton=0.351 ms, speed_vs_torch=0.774x, 
BM=32, BN=32, BK=64, W=4, S=3: t_triton=0.338 ms, speed_vs_torch=0.820x, 
BM=32, BN=32, BK=64, W=8, S=2: t_triton=0.411 ms, speed_vs_torch=0.676x, 
BM=32, BN=32, BK=64, W=8, S=3: t_triton=0.334 ms, speed_vs_torch=0.835x, 
BM=32, BN=32, BK=128, W=2, S=2: t_triton=0.403 ms, speed_vs_torch=0.687x, 
BM=32, BN=32, BK=128, W=2, S=3: t_tri

In [11]:
cols = [
    "BLOCK_M", "BLOCK_K",
    "num_warps", "num_stages",
    "t_triton_ms", "t_torch16_ms",
    "speed_vs_torch16",
     "mean_abs_err"
]

df_fp16_filtered = df_fp16[cols].sort_values("speed_vs_torch16", ascending=False).head(5).reset_index(drop=True)
df_fp16_filtered

Unnamed: 0,BLOCK_M,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch16_ms,speed_vs_torch16,mean_abs_err
0,32,32,2,3,0.230908,0.330703,1.432182,0.004498
1,128,64,8,3,0.220453,0.297433,1.34919,0.004496
2,32,64,8,2,0.242241,0.325687,1.344476,0.004494
3,64,32,2,3,0.26768,0.342614,1.279939,0.004496
4,64,64,4,2,0.223943,0.285073,1.272973,0.004489


In [11]:
FP16_GEMM_BEST_BLOCK_M = 64
FP16_GEMM_BEST_BLOCK_N = 128
FP16_GEMM_BEST_BLOCK_K = 32
FP16_GEMM_BEST_WARPS   = 4
FP16_GEMM_BEST_STAGES  = 3