# ResNet18 Baseline Conv2d Benchmark

Сравнение nn.Conv2d и кастомной img2col→GEMM свёртки (Baseline TritonConv2d) на ResNet18 с разными batch size и сценариями спарсификации.


In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))


In [2]:
import copy
import json
import math
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
    raise RuntimeError("CUDA GPU is required for this benchmark")

seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = True

data_root = Path("../data").resolve()
data_root.mkdir(parents=True, exist_ok=True)

config = {
    "data_root": str(data_root),
    "num_classes": 10,
    "batch_sizes": [32, 64, 96, 128, 160, 192, 256],
    "num_workers": 4,
    "train_subset": 8192,
    "lr": 1e-3,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "warmup_steps": 5,
    "model_warmup_steps": 3,
    "benchmark_steps": 40,
    "baseline_conv": {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_K": 64,
        "NUM_WARPS": 4,
        "NUM_STAGES": 2,
    },
    "sparsity_bench": {
        "modes": ["channel", "block", "input"],
        "keep_ratios": [0.75, 0.6, 0.5, 0.25],
        "block_size": 4,
        "batch_size": 128,
    },
    "conv_layer_bench": {
        "warmup_steps": 5,
        "bench_steps": 20,
    },
}
print(json.dumps(config, indent=2))


{
  "data_root": "/mnt/d/VSCode-Projects/conv2d-img2col-gemm/data",
  "num_classes": 10,
  "batch_sizes": [
    32,
    64,
    96,
    128,
    160,
    192,
    256
  ],
  "num_workers": 4,
  "train_subset": 8192,
  "lr": 0.001,
  "momentum": 0.9,
  "weight_decay": 0.0005,
  "warmup_steps": 5,
  "model_warmup_steps": 3,
  "benchmark_steps": 40,
  "baseline_conv": {
    "BLOCK_M": 64,
    "BLOCK_N": 64,
    "BLOCK_K": 64,
    "NUM_WARPS": 4,
    "NUM_STAGES": 2
  },
  "sparsity_bench": {
    "modes": [
      "channel",
      "block",
      "input"
    ],
    "keep_ratios": [
      0.75,
      0.6,
      0.5,
      0.25
    ],
    "block_size": 4,
    "batch_size": 128
  },
  "conv_layer_bench": {
    "warmup_steps": 5,
    "bench_steps": 20
  }
}


In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

full_train = torchvision.datasets.CIFAR10(
    root=config["data_root"], train=True, download=True, transform=transform_train
)
if config["train_subset"] is not None and config["train_subset"] < len(full_train):
    g = torch.Generator().manual_seed(seed)
    subset_idx = torch.randperm(len(full_train), generator=g)[: config["train_subset"]]
    train_dataset = torch.utils.data.Subset(full_train, subset_idx)
else:
    train_dataset = full_train


def make_loader(batch_size: int) -> DataLoader:
    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=config["num_workers"],
        pin_memory=True,
    )

train_loaders: Dict[int, DataLoader] = {}
for bs in config["batch_sizes"]:
    train_loaders[bs] = make_loader(bs)

print({bs: len(loader) for bs, loader in train_loaders.items()})


{32: 256, 64: 128, 96: 85, 128: 64, 160: 51, 192: 42, 256: 32}


In [5]:
def make_triton_conv(src: nn.Conv2d, cfg: dict) -> TritonConv2d:
    if src.groups != 1:
        raise ValueError("Baseline TritonConv2d currently supports groups=1 only")
    layer = TritonConv2d(
        in_channels=src.in_channels,
        out_channels=src.out_channels,
        kernel_size=src.kernel_size,
        stride=src.stride,
        padding=src.padding,
        dilation=src.dilation,
        bias=(src.bias is not None),
        # **cfg,
    ).to(src.weight.device)
    with torch.no_grad():
        layer.weight.copy_(src.weight.detach().to(layer.weight.dtype))
        if layer.bias is not None and src.bias is not None:
            layer.bias.copy_(src.bias.detach().to(layer.bias.dtype))
    return layer


def replace_convs_with_baseline(module: nn.Module, cfg: dict):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, make_triton_conv(child, cfg))
        else:
            replace_convs_with_baseline(child, cfg)


def build_model_pair(config: dict):
    reference = torchvision.models.resnet18(num_classes=config["num_classes"])
    baseline = copy.deepcopy(reference)
    replace_convs_with_baseline(baseline, config["baseline_conv"])
    return reference.half(), baseline.half()


def apply_sparsity_to_model(model: nn.Module, mode: str, keep_ratio: float, block_size: int = 4):
    for layer in model.modules():
        if isinstance(layer, TritonConv2d):
            layer.clear_sparsity()
            if keep_ratio >= 1.0:
                continue
            if mode == "channel":
                layer.set_channel_sparsity(keep_ratio)
                layer.set_backward_channel_sparsity(keep_ratio)
            elif mode == "block":
                layer.set_block_sparsity(keep_ratio, block_size=block_size)
                layer.set_backward_block_sparsity(keep_ratio, block_size=block_size)
            elif mode == "input":
                layer.set_input_channel_sparsity(keep_ratio)
                layer.set_backward_input_channel_sparsity(keep_ratio)
            else:
                raise ValueError(f"Unknown sparsity mode: {mode}")


