In [11]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [12]:
import triton
import triton.language as tl
import torch
import time
import pandas as pd
from conv_gemm.triton_kernels.fp16.gemm_kernel import triton_gemm

In [13]:
@torch.no_grad()
def bench_once_gemm_fp16_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):

    A_f16 = torch.randn(M, K, device=device, dtype=torch.float16)
    B_f16 = torch.randn(K, N, device=device, dtype=torch.float16)

    # Torch FP16 GEMM
    def _call_torch():
        return (A_f16 @ B_f16).float()

    # warmup
    for _ in range(5):
        _call_torch()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # Triton FP16 GEMM
    def _call_triton():
        return triton_gemm(
            A_f16, B_f16,
            use_fp16=True,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    for _ in range(5):
        C_tr = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        C_tr = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters


    diff = (C_tr - C_ref).abs()
    max_abs_err16 = diff.max().item()
    mean_abs_err16 = diff.mean().item()

    
    # Bandwidth
    bytes_moved = (
        A_f16.numel() * 2 +
        B_f16.numel() * 2 +
        C_tr.numel() * 4
    )
    bytes_moved = float(bytes_moved)

    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch16_ms": t_torch * 1e3,
        "speed_vs_torch16": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
        "max_abs_err": max_abs_err16,
        "mean_abs_err": mean_abs_err16,
    }


In [14]:
@torch.no_grad()
def tune_gemm_fp16_tiles_for_shape(
    M, K, N,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BN in blocks_N:
            for BK in blocks_K:
                for W in warps:
                    for S in stages:

                        try:
                            rec = bench_once_gemm_fp16_vs_torch(
                                M, K, N,
                                BLOCK_M=BM,
                                BLOCK_N=BN,
                                BLOCK_K=BK,
                                num_warps=W,
                                num_stages=S,
                                iters=iters,
                                device=device,
                            )
                        except Exception as e:
                            print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: {e}")
                            continue

                        print(
                            f"BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: "
                            f"t_triton={rec['t_triton_ms']:.3f} ms, "
                            f"speed_vs_torch={rec['speed_vs_torch16']:.3f}x, "
                        )
                        records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this GEMM FP16 shape")

    return pd.DataFrame(records)


In [15]:
df_fp16 = tune_gemm_fp16_tiles_for_shape(
    M=4096, K=1024, N=1024,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)


BM=32, BN=32, BK=32, W=2, S=2: t_triton=0.342 ms, speed_vs_torch=0.851x, 
BM=32, BN=32, BK=32, W=2, S=3: t_triton=0.437 ms, speed_vs_torch=0.628x, 
BM=32, BN=32, BK=32, W=4, S=2: t_triton=0.332 ms, speed_vs_torch=0.786x, 
BM=32, BN=32, BK=32, W=4, S=3: t_triton=0.435 ms, speed_vs_torch=0.609x, 
BM=32, BN=32, BK=32, W=8, S=2: t_triton=0.454 ms, speed_vs_torch=0.580x, 
BM=32, BN=32, BK=32, W=8, S=3: t_triton=0.542 ms, speed_vs_torch=0.484x, 
BM=32, BN=32, BK=64, W=2, S=2: t_triton=0.314 ms, speed_vs_torch=0.856x, 
BM=32, BN=32, BK=64, W=2, S=3: t_triton=0.339 ms, speed_vs_torch=0.786x, 
BM=32, BN=32, BK=64, W=4, S=2: t_triton=0.328 ms, speed_vs_torch=0.816x, 
BM=32, BN=32, BK=64, W=4, S=3: t_triton=0.328 ms, speed_vs_torch=0.815x, 
BM=32, BN=32, BK=64, W=8, S=2: t_triton=0.336 ms, speed_vs_torch=0.807x, 
BM=32, BN=32, BK=64, W=8, S=3: t_triton=0.355 ms, speed_vs_torch=0.764x, 
BM=32, BN=32, BK=128, W=2, S=2: t_triton=0.401 ms, speed_vs_torch=0.682x, 
BM=32, BN=32, BK=128, W=2, S=3: t_tri

In [17]:
df_fp16["shape_info"] = (
      "4096/1024/1024"

)

cols = [
    "shape_info",
    "BLOCK_M", "BLOCK_N", "BLOCK_K",
    "num_warps", "num_stages",
    "t_triton_ms", "t_torch16_ms",
    "speed_vs_torch16",
     "mean_abs_err"
]

df_fp16_filtered = df_fp16[cols].sort_values("speed_vs_torch16", ascending=False).head(5).reset_index(drop=True)
df_fp16_filtered

Unnamed: 0,shape_info,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch16_ms,speed_vs_torch16,mean_abs_err
0,4096/1024/1024,64,128,32,4,3,0.209433,0.267012,1.274928,0.004495
1,4096/1024/1024,64,128,64,4,2,0.227704,0.276041,1.212277,0.004497
2,4096/1024/1024,32,128,32,4,3,0.226807,0.274289,1.209349,0.004498
3,4096/1024/1024,64,128,64,4,3,0.218102,0.262456,1.203364,0.004496
4,4096/1024/1024,64,128,32,8,3,0.240096,0.288143,1.200114,0.004501


In [18]:
df_fp16_filtered.index = ["GEMM_FP16"] * len(df_fp16_filtered)
df_fp16_filtered

Unnamed: 0,shape_info,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch16_ms,speed_vs_torch16,mean_abs_err
GEMM_FP16,4096/1024/1024,64,128,32,4,3,0.209433,0.267012,1.274928,0.004495
GEMM_FP16,4096/1024/1024,64,128,64,4,2,0.227704,0.276041,1.212277,0.004497
GEMM_FP16,4096/1024/1024,32,128,32,4,3,0.226807,0.274289,1.209349,0.004498
GEMM_FP16,4096/1024/1024,64,128,64,4,3,0.218102,0.262456,1.203364,0.004496
GEMM_FP16,4096/1024/1024,64,128,32,8,3,0.240096,0.288143,1.200114,0.004501


In [8]:
FP16_GEMM_BEST_BLOCK_M = 64
FP16_GEMM_BEST_BLOCK_N = 128
FP16_GEMM_BEST_BLOCK_K = 32
FP16_GEMM_BEST_WARPS   = 4
FP16_GEMM_BEST_STAGES  = 3