# ResNet18 Baseline Conv2d Benchmark

Эта тетрадь сравнивает nn.Conv2d vs baseline TritonConv2d в ResNet18 на разных batch size и измеряет производительность отдельных ядер (3×3, 5×5, 7×7).


## Цели и критерии
- Выполнить требования README: Conv2d → img2col → GEMM, измерить ускорение и память.
- Обучить ResNet18 с nn.Conv2d и baseline TritonConv2d на нескольких batch size.
- Сравнить forward/backward время и использование GPU памяти.


In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

In [2]:
import copy
import json
import math
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
    raise RuntimeError("CUDA GPU is required for this benchmark")

seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = True

data_root = Path("../data").resolve()
data_root.mkdir(parents=True, exist_ok=True)

config = {
    "data_root": str(data_root),
    "num_classes": 10,
    "batch_sizes": [32, 64, 96, 128, 160, 192, 256],
    "num_workers": 4,
    "train_subset": 8192,
    "lr": 1e-3,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "warmup_steps": 5,
    "benchmark_steps": 40,
    "baseline_conv": {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_K": 64,
        "NUM_WARPS": 4,
        "NUM_STAGES": 2,
    },
}
print(json.dumps(config, indent=2))


{
  "data_root": "/mnt/d/VSCode-Projects/conv2d-img2col-gemm/data",
  "num_classes": 10,
  "batch_sizes": [
    32,
    64,
    96,
    128,
    160,
    192,
    256
  ],
  "num_workers": 4,
  "train_subset": 8192,
  "lr": 0.001,
  "momentum": 0.9,
  "weight_decay": 0.0005,
  "warmup_steps": 5,
  "benchmark_steps": 40,
  "baseline_conv": {
    "BLOCK_M": 64,
    "BLOCK_N": 64,
    "BLOCK_K": 64,
    "NUM_WARPS": 4,
    "NUM_STAGES": 2
  }
}


In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

full_train = torchvision.datasets.CIFAR10(
    root=config["data_root"], train=True, download=True, transform=transform_train
)
if config["train_subset"] is not None and config["train_subset"] < len(full_train):
    g = torch.Generator().manual_seed(seed)
    subset_idx = torch.randperm(len(full_train), generator=g)[: config["train_subset"]]
    train_dataset = torch.utils.data.Subset(full_train, subset_idx)
else:
    train_dataset = full_train


def make_loader(batch_size: int) -> DataLoader:
    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=config["num_workers"],
        pin_memory=True,
    )

train_loaders = {bs: make_loader(bs) for bs in config["batch_sizes"]}
print({bs: len(loader) for bs, loader in train_loaders.items()})

{32: 256, 64: 128, 96: 85, 128: 64, 160: 51, 192: 42, 256: 32}


In [5]:
def make_triton_conv(src: nn.Conv2d, cfg: dict) -> TritonConv2d:
    if src.groups != 1:
        raise ValueError("Baseline TritonConv2d currently supports groups=1 only")
    layer = TritonConv2d(
        in_channels=src.in_channels,
        out_channels=src.out_channels,
        kernel_size=src.kernel_size,
        stride=src.stride,
        padding=src.padding,
        dilation=src.dilation,
        bias=(src.bias is not None),
        **cfg,
    ).to(src.weight.device)
    with torch.no_grad():
        layer.weight.copy_(src.weight.detach().to(layer.weight.dtype))
        if layer.bias is not None and src.bias is not None:
            layer.bias.copy_(src.bias.detach().to(layer.bias.dtype))
    return layer


def replace_convs_with_baseline(module: nn.Module, cfg: dict):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, make_triton_conv(child, cfg))
        else:
            replace_convs_with_baseline(child, cfg)


def build_model_pair(config: dict):
    reference = torchvision.models.resnet18(num_classes=config["num_classes"])
    baseline = copy.deepcopy(reference)
    replace_convs_with_baseline(baseline, config["baseline_conv"])
    return reference, baseline

