# ResNet18 Triton per-layer benchmark (FP16)

Собираем ResNet18 полностью на `conv_gemm.layers.TritonConv2d` в режиме fp16 и сравниваем чистый forward каждой свёртки с PyTorch-версией. В конце сохраняем таблицу — из неё потом можно выбрать слои для дальнейшей спарсификации.


In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

In [2]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
import copy, time
from typing import Dict, List
from conv_gemm.layers.triton_conv2d import TritonConv2d

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16 if device.type == 'cuda' else torch.float32
torch.manual_seed(0)
print('device:', device)
print('dtype for activations:', dtype)


device: cuda
dtype for activations: torch.float16


In [3]:
def make_triton_from_conv(conv: nn.Conv2d, *, precision_mode: str = 'fp16_infer',
                        triton_blocks=(64, 64, 32), triton_launch=(4, 2)) -> TritonConv2d:
    tm = TritonConv2d(
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        bias=(conv.bias is not None),
        BLOCK_M=triton_blocks[0],
        BLOCK_N=triton_blocks[1],
        BLOCK_K=triton_blocks[2],
        NUM_WARPS=triton_launch[0],
        NUM_STAGES=triton_launch[1],
        precision_mode=precision_mode,
    ).to(conv.weight.device)

    with torch.no_grad():
        tm.weight.copy_(conv.weight.data.to(tm.weight.dtype))
        if conv.bias is not None and tm.bias is not None:
            tm.bias.copy_(conv.bias.data.to(tm.bias.dtype))
    return tm


def replace_convs_with_triton(module: nn.Module, *, precision_mode: str = 'fp16_infer',
                              triton_blocks=(64, 64, 32), triton_launch=(4, 2)) -> nn.Module:
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d) and child.groups == 1:
            new_conv = make_triton_from_conv(
                child, precision_mode=precision_mode,
                triton_blocks=triton_blocks, triton_launch=triton_launch
            )
            setattr(module, name, new_conv)
        else:
            replace_convs_with_triton(child, precision_mode=precision_mode,
                                      triton_blocks=triton_blocks, triton_launch=triton_launch)
    return module


def count_conv_layers(model: nn.Module) -> int:
    return sum(1 for _ in model.modules() if isinstance(_, nn.Conv2d))


In [4]:
base_torch = resnet18(weights=None).to(device).eval().to(dtype)
triton_resnet = copy.deepcopy(base_torch).to(device).eval()
replace_convs_with_triton(triton_resnet, precision_mode='fp16_infer')
triton_resnet.eval()

sample_batch = 4
sample = torch.randn(sample_batch, 3, 224, 224, device=device, dtype=dtype)
print('total conv layers:', count_conv_layers(base_torch))


total conv layers: 20


In [5]:
@torch.no_grad()
def collect_conv_inputs(model: nn.Module, sample: torch.Tensor) -> Dict[str, torch.Tensor]:
    model.eval()
    buffers: Dict[str, torch.Tensor] = {}
    hooks: List[torch.utils.hooks.RemovableHandle] = []

    def make_hook(name):
        def hook(module, inp, out):
            buffers[name] = inp[0].detach().contiguous()
        return hook

    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            hooks.append(module.register_forward_hook(make_hook(name)))

    model(sample)

    for h in hooks:
        h.remove()
    return buffers

conv_inputs = collect_conv_inputs(base_torch, sample)
print('captured conv tensors:', len(conv_inputs))


captured conv tensors: 20


In [6]:
def bench_ms(fn, iters=50):
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) * 1000.0 / iters


@torch.no_grad()
def run_once(module: nn.Module, tensor: torch.Tensor):
    module(tensor)


def run_conv_benchmarks(torch_model: nn.Module, triton_model: nn.Module,
                         inputs: Dict[str, torch.Tensor], iters=80):
    results = []
    triton_modules = dict(triton_model.named_modules())

    for name, module in torch_model.named_modules():
        if not isinstance(module, nn.Conv2d):
            continue
        if name not in triton_modules or name not in inputs:
            continue
        tri_module = triton_modules[name]
        x = inputs[name]
        torch_time = bench_ms(lambda: run_once(module, x), iters=iters)
        triton_time = bench_ms(lambda: run_once(tri_module, x), iters=iters)
        speed = torch_time / triton_time if triton_time > 0 else float('nan')
        N, Cin, H, W = x.shape
        results.append({
            'name': name,
            'input_shape': f"{N}x{Cin}x{H}x{W}",
            'Cin': module.in_channels,
            'Cout': module.out_channels,
            'kernel': module.kernel_size[0],
            'stride': module.stride[0],
            'torch_ms': torch_time,
            'triton_ms': triton_time,
            'speedup': speed,
        })
    return results