In [6]:
def run_benchmark(model: nn.Module, label: str, loader: DataLoader, config: dict):
    model = model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
    )

    warmup = config["warmup_steps"]
    total_steps = config["benchmark_steps"]
    model_warmup = config.get("model_warmup_steps", 0)
    records = []

    if model_warmup > 0:
        warmup_iter = iter(loader)
        for _ in range(model_warmup):
            try:
                images, targets = next(warmup_iter)
            except StopIteration:
                warmup_iter = iter(loader)
                images, targets = next(warmup_iter)

            images = images.half().to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.zero_grad(set_to_none=True)
            torch.cuda.synchronize()

        # Extra GPU warmup to drop JIT/cudnn noise from timed iterations
        torch.cuda.reset_peak_memory_stats(device)

    data_iter = iter(loader)

    for step in range(total_steps):
        try:
            images, targets = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            images, targets = next(data_iter)

        images = images.half().to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)
        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        fwd_start.record()
        outputs = model(images)
        fwd_end.record()
        loss = criterion(outputs, targets)

        bwd_start.record()
        loss.backward()
        bwd_end.record()
        optimizer.step()

        torch.cuda.synchronize()

        fwd_ms = fwd_start.elapsed_time(fwd_end)
        bwd_ms = bwd_start.elapsed_time(bwd_end)
        step_ms = fwd_ms + bwd_ms
        mem_alloc = torch.cuda.max_memory_allocated(device) / 1024 ** 2
        mem_reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 2

        if step >= warmup:
            records.append({
                "label": label,
                "step": step,
                "loss": float(loss.item()),
                "fwd_ms": fwd_ms,
                "bwd_ms": bwd_ms,
                "step_ms": step_ms,
                "throughput_sps": images.size(0) / (step_ms / 1000.0),
                "max_mem_alloc_mb": mem_alloc,
                "max_mem_reserved_mb": mem_reserved,
            })

    if not records:
        raise RuntimeError("No data recorded for benchmark")

    df = pd.DataFrame(records)
    summary = {
        "label": label,
        "avg_forward_ms": df["fwd_ms"].mean(),
        "avg_backward_ms": df["bwd_ms"].mean(),
        "avg_step_ms": df["step_ms"].mean(),
        "samples_per_s": df["throughput_sps"].mean(),
        "max_mem_alloc_mb": df["max_mem_alloc_mb"].max(),
        "max_mem_reserved_mb": df["max_mem_reserved_mb"].max(),
    }
    return df, summary


In [7]:
def is_conv_module(module: nn.Module) -> bool:
    return isinstance(module, (nn.Conv2d, TritonConv2d))


def collect_conv_input_shapes(model: nn.Module, sample: torch.Tensor) -> Dict[str, torch.Size]:
    shapes: Dict[str, torch.Size] = {}
    handles = []

    def make_hook(layer_name: str):
        def _hook(mod, inp):
            shapes.setdefault(layer_name, inp[0].shape)
            return None  # do not override inputs
        return _hook

    for name, module in model.named_modules():
        if is_conv_module(module):
            handles.append(module.register_forward_pre_hook(make_hook(name)))
    with torch.no_grad():
        model(sample)
    for h in handles:
        h.remove()
    return shapes


def conv_metadata(name: str, module: nn.Module) -> Dict[str, object]:
    meta = {
        "layer": name,
        "layer_type": type(module).__name__,
        "in_channels": getattr(module, "in_channels", None),
        "out_channels": getattr(module, "out_channels", None),
        "kernel_size": tuple(getattr(module, "kernel_size", [])) if hasattr(module, "kernel_size") else None,
        "stride": tuple(getattr(module, "stride", [])) if hasattr(module, "stride") else None,
        "padding": tuple(getattr(module, "padding", [])) if hasattr(module, "padding") else None,
        "dilation": tuple(getattr(module, "dilation", [])) if hasattr(module, "dilation") else None,
    }
    if isinstance(module, TritonConv2d):
        keep_out = float(module.channel_mask.float().mean().item()) if hasattr(module, "channel_mask") else 1.0
        keep_in = float(module.input_channel_mask.float().mean().item()) if hasattr(module, "input_channel_mask") else 1.0
        meta.update({
            "channel_keep_ratio": keep_out,
            "input_keep_ratio": keep_in,
            "block_size": getattr(module, "block_size", None),
            "grad_block_size": getattr(module, "grad_block_size", None),
        })
    return meta


def benchmark_single_conv(module: nn.Module, input_shape: torch.Size, device: torch.device, warmup: int, steps: int) -> Dict[str, float]:
    x = torch.randn(input_shape, device=device, dtype=torch.float16, requires_grad=True)
    layer = copy.deepcopy(module).to(device)
    layer.train()
    torch.cuda.synchronize()

    for _ in range(warmup):
        layer.zero_grad(set_to_none=True)
        out = layer(x)
        loss = out.float().sum()
        loss.backward()
        torch.cuda.synchronize()

    torch.cuda.reset_peak_memory_stats(device)
    records: List[Dict[str, float]] = []

    for _ in range(steps):
        layer.zero_grad(set_to_none=True)
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)
        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        fwd_start.record()
        out = layer(x)
        fwd_end.record()

        loss = out.float().sum()

        bwd_start.record()
        loss.backward()
        bwd_end.record()

        torch.cuda.synchronize()

        fwd_ms = fwd_start.elapsed_time(fwd_end)
        bwd_ms = bwd_start.elapsed_time(bwd_end)
        step_ms = fwd_ms + bwd_ms
        records.append({
            "avg_forward_ms": fwd_ms,
            "avg_backward_ms": bwd_ms,
            "avg_step_ms": step_ms,
            "throughput_sps": input_shape[0] / (step_ms / 1000.0),
            "max_mem_alloc_mb": torch.cuda.max_memory_allocated(device) / 1024 ** 2,
            "max_mem_reserved_mb": torch.cuda.max_memory_reserved(device) / 1024 ** 2,
        })

    if not records:
        raise RuntimeError("No data recorded for conv benchmark")

    df = pd.DataFrame(records)
    return {
        "avg_forward_ms": df["avg_forward_ms"].mean(),
        "avg_backward_ms": df["avg_backward_ms"].mean(),
        "avg_step_ms": df["avg_step_ms"].mean(),
        "throughput_sps": df["throughput_sps"].mean(),
        "max_mem_alloc_mb": df["max_mem_alloc_mb"].max(),
        "max_mem_reserved_mb": df["max_mem_reserved_mb"].max(),
    }


def benchmark_conv_layers(torch_model: nn.Module, baseline_model: nn.Module, batch_size: int, config: dict):
    bench_cfg = config.get("conv_layer_bench", {"warmup_steps": 3, "bench_steps": 10})
    warmup = bench_cfg.get("warmup_steps", 3)
    steps = bench_cfg.get("bench_steps", 10)

    sample = torch.randn(batch_size, 3, 32, 32, device=device, dtype=torch.float16)
    torch_model = torch_model.to(device).eval()
    baseline_model = baseline_model.to(device).eval()

    input_shapes = collect_conv_input_shapes(torch_model, sample)
    torch_conv_map = dict(torch_model.named_modules())
    baseline_conv_map = dict(baseline_model.named_modules())

    rows: List[Dict[str, object]] = []
    for name, inp_shape in input_shapes.items():
        torch_layer = torch_conv_map.get(name)
        baseline_layer = baseline_conv_map.get(name)
        if not (is_conv_module(torch_layer) and is_conv_module(baseline_layer)):
            continue

        for variant, layer in [("nn.Conv2d", torch_layer), ("Baseline TritonConv2d", baseline_layer)]:
            summary = benchmark_single_conv(layer, inp_shape, device, warmup, steps)
            meta = conv_metadata(name, layer)
            meta.update({
                "variant": variant,
                "batch_size": batch_size,
            })
            meta.update(summary)
            rows.append(meta)

    torch.cuda.empty_cache()
    return rows