In [6]:
def run_benchmark(model: nn.Module, label: str, loader: DataLoader, config: dict):
    model = model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
    )

    warmup = config["warmup_steps"]
    total_steps = config["benchmark_steps"]
    records = []
    data_iter = iter(loader)

    for step in range(total_steps):
        try:
            images, targets = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            images, targets = next(data_iter)

        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)
        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        fwd_start.record()
        outputs = model(images)
        fwd_end.record()
        loss = criterion(outputs, targets)

        bwd_start.record()
        loss.backward()
        bwd_end.record()
        optimizer.step()

        torch.cuda.synchronize()

        fwd_ms = fwd_start.elapsed_time(fwd_end)
        bwd_ms = bwd_start.elapsed_time(bwd_end)
        step_ms = fwd_ms + bwd_ms
        mem_alloc = torch.cuda.max_memory_allocated(device) / 1024 ** 2
        mem_reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 2

        if step >= warmup:
            records.append({
                "label": label,
                "step": step,
                "loss": float(loss.item()),
                "fwd_ms": fwd_ms,
                "bwd_ms": bwd_ms,
                "step_ms": step_ms,
                "throughput_sps": images.size(0) / (step_ms / 1000.0),
                "max_mem_alloc_mb": mem_alloc,
                "max_mem_reserved_mb": mem_reserved,
            })

    if not records:
        raise RuntimeError("No data recorded for benchmark")

    df = pd.DataFrame(records)
    summary = {
        "label": label,
        "avg_forward_ms": df["fwd_ms"].mean(),
        "avg_backward_ms": df["bwd_ms"].mean(),
        "avg_step_ms": df["step_ms"].mean(),
        "samples_per_s": df["throughput_sps"].mean(),
        "max_mem_alloc_mb": df["max_mem_alloc_mb"].max(),
        "max_mem_reserved_mb": df["max_mem_reserved_mb"].max(),
    }
    return df, summary

In [7]:
batch_summaries = []
batch_details = []

for bs, loader in train_loaders.items():
    print(f"=== Batch size {bs} ===")
    torch_model, baseline_model = build_model_pair(config)

    torch_df, torch_summary = run_benchmark(torch_model, f"nn.Conv2d (torch, bs={bs})", loader, config)
    torch_summary.update({"variant": "nn.Conv2d", "batch_size": bs})
    batch_summaries.append(torch_summary)
    batch_details.append(torch_df.assign(variant="nn.Conv2d", batch_size=bs))

    baseline_df, baseline_summary = run_benchmark(baseline_model, f"Baseline TritonConv2d (bs={bs})", loader, config)
    baseline_summary.update({"variant": "Baseline TritonConv2d", "batch_size": bs})
    batch_summaries.append(baseline_summary)
    batch_details.append(baseline_df.assign(variant="Baseline TritonConv2d", batch_size=bs))

summary_df = pd.DataFrame(batch_summaries).set_index(["variant", "batch_size"])
summary_df

=== Batch size 32 ===
=== Batch size 64 ===
=== Batch size 96 ===
=== Batch size 128 ===
=== Batch size 160 ===
=== Batch size 192 ===
=== Batch size 256 ===


Unnamed: 0_level_0,Unnamed: 1_level_0,label,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb
variant,batch_size,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
nn.Conv2d,32,"nn.Conv2d (torch, bs=32)",5.463528,5.896806,11.360335,2963.048966,192.057129,220.0
Baseline TritonConv2d,32,Baseline TritonConv2d (bs=32),21.482849,18.368073,39.850922,816.408838,207.492676,226.0
nn.Conv2d,64,"nn.Conv2d (torch, bs=64)",7.024066,7.423326,14.447392,4822.355612,190.058105,228.0
Baseline TritonConv2d,64,Baseline TritonConv2d (bs=64),19.026444,20.994926,40.02137,1615.257839,228.024902,258.0
nn.Conv2d,96,"nn.Conv2d (torch, bs=96)",7.722843,7.825406,15.54825,6471.987964,190.685059,222.0
Baseline TritonConv2d,96,Baseline TritonConv2d (bs=96),19.798643,24.3812,44.179843,2196.84862,252.019043,312.0
nn.Conv2d,128,"nn.Conv2d (torch, bs=128)",6.229737,9.515037,15.744774,8389.520638,204.617676,228.0
Baseline TritonConv2d,128,Baseline TritonConv2d (bs=128),20.065011,25.708573,45.773584,2814.298342,281.83252,340.0
nn.Conv2d,160,"nn.Conv2d (torch, bs=160)",6.769018,10.961773,17.73079,9296.235904,219.119629,252.0
Baseline TritonConv2d,160,Baseline TritonConv2d (bs=160),22.511662,32.139702,54.651364,2990.8916,306.271973,366.0