In [7]:
results = run_conv_benchmarks(base_torch, triton_resnet, conv_inputs, iters=60)
print(f'Profiled convs: {len(results)}')


Profiled convs: 20


In [8]:
try:
    import pandas as pd
    df = pd.DataFrame(results)
    display(df.sort_values('speedup', ascending=False).head(10))
except Exception as exc:
    print('pandas unavailable:', exc)
    for row in results[:10]:
        print(row)


Unnamed: 0,name,input_shape,Cin,Cout,kernel,stride,torch_ms,triton_ms,speedup
15,layer4.0.conv1,4x256x14x14,256,512,3,2,0.156695,0.248366,0.630904
19,layer4.1.conv2,4x512x7x7,512,512,3,1,0.091402,0.174898,0.522604
18,layer4.1.conv1,4x512x7x7,512,512,3,1,0.09151,0.176002,0.519938
16,layer4.0.conv2,4x512x7x7,512,512,3,1,0.091682,0.179665,0.510296
14,layer3.1.conv2,4x256x14x14,256,256,3,1,0.062229,0.173947,0.357747
13,layer3.1.conv1,4x256x14x14,256,256,3,1,0.06197,0.176494,0.351117
9,layer2.1.conv2,4x128x28x28,128,128,3,1,0.058614,0.172173,0.340438
8,layer2.1.conv1,4x128x28x28,128,128,3,1,0.057795,0.174359,0.331469
6,layer2.0.conv2,4x128x28x28,128,128,3,1,0.05786,0.174566,0.331452
3,layer1.1.conv1,4x64x56x56,64,64,3,1,0.05662,0.174992,0.32356


In [9]:
valid = [r for r in results if r['speedup'] == r['speedup']]
closest = sorted(valid, key=lambda r: abs(1.0 - r['speedup']))[:10]
for row in closest:
    print(f"{row['name']:<20} | input {row['input_shape']} | k={row['kernel']} stride={row['stride']} |"
          f" torch={row['torch_ms']:.3f} ms | triton={row['triton_ms']:.3f} ms | speedup={row['speedup']:.3f}x")


layer4.0.conv1       | input 4x256x14x14 | k=3 stride=2 | torch=0.157 ms | triton=0.248 ms | speedup=0.631x
layer4.1.conv2       | input 4x512x7x7 | k=3 stride=1 | torch=0.091 ms | triton=0.175 ms | speedup=0.523x
layer4.1.conv1       | input 4x512x7x7 | k=3 stride=1 | torch=0.092 ms | triton=0.176 ms | speedup=0.520x
layer4.0.conv2       | input 4x512x7x7 | k=3 stride=1 | torch=0.092 ms | triton=0.180 ms | speedup=0.510x
layer3.1.conv2       | input 4x256x14x14 | k=3 stride=1 | torch=0.062 ms | triton=0.174 ms | speedup=0.358x
layer3.1.conv1       | input 4x256x14x14 | k=3 stride=1 | torch=0.062 ms | triton=0.176 ms | speedup=0.351x
layer2.1.conv2       | input 4x128x28x28 | k=3 stride=1 | torch=0.059 ms | triton=0.172 ms | speedup=0.340x
layer2.1.conv1       | input 4x128x28x28 | k=3 stride=1 | torch=0.058 ms | triton=0.174 ms | speedup=0.331x
layer2.0.conv2       | input 4x128x28x28 | k=3 stride=1 | torch=0.058 ms | triton=0.175 ms | speedup=0.331x
layer1.1.conv1       | input 4x64x

In [10]:
import csv
save_path = 'notebooks/resnet18_layer_bench_fp16.csv'
fieldnames = list(results[0].keys()) if results else []
if fieldnames:
    with open(save_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)
    print(f'Saved {len(results)} rows to {save_path}')
else:
    print('No results to save')


Saved 20 rows to notebooks/resnet18_layer_bench_fp16.csv