Таблица `summary_df` показывает средние метрики по каждому batch size: `avg_forward_ms`, `avg_backward_ms`, `avg_step_ms`, `samples_per_s`, а также пики памяти (`max_mem_alloc_mb`, `max_mem_reserved_mb`).


In [8]:
batch_summaries = []
batch_details = []
conv_layer_rows = []

for bs, loader in train_loaders.items():
    print(f"=== Batch size {bs} ===")
    torch_model, baseline_model = build_model_pair(config)

    # per-layer bench (forward FP16, backward FP32)
    conv_layer_rows.extend(benchmark_conv_layers(torch_model, baseline_model, bs, config))

    torch_df, torch_summary = run_benchmark(torch_model, f"nn.Conv2d (bs={bs})", loader, config)
    torch_summary.update({"variant": "nn.Conv2d", "batch_size": bs})
    batch_summaries.append(torch_summary)
    batch_details.append(torch_df.assign(variant="nn.Conv2d", batch_size=bs))

    baseline_df, baseline_summary = run_benchmark(baseline_model, f"Baseline TritonConv2d (bs={bs})", loader, config)
    baseline_summary.update({"variant": "Baseline TritonConv2d", "batch_size": bs})
    batch_summaries.append(baseline_summary)
    batch_details.append(baseline_df.assign(variant="Baseline TritonConv2d", batch_size=bs))

summary_df = pd.DataFrame(batch_summaries).set_index(["variant", "batch_size"])
summary_df


=== Batch size 32 ===
=== Batch size 64 ===
=== Batch size 96 ===
=== Batch size 128 ===
=== Batch size 160 ===
=== Batch size 192 ===
=== Batch size 256 ===


Unnamed: 0_level_0,Unnamed: 1_level_0,label,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb
variant,batch_size,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
nn.Conv2d,32,nn.Conv2d (bs=32),7.778965,6.27437,14.053335,2353.629652,127.072266,138.0
Baseline TritonConv2d,32,Baseline TritonConv2d (bs=32),45.730244,18.65649,64.386734,503.617131,154.655273,182.0
nn.Conv2d,64,nn.Conv2d (bs=64),8.03872,6.581922,14.620642,4572.664194,127.260254,138.0
Baseline TritonConv2d,64,Baseline TritonConv2d (bs=64),43.745434,18.876972,62.622406,1029.01729,169.953613,204.0
nn.Conv2d,96,nn.Conv2d (bs=96),8.269396,6.013748,14.283144,6799.590031,127.57373,144.0
Baseline TritonConv2d,96,Baseline TritonConv2d (bs=96),41.662605,21.079979,62.742584,1538.861629,195.14209,238.0
nn.Conv2d,128,nn.Conv2d (bs=128),9.351014,6.526245,15.877258,8238.250459,132.817871,168.0
Baseline TritonConv2d,128,Baseline TritonConv2d (bs=128),42.343821,25.119656,67.463477,1904.794747,213.830078,312.0
nn.Conv2d,160,nn.Conv2d (bs=160),9.560735,7.603258,17.163994,9408.25363,140.381836,162.0
Baseline TritonConv2d,160,Baseline TritonConv2d (bs=160),45.031539,28.71887,73.750409,2197.614342,235.769043,284.0


Вывод `detail_df.groupby(...).describe()` содержит count/mean/std/min/25%/50%/75%/max для метрик `step_ms`, `fwd_ms`, `bwd_ms`, `max_mem_alloc_mb` отдельно по каждому `(variant, batch_size)`.


In [9]:
detail_df = pd.concat(batch_details, ignore_index=True)
metrics = ["step_ms", "fwd_ms", "bwd_ms", "max_mem_alloc_mb"]
detail_df.groupby(["variant", "batch_size"])[metrics].describe()


Unnamed: 0_level_0,Unnamed: 1_level_0,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,fwd_ms,fwd_ms,...,bwd_ms,bwd_ms,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
variant,batch_size,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
Baseline TritonConv2d,32,35.0,64.386734,7.714205,53.815456,57.372768,62.656929,68.443359,82.406975,35.0,45.730244,...,21.241857,39.800831,35.0,154.655273,0.0,154.655273,154.655273,154.655273,154.655273,154.655273
Baseline TritonConv2d,64,35.0,62.622406,5.43871,55.743107,58.467583,61.369698,65.310529,76.923553,35.0,43.745434,...,18.515968,30.216192,35.0,169.953613,0.0,169.953613,169.953613,169.953613,169.953613,169.953613
Baseline TritonConv2d,96,35.0,62.742584,5.100138,57.877247,59.487199,60.23085,63.796144,78.579872,35.0,41.662605,...,20.440576,32.645119,35.0,195.14209,0.0,195.14209,195.14209,195.14209,195.14209,195.14209
Baseline TritonConv2d,128,35.0,67.463477,4.368278,61.957441,63.979456,66.749025,70.167774,77.208542,35.0,42.343821,...,25.134592,35.337215,35.0,213.830078,0.0,213.830078,213.830078,213.830078,213.830078,213.830078
Baseline TritonConv2d,160,35.0,73.750409,9.086172,64.520256,67.379122,70.951519,78.539248,104.96035,35.0,45.031539,...,28.389888,43.01107,35.0,235.769043,0.0,235.769043,235.769043,235.769043,235.769043,235.769043
Baseline TritonConv2d,192,35.0,76.138919,4.620434,71.011902,73.222321,74.886974,76.5448,93.375553,35.0,40.830962,...,34.584576,45.238274,35.0,260.832031,0.0,260.832031,260.832031,260.832031,260.832031,260.832031
Baseline TritonConv2d,256,35.0,96.340511,9.653058,87.93568,89.920609,93.105152,98.683104,127.200932,35.0,46.241049,...,49.689089,58.495998,35.0,300.833496,0.0,300.833496,300.833496,300.833496,300.833496,300.833496
nn.Conv2d,32,35.0,14.053335,2.677401,10.188448,12.031408,13.859904,15.608224,20.778144,35.0,7.778965,...,7.471616,12.1856,35.0,127.072266,0.0,127.072266,127.072266,127.072266,127.072266,127.072266
nn.Conv2d,64,35.0,14.620642,3.887574,11.221376,12.570896,13.617663,15.03768,29.330273,35.0,8.03872,...,6.107136,20.396032,35.0,127.260254,0.0,127.260254,127.260254,127.260254,127.260254,127.260254
nn.Conv2d,96,35.0,14.283144,1.578264,10.919808,13.20088,14.309408,14.829072,18.125888,35.0,8.269396,...,6.769664,9.326592,35.0,127.57373,0.0,127.57373,127.57373,127.57373,127.57373,127.57373


