# Triton FP16 baseline benchmark

?????????? `conv_gemm.baseline_layers.TritonConv2d` ? `nn.Conv2d` ? ???? ?????? ????????, ??? baseline ???? ?????????? ????????/????????.


In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

In [2]:
import torch
import time
import pandas as pd
from torch import nn

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d as BaselineTritonConv2d


In [3]:
assert torch.cuda.is_available(), 'CUDA device required'
device = 'cuda'
torch.manual_seed(0)
print('device:', device)


device: cuda


In [4]:
N = 1
Cin = 3
Cout = 8
H = W = 32
K = 3

x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float16)
conv_ref = nn.Conv2d(Cin, Cout, kernel_size=K, padding=K//2, bias=True).to(device).half()
conv_tri = BaselineTritonConv2d(Cin, Cout, kernel_size=K, padding=K//2, bias=True).to(device)


In [5]:
with torch.no_grad():
    conv_tri.weight.copy_(conv_ref.weight)
    if conv_ref.bias is not None and conv_tri.bias is not None:
        conv_tri.bias.copy_(conv_ref.bias)


In [6]:
y_ref = conv_ref(x)
y_tri = conv_tri(x)

print('y_ref.shape:', y_ref.shape)
print('y_tri.shape:', y_tri.shape)


y_ref.shape: torch.Size([1, 8, 32, 32])
y_tri.shape: torch.Size([1, 8, 32, 32])


In [7]:
err = (y_ref - y_tri).abs()
print('max error:', err.max().item())
print('mean error:', err.mean().item())


max error: 0.0009765625
mean error: 7.933378219604492e-05


In [8]:
def bench_ms(fn, iters=200):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0 / iters

print('PyTorch FP16 (ms):', bench_ms(lambda: conv_ref(x)))
print('Triton baseline FP16 (ms):', bench_ms(lambda: conv_tri(x)))


PyTorch FP16 (ms): 0.10113624501173035
Triton baseline FP16 (ms): 0.688816340007179


## ????? ?? ????? ????????????

??????? ???? ?????????? ??????? ???????????/??????/???? ? ???????? ???????.


In [9]:
def run_fp16_baseline_bench(
    image_sizes=(32, 64, 112, 224),
    batch_sizes=(1, 2, 4, 8),
    channels=((1, 1), (1, 3), (3, 8), (8, 16), (16, 32), (32, 64)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters=50,
):
    rows = []

    for H in image_sizes:
        W = H
        for N in batch_sizes:
            for (Cin, Cout) in channels:
                for K in kernels:
                    if K > H or K > W:
                        continue

                    print(f'[bench] img={H} N={N} Cin={Cin} Cout={Cout} K={K}')
                    x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float16)
                    conv_ref = nn.Conv2d(Cin, Cout, kernel_size=K, padding=K//2, bias=True).to(device).half()
                    conv_tri = BaselineTritonConv2d(Cin, Cout, kernel_size=K, padding=K//2, bias=True).to(device)

                    with torch.no_grad():
                        conv_tri.weight.copy_(conv_ref.weight)
                        if conv_ref.bias is not None and conv_tri.bias is not None:
                            conv_tri.bias.copy_(conv_ref.bias)

                    try:
                        with torch.no_grad():
                            y_ref = conv_ref(x)
                    except RuntimeError as e:
                        note = f'torch_fail: {e}'
                        print('  ', note)
                        rows.append([H, N, Cin, Cout, K, None, None, None, None, None, note])
                        continue

                    try:
                        with torch.no_grad():
                            y_tri = conv_tri(x)
                    except RuntimeError as e:
                        note = f'triton_fail: {e}'
                        print('  ', note)
                        rows.append([H, N, Cin, Cout, K, None, None, None, None, None, note])
                        continue

                    err = (y_ref - y_tri).abs()
                    err_max = err.max().item()
                    err_mean = err.mean().item()

                    try:
                        t_ref = bench_ms(lambda: conv_ref(x), iters=iters)
                        t_tri = bench_ms(lambda: conv_tri(x), iters=iters)
                        speedup = t_ref / t_tri if t_tri > 0 else None
                    except RuntimeError as e:
                        note = f'bench_fail: {e}'
                        print('  ', note)
                        rows.append([H, N, Cin, Cout, K, None, None, None, err_max, err_mean, note])
                        continue

                    rows.append([H, N, Cin, Cout, K, t_ref, t_tri, speedup, err_max, err_mean, None])

    return pd.DataFrame(rows, columns=[
        'img', 'N', 'Cin', 'Cout', 'K',
        't_torch_fp16_ms', 't_triton_fp16_ms', 'speedup',
        'err_max', 'err_mean', 'note',
    ])


In [10]:
channels_cfg = (
    (1, 1),
    (1, 3),
    (3, 8),
    (8, 16),
    (16, 32),
    (32, 64),
)

df = run_fp16_baseline_bench(
    image_sizes=(32, 64, 112, 224),
    batch_sizes=(1, 2, 4, 8),
    channels=channels_cfg,
    kernels=(1, 3, 5, 7, 9, 11),
    iters=30,
)
df


[bench] img=32 N=1 Cin=1 Cout=1 K=1
[bench] img=32 N=1 Cin=1 Cout=1 K=3
[bench] img=32 N=1 Cin=1 Cout=1 K=5
[bench] img=32 N=1 Cin=1 Cout=1 K=7
[bench] img=32 N=1 Cin=1 Cout=1 K=9
[bench] img=32 N=1 Cin=1 Cout=1 K=11
[bench] img=32 N=1 Cin=1 Cout=3 K=1
[bench] img=32 N=1 Cin=1 Cout=3 K=3
[bench] img=32 N=1 Cin=1 Cout=3 K=5
[bench] img=32 N=1 Cin=1 Cout=3 K=7
[bench] img=32 N=1 Cin=1 Cout=3 K=9
[bench] img=32 N=1 Cin=1 Cout=3 K=11
[bench] img=32 N=1 Cin=3 Cout=8 K=1
[bench] img=32 N=1 Cin=3 Cout=8 K=3
[bench] img=32 N=1 Cin=3 Cout=8 K=5
[bench] img=32 N=1 Cin=3 Cout=8 K=7
[bench] img=32 N=1 Cin=3 Cout=8 K=9
[bench] img=32 N=1 Cin=3 Cout=8 K=11
[bench] img=32 N=1 Cin=8 Cout=16 K=1
[bench] img=32 N=1 Cin=8 Cout=16 K=3
[bench] img=32 N=1 Cin=8 Cout=16 K=5
[bench] img=32 N=1 Cin=8 Cout=16 K=7
[bench] img=32 N=1 Cin=8 Cout=16 K=9
[bench] img=32 N=1 Cin=8 Cout=16 K=11
[bench] img=32 N=1 Cin=16 Cout=32 K=1
[bench] img=32 N=1 Cin=16 Cout=32 K=3
[bench] img=32 N=1 Cin=16 Cout=32 K=5
[bench] img=

Unnamed: 0,img,N,Cin,Cout,K,t_torch_fp16_ms,t_triton_fp16_ms,speedup,err_max,err_mean,note
0,32,1,1,1,1,0.078749,0.723765,0.108804,0.000977,0.000014,
1,32,1,1,1,3,0.107890,1.195479,0.090248,0.001953,0.000052,
2,32,1,1,1,5,0.270074,2.113724,0.127772,0.000977,0.000090,
3,32,1,1,1,7,0.078025,1.723411,0.045274,0.000977,0.000054,
4,32,1,1,1,9,0.163178,1.579105,0.103336,0.000977,0.000092,
...,...,...,...,...,...,...,...,...,...,...,...
571,224,8,32,64,3,2.699882,14.586037,0.185100,0.001953,0.000081,
572,224,8,32,64,5,4.720472,24.734845,0.190843,0.001953,0.000081,
573,224,8,32,64,7,8.088431,741.839068,0.010903,0.001953,0.000078,
574,224,8,32,64,9,11.852006,1258.304451,0.009419,0.001953,0.000081,


In [11]:
df_valid = df.dropna(subset=['t_torch_fp16_ms', 't_triton_fp16_ms', 'speedup'])
df_top = df_valid.sort_values('speedup', ascending=False).head(30)
df_top


Unnamed: 0,img,N,Cin,Cout,K,t_torch_fp16_ms,t_triton_fp16_ms,speedup,err_max,err_mean,note
263,64,8,1,3,11,0.671052,0.609302,1.101346,0.001953,9.3e-05,
413,112,8,3,8,11,1.923076,2.155135,0.892323,0.001953,6.9e-05,
557,224,8,3,8,11,7.378039,8.302525,0.88865,0.001953,7.9e-05,
556,224,8,3,8,9,5.178869,5.939555,0.871929,0.001953,7.9e-05,
412,112,8,3,8,9,1.331805,1.542914,0.863175,0.001953,7.3e-05,
268,64,8,3,8,9,0.451604,0.728869,0.619595,0.001953,8.1e-05,
269,64,8,3,8,11,0.642886,1.082386,0.593953,0.001953,6.8e-05,
303,112,1,3,8,7,0.485141,0.84596,0.57348,0.001953,8.1e-05,
521,224,4,3,8,11,2.437493,4.406784,0.553123,0.001953,7.9e-05,
520,224,4,3,8,9,1.662396,3.058153,0.543595,0.001953,9e-05,