In [8]:
detail_df = pd.concat(batch_details, ignore_index=True)
metrics = ["step_ms", "fwd_ms", "bwd_ms", "max_mem_alloc_mb"]
detail_df.groupby(["variant", "batch_size"])[metrics].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,fwd_ms,fwd_ms,...,bwd_ms,bwd_ms,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
variant,batch_size,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
Baseline TritonConv2d,32,35.0,39.850922,5.521001,34.037823,35.449903,37.84115,42.014,53.476,35.0,21.482849,...,17.13408,32.754688,35.0,207.492676,0.0,207.492676,207.492676,207.492676,207.492676,207.492676
Baseline TritonConv2d,64,35.0,40.02137,4.599582,34.997089,37.546255,39.405184,40.693841,60.875553,35.0,19.026444,...,20.706816,41.250816,35.0,228.024902,0.0,228.024902,228.024902,228.024902,228.024902,228.024902
Baseline TritonConv2d,96,35.0,44.179843,4.964113,38.937216,40.468017,42.470079,46.01952,59.205729,35.0,19.798643,...,24.186368,34.711552,35.0,252.019043,0.0,252.019043,252.019043,252.019043,252.019043,252.019043
Baseline TritonConv2d,128,35.0,45.773584,3.91366,42.213663,42.945088,44.547585,46.685905,56.56912,35.0,20.065011,...,25.850881,27.433985,35.0,281.125377,0.748738,280.08252,280.45752,280.83252,281.83252,281.83252
Baseline TritonConv2d,160,35.0,54.651364,8.612499,47.131744,47.655553,50.296928,59.0668,77.979458,35.0,22.511662,...,30.516736,50.637825,35.0,306.271973,0.0,306.271973,306.271973,306.271973,306.271973,306.271973
Baseline TritonConv2d,192,35.0,60.667655,8.55173,53.696898,54.455696,56.85491,63.749506,84.075741,35.0,20.316644,...,44.986881,52.293633,35.0,334.874735,0.737501,333.960449,333.960449,334.960449,335.710449,335.710449
Baseline TritonConv2d,256,35.0,74.426799,5.109964,70.20163,70.932704,72.811424,75.380497,89.0424,35.0,21.891679,...,51.967487,66.821121,35.0,386.338379,0.0,386.338379,386.338379,386.338379,386.338379,386.338379
nn.Conv2d,32,35.0,11.360335,3.058877,9.1856,9.519824,9.89696,11.166256,21.353184,35.0,5.463528,...,5.267968,16.7936,35.0,192.057129,0.0,192.057129,192.057129,192.057129,192.057129,192.057129
nn.Conv2d,64,35.0,14.447392,5.096346,10.004064,11.260576,12.437568,15.387392,29.5464,35.0,7.024066,...,6.147568,21.141504,35.0,190.058105,0.0,190.058105,190.058105,190.058105,190.058105,190.058105
nn.Conv2d,96,35.0,15.54825,3.931324,11.785472,13.08144,14.124832,16.228624,29.505695,35.0,7.722843,...,7.563776,21.264383,35.0,190.685059,0.0,190.685059,190.685059,190.685059,190.685059,190.685059


## Следующие шаги
- При необходимости увеличить `benchmark_steps`, списки batch size и kernel size.
- Сохранять `summary_df`, `detail_df`, `kernel_df` в CSV для отчёта.
- Добавить метрики точности (accuracy) при полном обучении.