In [10]:
forward_bs_top = (
    summary_df.reset_index()
    .sort_values("avg_forward_ms")
    .groupby("variant")
    .head(3)
    .reset_index(drop=True)
)

backward_bs_top = (
    summary_df.reset_index()
    .sort_values("avg_backward_ms")
    .groupby("variant")
    .head(3)
    .reset_index(drop=True)
)

forward_bs_top, backward_bs_top


(                 variant  batch_size                           label  \
 0              nn.Conv2d          32               nn.Conv2d (bs=32)   
 1              nn.Conv2d          64               nn.Conv2d (bs=64)   
 2              nn.Conv2d          96               nn.Conv2d (bs=96)   
 3  Baseline TritonConv2d         192  Baseline TritonConv2d (bs=192)   
 4  Baseline TritonConv2d          96   Baseline TritonConv2d (bs=96)   
 5  Baseline TritonConv2d         128  Baseline TritonConv2d (bs=128)   
 
    avg_forward_ms  avg_backward_ms  avg_step_ms  samples_per_s  \
 0        7.778965         6.274370    14.053335    2353.629652   
 1        8.038720         6.581922    14.620642    4572.664194   
 2        8.269396         6.013748    14.283144    6799.590031   
 3       40.830962        35.307957    76.138919    2529.765482   
 4       41.662605        21.079979    62.742584    1538.861629   
 5       42.343821        25.119656    67.463477    1904.794747   
 
    max_mem_allo

Per-layer metrics: forward/backward time and memory for each batch size and variant.

In [11]:
conv_layer_df = pd.DataFrame(conv_layer_rows)
conv_layer_df


Unnamed: 0,layer,layer_type,in_channels,out_channels,kernel_size,stride,padding,dilation,variant,batch_size,avg_forward_ms,avg_backward_ms,avg_step_ms,throughput_sps,max_mem_alloc_mb,max_mem_reserved_mb,channel_keep_ratio,input_keep_ratio,block_size,grad_block_size
0,conv1,Conv2d,3,64,"(7, 7)","(2, 2)","(3, 3)","(1, 1)",nn.Conv2d,32,0.137830,0.625101,0.762931,45026.716191,57.075684,68.0,,,,
1,conv1,TritonConv2d,3,64,"(7, 7)","(2, 2)","(3, 3)","(1, 1)",Baseline TritonConv2d,32,1.727437,1.455002,3.182438,10089.058649,74.058594,90.0,1.0,1.0,,
2,layer1.0.conv1,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,0.197837,0.647936,0.845773,41442.751906,54.921387,68.0,,,,
3,layer1.0.conv1,TritonConv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",Baseline TritonConv2d,32,2.420838,1.333709,3.754547,9043.330822,70.132812,90.0,1.0,1.0,,
4,layer1.0.conv2,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,0.160666,0.608973,0.769638,45770.083699,54.921387,90.0,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
275,layer4.0.downsample.0,TritonConv2d,256,512,"(1, 1)","(2, 2)","(0, 0)","(1, 1)",Baseline TritonConv2d,256,1.783552,1.002085,2.785637,93776.683330,69.257812,82.0,1.0,1.0,,
276,layer4.1.conv1,Conv2d,512,512,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,256,0.277398,0.615270,0.892669,290090.643948,77.257324,102.0,,,,
277,layer4.1.conv1,TritonConv2d,512,512,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",Baseline TritonConv2d,256,1.790870,1.448704,3.239574,79458.781348,114.757812,142.0,1.0,1.0,,
278,layer4.1.conv2,Conv2d,512,512,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,256,0.266803,0.619878,0.886682,290782.575864,77.257324,142.0,,,,


In [12]:
torch_conv_df = conv_layer_df[conv_layer_df["variant"] == "nn.Conv2d"]
baseline_conv_df = conv_layer_df[conv_layer_df["variant"] == "Baseline TritonConv2d"]

conv_layer_compare_df = torch_conv_df.merge(
    baseline_conv_df,
    on=["layer", "batch_size"],
    suffixes=("_torch", "_baseline"),
)

conv_layer_compare_df["speedup_forward"] = conv_layer_compare_df["avg_forward_ms_torch"] / conv_layer_compare_df["avg_forward_ms_baseline"]
conv_layer_compare_df["speedup_backward"] = conv_layer_compare_df["avg_backward_ms_torch"] / conv_layer_compare_df["avg_backward_ms_baseline"]
conv_layer_compare_df["speedup_step"] = conv_layer_compare_df["avg_step_ms_torch"] / conv_layer_compare_df["avg_step_ms_baseline"]
conv_layer_compare_df["throughput_ratio"] = conv_layer_compare_df["throughput_sps_baseline"] / conv_layer_compare_df["throughput_sps_torch"]
conv_layer_compare_df["mem_alloc_ratio"] = conv_layer_compare_df["max_mem_alloc_mb_baseline"] / conv_layer_compare_df["max_mem_alloc_mb_torch"]
conv_layer_compare_df["mem_reserved_ratio"] = conv_layer_compare_df["max_mem_reserved_mb_baseline"] / conv_layer_compare_df["max_mem_reserved_mb_torch"]
conv_layer_compare_df


Unnamed: 0,layer,layer_type_torch,in_channels_torch,out_channels_torch,kernel_size_torch,stride_torch,padding_torch,dilation_torch,variant_torch,batch_size,...,channel_keep_ratio_baseline,input_keep_ratio_baseline,block_size_baseline,grad_block_size_baseline,speedup_forward,speedup_backward,speedup_step,throughput_ratio,mem_alloc_ratio,mem_reserved_ratio
0,conv1,Conv2d,3,64,"(7, 7)","(2, 2)","(3, 3)","(1, 1)",nn.Conv2d,32,...,1.0,1.0,,,0.079789,0.429622,0.239732,0.224068,1.297551,1.323529
1,layer1.0.conv1,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,...,1.0,1.0,,,0.081722,0.485815,0.225266,0.218213,1.276967,1.323529
2,layer1.0.conv2,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,...,1.0,1.0,,,0.085013,0.556757,0.257949,0.242236,1.276967,1.000000
3,layer1.1.conv1,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,...,1.0,1.0,,,0.091507,0.434048,0.211179,0.218649,1.276967,1.000000
4,layer1.1.conv2,Conv2d,64,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,32,...,1.0,1.0,,,0.087876,0.356348,0.194277,0.191248,1.276967,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
135,layer4.0.conv1,Conv2d,256,512,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",nn.Conv2d,256,...,1.0,1.0,,,0.103661,0.545160,0.259775,0.255342,1.284700,1.282051
136,layer4.0.conv2,Conv2d,512,512,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,256,...,1.0,1.0,,,0.146154,0.440092,0.272823,0.276642,1.485397,1.408163
137,layer4.0.downsample.0,Conv2d,256,512,"(1, 1)","(2, 2)","(0, 0)","(1, 1)",nn.Conv2d,256,...,1.0,1.0,,,0.084140,0.479819,0.226478,0.221775,1.065385,1.051282
138,layer4.1.conv1,Conv2d,512,512,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",nn.Conv2d,256,...,1.0,1.0,,,0.154896,0.424704,0.275551,0.273910,1.485397,1.392157


In [13]:
conv_layer_ranking_df = conv_layer_compare_df[[
    "layer",
    "batch_size",
    "kernel_size_torch",
    "stride_torch",
    "padding_torch",
    "dilation_torch",
    "channel_keep_ratio_baseline",
    "input_keep_ratio_baseline",
    "block_size_baseline",
    "grad_block_size_baseline",
    "avg_forward_ms_torch",
    "avg_forward_ms_baseline",
    "avg_backward_ms_torch",
    "avg_backward_ms_baseline",
    "avg_step_ms_torch",
    "avg_step_ms_baseline",
    "throughput_ratio",
    "speedup_forward",
    "speedup_backward",
    "speedup_step",
    "mem_alloc_ratio",
    "mem_reserved_ratio",
]].sort_values("speedup_step", ascending=False).reset_index(drop=True)
conv_layer_ranking_df.head(15)


Unnamed: 0,layer,batch_size,kernel_size_torch,stride_torch,padding_torch,dilation_torch,channel_keep_ratio_baseline,input_keep_ratio_baseline,block_size_baseline,grad_block_size_baseline,...,avg_backward_ms_torch,avg_backward_ms_baseline,avg_step_ms_torch,avg_step_ms_baseline,throughput_ratio,speedup_forward,speedup_backward,speedup_step,mem_alloc_ratio,mem_reserved_ratio
0,layer3.0.downsample.0,96,"(1, 1)","(2, 2)","(0, 0)","(1, 1)",1.0,1.0,,,...,1.018408,1.554226,1.286645,3.819518,0.244453,0.118412,0.655251,0.33686,1.04167,1.058824
1,layer4.1.conv1,32,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.710349,1.164186,0.974746,3.010253,0.320859,0.143222,0.610168,0.323809,1.366681,1.227273
2,layer2.0.conv1,32,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.740506,1.199565,1.005773,3.204506,0.291829,0.132307,0.617312,0.313862,1.095275,1.029412
3,layer3.1.conv1,96,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.636006,1.042944,0.874394,2.826803,0.289148,0.133636,0.609818,0.309322,1.261351,1.277778
4,layer2.1.conv2,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.663757,1.089843,0.827085,2.734746,0.269417,0.099293,0.609039,0.302436,1.259493,1.0
5,layer4.1.conv2,96,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.672614,1.287987,0.97495,3.261133,0.29809,0.153225,0.522221,0.298961,1.391945,1.0
6,layer3.0.conv1,64,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.617267,1.039155,0.805069,2.706739,0.280809,0.112619,0.594009,0.297431,1.12855,1.285714
7,layer4.0.conv2,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.560742,0.98217,0.811469,2.745037,0.295681,0.142226,0.570922,0.295613,1.36437,1.232558
8,layer2.0.conv2,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.634098,1.049856,0.796813,2.724454,0.268399,0.097167,0.603985,0.292467,1.259493,1.352941
9,layer4.1.conv2,64,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,,,...,0.557926,1.016166,0.841062,2.87959,0.291464,0.151944,0.54905,0.292077,1.36437,1.0


`baseline_vs_torch_df` сравнивает nn.Conv2d и Baseline TritonConv2d: пары столбцов с абсолютными значениями (forward/backward/step время, throughput, память) и коэффициенты ускорения (`speedup_*`, `throughput_ratio`, `mem_*_ratio`).


In [14]:
baseline_compare_rows = []
for bs in config["batch_sizes"]:
    torch_row = summary_df.loc[("nn.Conv2d", bs)]
    baseline_row = summary_df.loc[("Baseline TritonConv2d", bs)]
    comparison = {
        "batch_size": bs,
        "torch_forward_ms": torch_row["avg_forward_ms"],
        "baseline_forward_ms": baseline_row["avg_forward_ms"],
        "torch_backward_ms": torch_row["avg_backward_ms"],
        "baseline_backward_ms": baseline_row["avg_backward_ms"],
        "torch_step_ms": torch_row["avg_step_ms"],
        "baseline_step_ms": baseline_row["avg_step_ms"],
        "torch_samples_per_s": torch_row["samples_per_s"],
        "baseline_samples_per_s": baseline_row["samples_per_s"],
        "speedup_forward": torch_row["avg_forward_ms"] / baseline_row["avg_forward_ms"],
        "speedup_backward": torch_row["avg_backward_ms"] / baseline_row["avg_backward_ms"],
        "speedup_step": torch_row["avg_step_ms"] / baseline_row["avg_step_ms"],
        "throughput_ratio": baseline_row["samples_per_s"] / torch_row["samples_per_s"],
        "torch_mem_alloc_mb": torch_row["max_mem_alloc_mb"],
        "baseline_mem_alloc_mb": baseline_row["max_mem_alloc_mb"],
        "torch_mem_reserved_mb": torch_row["max_mem_reserved_mb"],
        "baseline_mem_reserved_mb": baseline_row["max_mem_reserved_mb"],
        "mem_alloc_ratio": baseline_row["max_mem_alloc_mb"] / torch_row["max_mem_alloc_mb"],
        "mem_reserved_ratio": baseline_row["max_mem_reserved_mb"] / torch_row["max_mem_reserved_mb"],
    }
    baseline_compare_rows.append(comparison)

baseline_vs_torch_df = pd.DataFrame(baseline_compare_rows).set_index("batch_size")
baseline_vs_torch_df


Unnamed: 0_level_0,torch_forward_ms,baseline_forward_ms,torch_backward_ms,baseline_backward_ms,torch_step_ms,baseline_step_ms,torch_samples_per_s,baseline_samples_per_s,speedup_forward,speedup_backward,speedup_step,throughput_ratio,torch_mem_alloc_mb,baseline_mem_alloc_mb,torch_mem_reserved_mb,baseline_mem_reserved_mb,mem_alloc_ratio,mem_reserved_ratio
batch_size,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
32,7.778965,45.730244,6.27437,18.65649,14.053335,64.386734,2353.629652,503.617131,0.170105,0.33631,0.218264,0.213975,127.072266,154.655273,138.0,182.0,1.217066,1.318841
64,8.03872,43.745434,6.581922,18.876972,14.620642,62.622406,4572.664194,1029.01729,0.183761,0.348675,0.233473,0.225037,127.260254,169.953613,138.0,204.0,1.335481,1.478261
96,8.269396,41.662605,6.013748,21.079979,14.283144,62.742584,6799.590031,1538.861629,0.198485,0.285282,0.227647,0.226317,127.57373,195.14209,144.0,238.0,1.529642,1.652778
128,9.351014,42.343821,6.526245,25.119656,15.877258,67.463477,8238.250459,1904.794747,0.220835,0.259806,0.235346,0.231214,132.817871,213.830078,168.0,312.0,1.60995,1.857143
160,9.560735,45.031539,7.603258,28.71887,17.163994,73.750409,9408.25363,2197.614342,0.212312,0.264748,0.232731,0.233584,140.381836,235.769043,162.0,284.0,1.679484,1.753086
192,8.512421,40.830962,8.360785,35.307957,16.873207,76.138919,11492.926501,2529.765482,0.20848,0.236796,0.221611,0.220115,148.069824,260.832031,160.0,328.0,1.761548,2.05
256,9.572772,46.241049,10.040291,50.099462,19.613062,96.340511,13442.939432,2679.841985,0.207019,0.200407,0.203581,0.199349,162.446289,300.833496,196.0,500.0,1.851895,2.55102


Таблица `summary_df` показывает средние метрики по каждому batch size: `avg_forward_ms`, `avg_backward_ms`, `avg_step_ms`, `samples_per_s`, а также пики памяти (`max_mem_alloc_mb`, `max_mem_reserved_mb`).


In [19]:
sparsity_cfg = config["sparsity_bench"]
sparsity_bs = sparsity_cfg["batch_size"]
if sparsity_bs not in train_loaders:
    train_loaders[sparsity_bs] = make_loader(sparsity_bs)
sparsity_loader = train_loaders[sparsity_bs]

sparsity_summaries = []
sparsity_details = []

for mode in sparsity_cfg["modes"]:
    for ratio in sparsity_cfg["keep_ratios"]:
        _, baseline_model = build_model_pair(config)
        apply_sparsity_to_model(
            baseline_model,
            mode,
            keep_ratio=ratio,
            block_size=sparsity_cfg.get("block_size", 4),
        )
        label = f"{mode.capitalize()} sparsity (keep={ratio:.2f}, bs={sparsity_bs})"
        bench_df, bench_summary = run_benchmark(baseline_model, label, sparsity_loader, config)
        bench_summary.update({
            "variant": f"Sparsity::{mode}",
            "mode": mode,
            "keep_ratio": ratio,
            "batch_size": sparsity_bs,
        })
        sparsity_summaries.append(bench_summary)
        sparsity_details.append(
            bench_df.assign(variant=f"Sparsity::{mode}", mode=mode, keep_ratio=ratio, batch_size=sparsity_bs)
        )

sparsity_summary_df = pd.DataFrame(sparsity_summaries).sort_values("samples_per_s", ascending=False).reset_index(drop=True)
sparsity_summary_df


Unnamed: 0,label,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb,variant,mode,keep_ratio,batch_size
0,"Block sparsity (keep=0.50, bs=128)",58.12284,24.471259,82.594099,1554.019946,214.039062,506.0,Sparsity::block,block,0.5,128
1,"Block sparsity (keep=0.60, bs=128)",60.420029,27.095655,87.515684,1469.793185,214.513184,506.0,Sparsity::block,block,0.6,128
2,"Channel sparsity (keep=0.75, bs=128)",62.082142,25.897778,87.979921,1468.674844,214.404297,506.0,Sparsity::channel,channel,0.75,128
3,"Block sparsity (keep=0.25, bs=128)",62.456207,25.723875,88.180081,1466.526397,214.42627,506.0,Sparsity::block,block,0.25,128
4,"Input sparsity (keep=0.50, bs=128)",60.679336,27.183104,87.86244,1465.291647,190.521484,506.0,Sparsity::input,input,0.5,128
5,"Input sparsity (keep=0.75, bs=128)",60.569348,28.613808,89.183155,1446.018527,201.649414,506.0,Sparsity::input,input,0.75,128
6,"Channel sparsity (keep=0.50, bs=128)",63.34762,27.655227,91.002846,1422.986126,214.039062,506.0,Sparsity::channel,channel,0.5,128
7,"Channel sparsity (keep=0.25, bs=128)",63.75313,27.683518,91.436648,1408.948505,214.42627,506.0,Sparsity::channel,channel,0.25,128
8,"Input sparsity (keep=0.25, bs=128)",64.828435,27.167012,91.995447,1402.890584,172.675293,506.0,Sparsity::input,input,0.25,128
9,"Channel sparsity (keep=0.60, bs=128)",64.197925,29.116798,93.314724,1394.383639,214.482422,506.0,Sparsity::channel,channel,0.6,128


`sparsity_compare_df` добавляет к тем же сценариям относительные значения относительно эталонного nn.Conv2d (`speedup_*_vs_torch`, `throughput_ratio_vs_torch`, `mem_*_ratio_vs_torch`).


In [20]:
sparsity_reference = summary_df.loc[("nn.Conv2d", sparsity_bs)]

sparsity_compare_df = sparsity_summary_df.copy()
sparsity_compare_df["speedup_forward_vs_torch"] = sparsity_reference["avg_forward_ms"] / sparsity_compare_df["avg_forward_ms"]
sparsity_compare_df["speedup_backward_vs_torch"] = sparsity_reference["avg_backward_ms"] / sparsity_compare_df["avg_backward_ms"]
sparsity_compare_df["speedup_step_vs_torch"] = sparsity_reference["avg_step_ms"] / sparsity_compare_df["avg_step_ms"]
sparsity_compare_df["throughput_ratio_vs_torch"] = sparsity_compare_df["samples_per_s"] / sparsity_reference["samples_per_s"]
sparsity_compare_df["mem_alloc_ratio_vs_torch"] = sparsity_compare_df["max_mem_alloc_mb"] / sparsity_reference["max_mem_alloc_mb"]
sparsity_compare_df["mem_reserved_ratio_vs_torch"] = sparsity_compare_df["max_mem_reserved_mb"] / sparsity_reference["max_mem_reserved_mb"]
sparsity_compare_df = sparsity_compare_df.sort_values("samples_per_s", ascending=False).reset_index(drop=True)
sparsity_compare_df


Unnamed: 0,label,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb,variant,mode,keep_ratio,batch_size,speedup_forward_vs_torch,speedup_backward_vs_torch,speedup_step_vs_torch,throughput_ratio_vs_torch,mem_alloc_ratio_vs_torch,mem_reserved_ratio_vs_torch
0,"Block sparsity (keep=0.50, bs=128)",58.12284,24.471259,82.594099,1554.019946,214.039062,506.0,Sparsity::block,block,0.5,128,0.160884,0.26669,0.192232,0.188635,1.611523,3.011905
1,"Block sparsity (keep=0.60, bs=128)",60.420029,27.095655,87.515684,1469.793185,214.513184,506.0,Sparsity::block,block,0.6,128,0.154767,0.240859,0.181422,0.178411,1.615093,3.011905
2,"Channel sparsity (keep=0.75, bs=128)",62.082142,25.897778,87.979921,1468.674844,214.404297,506.0,Sparsity::channel,channel,0.75,128,0.150623,0.252,0.180465,0.178275,1.614273,3.011905
3,"Block sparsity (keep=0.25, bs=128)",62.456207,25.723875,88.180081,1466.526397,214.42627,506.0,Sparsity::block,block,0.25,128,0.149721,0.253704,0.180055,0.178014,1.614438,3.011905
4,"Input sparsity (keep=0.50, bs=128)",60.679336,27.183104,87.86244,1465.291647,190.521484,506.0,Sparsity::input,input,0.5,128,0.154105,0.240085,0.180706,0.177864,1.434457,3.011905
5,"Input sparsity (keep=0.75, bs=128)",60.569348,28.613808,89.183155,1446.018527,201.649414,506.0,Sparsity::input,input,0.75,128,0.154385,0.22808,0.17803,0.175525,1.51824,3.011905
6,"Channel sparsity (keep=0.50, bs=128)",63.34762,27.655227,91.002846,1422.986126,214.039062,506.0,Sparsity::channel,channel,0.5,128,0.147614,0.235986,0.17447,0.172729,1.611523,3.011905
7,"Channel sparsity (keep=0.25, bs=128)",63.75313,27.683518,91.436648,1408.948505,214.42627,506.0,Sparsity::channel,channel,0.25,128,0.146675,0.235745,0.173642,0.171025,1.614438,3.011905
8,"Input sparsity (keep=0.25, bs=128)",64.828435,27.167012,91.995447,1402.890584,172.675293,506.0,Sparsity::input,input,0.25,128,0.144242,0.240227,0.172587,0.17029,1.300091,3.011905
9,"Channel sparsity (keep=0.60, bs=128)",64.197925,29.116798,93.314724,1394.383639,214.482422,506.0,Sparsity::channel,channel,0.6,128,0.145659,0.22414,0.170147,0.169257,1.614861,3.011905


`ranking_df` — упорядоченный рейтинг сценариев спарсификации: показывает `mode`, `keep_ratio`, абсолютный throughput и его отношение к торчу, а также ускорения forward/backward/step и изменение памяти.


In [21]:
ranking_df = sparsity_compare_df[[
    "variant",
    "mode",
    "keep_ratio",
    "samples_per_s",
    "throughput_ratio_vs_torch",
    "speedup_forward_vs_torch",
    "speedup_backward_vs_torch",
    "speedup_step_vs_torch",
    "mem_alloc_ratio_vs_torch",
    "mem_reserved_ratio_vs_torch",
]].copy()
ranking_df = ranking_df.sort_values("throughput_ratio_vs_torch", ascending=False).reset_index(drop=True)
ranking_df


Unnamed: 0,variant,mode,keep_ratio,samples_per_s,throughput_ratio_vs_torch,speedup_forward_vs_torch,speedup_backward_vs_torch,speedup_step_vs_torch,mem_alloc_ratio_vs_torch,mem_reserved_ratio_vs_torch
0,Sparsity::block,block,0.5,1554.019946,0.188635,0.160884,0.26669,0.192232,1.611523,3.011905
1,Sparsity::block,block,0.6,1469.793185,0.178411,0.154767,0.240859,0.181422,1.615093,3.011905
2,Sparsity::channel,channel,0.75,1468.674844,0.178275,0.150623,0.252,0.180465,1.614273,3.011905
3,Sparsity::block,block,0.25,1466.526397,0.178014,0.149721,0.253704,0.180055,1.614438,3.011905
4,Sparsity::input,input,0.5,1465.291647,0.177864,0.154105,0.240085,0.180706,1.434457,3.011905
5,Sparsity::input,input,0.75,1446.018527,0.175525,0.154385,0.22808,0.17803,1.51824,3.011905
6,Sparsity::channel,channel,0.5,1422.986126,0.172729,0.147614,0.235986,0.17447,1.611523,3.011905
7,Sparsity::channel,channel,0.25,1408.948505,0.171025,0.146675,0.235745,0.173642,1.614438,3.011905
8,Sparsity::input,input,0.25,1402.890584,0.17029,0.144242,0.240227,0.172587,1.300091,3.011905
9,Sparsity::channel,channel,0.6,1394.383639,0.169257,0.145659,0.22414,0.170147,1.614861,3.011905


Final rankings for model batch sizes and per-layer convs.

Model batch-size rankings (step/throughput/memory).

In [22]:
model_step_top = (
    summary_df.reset_index()
    .sort_values("avg_step_ms")
    .groupby("variant")
    .head(3)
    .reset_index(drop=True)
)

model_throughput_top = (
    summary_df.reset_index()
    .sort_values("samples_per_s", ascending=False)
    .groupby("variant")
    .head(3)
    .reset_index(drop=True)
)

model_memory_top = (
    summary_df.reset_index()
    .sort_values("max_mem_alloc_mb")
    .groupby("variant")
    .head(3)
    .reset_index(drop=True)
)


In [23]:
model_rankings_df = pd.concat(
    [
        model_step_top.assign(metric="fastest_step"),
        model_throughput_top.assign(metric="highest_throughput"),
        model_memory_top.assign(metric="lowest_mem_alloc"),
    ],
    ignore_index=True,
)
model_rankings_df = model_rankings_df[[
    "metric",
    "variant",
    "batch_size",
    "avg_forward_ms",
    "avg_backward_ms",
    "avg_step_ms",
    "samples_per_s",
    "max_mem_alloc_mb",
    "max_mem_reserved_mb",
]]
model_rankings_df


Unnamed: 0,metric,variant,batch_size,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb
0,fastest_step,nn.Conv2d,32,7.778965,6.27437,14.053335,2353.629652,127.072266,138.0
1,fastest_step,nn.Conv2d,96,8.269396,6.013748,14.283144,6799.590031,127.57373,144.0
2,fastest_step,nn.Conv2d,64,8.03872,6.581922,14.620642,4572.664194,127.260254,138.0
3,fastest_step,Baseline TritonConv2d,64,43.745434,18.876972,62.622406,1029.01729,169.953613,204.0
4,fastest_step,Baseline TritonConv2d,96,41.662605,21.079979,62.742584,1538.861629,195.14209,238.0
5,fastest_step,Baseline TritonConv2d,32,45.730244,18.65649,64.386734,503.617131,154.655273,182.0
6,highest_throughput,nn.Conv2d,256,9.572772,10.040291,19.613062,13442.939432,162.446289,196.0
7,highest_throughput,nn.Conv2d,192,8.512421,8.360785,16.873207,11492.926501,148.069824,160.0
8,highest_throughput,nn.Conv2d,160,9.560735,7.603258,17.163994,9408.25363,140.381836,162.0
9,highest_throughput,Baseline TritonConv2d,256,46.241049,50.099462,96.340511,2679.841985,300.833496,500.0


In [24]:
conv_forward_top = conv_layer_compare_df.sort_values("avg_forward_ms_baseline").head(10).assign(metric="forward_time")
conv_backward_top = conv_layer_compare_df.sort_values("avg_backward_ms_baseline").head(10).assign(metric="backward_time")
conv_speedup_top = conv_layer_compare_df.sort_values("speedup_step", ascending=False).head(15).assign(metric="speedup_step")

conv_layer_best_df = pd.concat(
    [conv_forward_top, conv_backward_top, conv_speedup_top],
    ignore_index=True,
)

conv_layer_best_df = conv_layer_best_df[[
    "metric",
    "layer",
    "batch_size",
    "layer_type_baseline",
    "kernel_size_baseline",
    "stride_baseline",
    "padding_baseline",
    "dilation_baseline",
    "channel_keep_ratio_baseline",
    "input_keep_ratio_baseline",
    "block_size_baseline",
    "grad_block_size_baseline",
    "avg_forward_ms_baseline",
    "avg_backward_ms_baseline",
    "avg_step_ms_baseline",
    "speedup_forward",
    "speedup_backward",
    "speedup_step",
    "throughput_ratio",
    "mem_alloc_ratio",
    "mem_reserved_ratio",
]]
conv_layer_best_df


Unnamed: 0,metric,layer,batch_size,layer_type_baseline,kernel_size_baseline,stride_baseline,padding_baseline,dilation_baseline,channel_keep_ratio_baseline,input_keep_ratio_baseline,...,grad_block_size_baseline,avg_forward_ms_baseline,avg_backward_ms_baseline,avg_step_ms_baseline,speedup_forward,speedup_backward,speedup_step,throughput_ratio,mem_alloc_ratio,mem_reserved_ratio
0,forward_time,layer4.0.conv1,160,TritonConv2d,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",1.0,1.0,...,,1.575014,1.020466,2.59548,0.12561,0.463853,0.258597,0.254855,1.236068,1.022222
1,forward_time,layer3.0.conv2,128,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.595597,1.034035,2.629632,0.112117,0.476036,0.255219,0.254613,1.31884,1.314286
2,forward_time,layer4.1.conv2,192,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.598261,1.326693,2.924954,0.157996,0.414635,0.274401,0.274209,1.438632,1.0
3,forward_time,layer4.0.conv1,192,TritonConv2d,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",1.0,1.0,...,,1.608602,0.943414,2.552016,0.119517,0.509982,0.263862,0.2625,1.262721,1.255814
4,forward_time,layer4.0.conv2,160,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.612544,1.285734,2.898278,0.159834,0.423508,0.276805,0.27737,1.421627,1.444444
5,forward_time,layer3.0.conv2,160,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.630259,1.117392,2.747651,0.122892,0.448633,0.255362,0.254346,1.385012,1.6
6,forward_time,layer3.1.conv2,192,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.631949,1.227008,2.858957,0.123518,0.420113,0.25081,0.249756,1.42332,1.0
7,forward_time,layer3.0.conv1,192,TritonConv2d,"(3, 3)","(2, 2)","(1, 1)","(1, 1)",1.0,1.0,...,,1.632819,0.980786,2.613605,0.115423,0.483247,0.253453,0.249666,1.238774,1.022727
8,forward_time,layer1.1.conv1,128,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.637384,2.000333,3.637717,0.098311,0.234201,0.173035,0.171005,1.90432,1.0
9,forward_time,layer3.1.conv1,128,TritonConv2d,"(3, 3)","(1, 1)","(1, 1)","(1, 1)",1.0,1.0,...,,1.64009,1.037774,2.677864,0.10505,0.466574,0.245154,0.243783,1.31884,1.27027
