# ResNet18 Baseline Conv2d Benchmark

Сравнение nn.Conv2d и кастомной img2col→GEMM свёртки (Baseline TritonConv2d) на ResNet18 с разными batch size и сценариями спарсификации.


In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))


In [2]:
import copy
import json
import math
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
    raise RuntimeError("CUDA GPU is required for this benchmark")

seed = 42
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = True

data_root = Path("../data").resolve()
data_root.mkdir(parents=True, exist_ok=True)

config = {
    "data_root": str(data_root),
    "num_classes": 10,
    "batch_sizes": [32, 64, 96, 128, 160, 192, 256],
    "num_workers": 4,
    "train_subset": 8192,
    "lr": 1e-3,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "warmup_steps": 5,
    "benchmark_steps": 40,
    "baseline_conv": {
        "BLOCK_M": 64,
        "BLOCK_N": 64,
        "BLOCK_K": 64,
        "NUM_WARPS": 4,
        "NUM_STAGES": 2,
    },
    "sparsity_bench": {
        "modes": ["channel", "block", "input"],
        "keep_ratios": [0.75, 0.6, 0.5, 0.25],
        "block_size": 4,
        "batch_size": 128,
    },
}
print(json.dumps(config, indent=2))


{
  "data_root": "/home/manzhura/ITMO/EDLM/conv2d-img2col-gemm/data",
  "num_classes": 10,
  "batch_sizes": [
    32,
    64,
    96,
    128,
    160,
    192,
    256
  ],
  "num_workers": 4,
  "train_subset": 8192,
  "lr": 0.001,
  "momentum": 0.9,
  "weight_decay": 0.0005,
  "warmup_steps": 5,
  "benchmark_steps": 40,
  "baseline_conv": {
    "BLOCK_M": 64,
    "BLOCK_N": 64,
    "BLOCK_K": 64,
    "NUM_WARPS": 4,
    "NUM_STAGES": 2
  },
  "sparsity_bench": {
    "modes": [
      "channel",
      "block",
      "input"
    ],
    "keep_ratios": [
      0.75,
      0.6,
      0.5,
      0.25
    ],
    "block_size": 4,
    "batch_size": 128
  }
}


In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

full_train = torchvision.datasets.CIFAR10(
    root=config["data_root"], train=True, download=True, transform=transform_train
)
if config["train_subset"] is not None and config["train_subset"] < len(full_train):
    g = torch.Generator().manual_seed(seed)
    subset_idx = torch.randperm(len(full_train), generator=g)[: config["train_subset"]]
    train_dataset = torch.utils.data.Subset(full_train, subset_idx)
else:
    train_dataset = full_train


def make_loader(batch_size: int) -> DataLoader:
    return DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=config["num_workers"],
        pin_memory=True,
    )

train_loaders: Dict[int, DataLoader] = {}
for bs in config["batch_sizes"]:
    train_loaders[bs] = make_loader(bs)

print({bs: len(loader) for bs, loader in train_loaders.items()})


{32: 256, 64: 128, 96: 85, 128: 64, 160: 51, 192: 42, 256: 32}


In [5]:
def make_triton_conv(src: nn.Conv2d, cfg: dict) -> TritonConv2d:
    if src.groups != 1:
        raise ValueError("Baseline TritonConv2d currently supports groups=1 only")
    layer = TritonConv2d(
        in_channels=src.in_channels,
        out_channels=src.out_channels,
        kernel_size=src.kernel_size,
        stride=src.stride,
        padding=src.padding,
        dilation=src.dilation,
        bias=(src.bias is not None),
        # **cfg,
    ).to(src.weight.device)
    with torch.no_grad():
        layer.weight.copy_(src.weight.detach().to(layer.weight.dtype))
        if layer.bias is not None and src.bias is not None:
            layer.bias.copy_(src.bias.detach().to(layer.bias.dtype))
    return layer


def replace_convs_with_baseline(module: nn.Module, cfg: dict):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, make_triton_conv(child, cfg))
        else:
            replace_convs_with_baseline(child, cfg)


def build_model_pair(config: dict):
    reference = torchvision.models.resnet18(num_classes=config["num_classes"])
    baseline = copy.deepcopy(reference)
    replace_convs_with_baseline(baseline, config["baseline_conv"])
    return reference, baseline


def apply_sparsity_to_model(model: nn.Module, mode: str, keep_ratio: float, block_size: int = 4):
    for layer in model.modules():
        if isinstance(layer, TritonConv2d):
            layer.clear_sparsity()
            if keep_ratio >= 1.0:
                continue
            if mode == "channel":
                layer.set_channel_sparsity(keep_ratio)
                layer.set_backward_channel_sparsity(keep_ratio)
            elif mode == "block":
                layer.set_block_sparsity(keep_ratio, block_size=block_size)
                layer.set_backward_block_sparsity(keep_ratio, block_size=block_size)
            elif mode == "input":
                layer.set_input_channel_sparsity(keep_ratio)
                layer.set_backward_input_channel_sparsity(keep_ratio)
            else:
                raise ValueError(f"Unknown sparsity mode: {mode}")


In [6]:
def run_benchmark(model: nn.Module, label: str, loader: DataLoader, config: dict):
    model = model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
    )

    warmup = config["warmup_steps"]
    total_steps = config["benchmark_steps"]
    records = []
    data_iter = iter(loader)

    for step in range(total_steps):
        try:
            images, targets = next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            images, targets = next(data_iter)

        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()

        fwd_start = torch.cuda.Event(enable_timing=True)
        fwd_end = torch.cuda.Event(enable_timing=True)
        bwd_start = torch.cuda.Event(enable_timing=True)
        bwd_end = torch.cuda.Event(enable_timing=True)

        fwd_start.record()
        outputs = model(images)
        fwd_end.record()
        loss = criterion(outputs, targets)

        bwd_start.record()
        loss.backward()
        bwd_end.record()
        optimizer.step()

        torch.cuda.synchronize()

        fwd_ms = fwd_start.elapsed_time(fwd_end)
        bwd_ms = bwd_start.elapsed_time(bwd_end)
        step_ms = fwd_ms + bwd_ms
        mem_alloc = torch.cuda.max_memory_allocated(device) / 1024 ** 2
        mem_reserved = torch.cuda.max_memory_reserved(device) / 1024 ** 2

        if step >= warmup:
            records.append({
                "label": label,
                "step": step,
                "loss": float(loss.item()),
                "fwd_ms": fwd_ms,
                "bwd_ms": bwd_ms,
                "step_ms": step_ms,
                "throughput_sps": images.size(0) / (step_ms / 1000.0),
                "max_mem_alloc_mb": mem_alloc,
                "max_mem_reserved_mb": mem_reserved,
            })

    if not records:
        raise RuntimeError("No data recorded for benchmark")

    df = pd.DataFrame(records)
    summary = {
        "label": label,
        "avg_forward_ms": df["fwd_ms"].mean(),
        "avg_backward_ms": df["bwd_ms"].mean(),
        "avg_step_ms": df["step_ms"].mean(),
        "samples_per_s": df["throughput_sps"].mean(),
        "max_mem_alloc_mb": df["max_mem_alloc_mb"].max(),
        "max_mem_reserved_mb": df["max_mem_reserved_mb"].max(),
    }
    return df, summary


Таблица `summary_df` показывает средние метрики по каждому batch size: `avg_forward_ms`, `avg_backward_ms`, `avg_step_ms`, `samples_per_s`, а также пики памяти (`max_mem_alloc_mb`, `max_mem_reserved_mb`).


In [7]:
batch_summaries = []
batch_details = []

for bs, loader in train_loaders.items():
    print(f"=== Batch size {bs} ===")
    torch_model, baseline_model = build_model_pair(config)

    torch_df, torch_summary = run_benchmark(torch_model, f"nn.Conv2d (bs={bs})", loader, config)
    torch_summary.update({"variant": "nn.Conv2d", "batch_size": bs})
    batch_summaries.append(torch_summary)
    batch_details.append(torch_df.assign(variant="nn.Conv2d", batch_size=bs))

    baseline_df, baseline_summary = run_benchmark(baseline_model, f"Baseline TritonConv2d (bs={bs})", loader, config)
    baseline_summary.update({"variant": "Baseline TritonConv2d", "batch_size": bs})
    batch_summaries.append(baseline_summary)
    batch_details.append(baseline_df.assign(variant="Baseline TritonConv2d", batch_size=bs))

summary_df = pd.DataFrame(batch_summaries).set_index(["variant", "batch_size"])
summary_df


=== Batch size 32 ===
=== Batch size 64 ===
=== Batch size 96 ===
=== Batch size 128 ===
=== Batch size 160 ===
=== Batch size 192 ===
=== Batch size 256 ===


Unnamed: 0_level_0,Unnamed: 1_level_0,label,avg_forward_ms,avg_backward_ms,avg_step_ms,samples_per_s,max_mem_alloc_mb,max_mem_reserved_mb
variant,batch_size,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
nn.Conv2d,32,nn.Conv2d (bs=32),4.718899,4.217519,8.936418,3730.567146,190.307129,220.0
Baseline TritonConv2d,32,Baseline TritonConv2d (bs=32),19.836949,14.022024,33.858973,976.347011,207.492676,226.0
nn.Conv2d,64,nn.Conv2d (bs=64),4.07905,3.708465,7.787515,8566.29246,190.058105,228.0
Baseline TritonConv2d,64,Baseline TritonConv2d (bs=64),17.765971,15.295524,33.061495,2087.275114,228.024902,256.0
nn.Conv2d,96,nn.Conv2d (bs=96),2.89512,3.304866,6.199985,15602.282334,190.685059,222.0
Baseline TritonConv2d,96,Baseline TritonConv2d (bs=96),16.232723,14.990402,31.223125,3169.658079,252.019043,312.0
nn.Conv2d,128,nn.Conv2d (bs=128),2.407906,3.489559,5.897465,21868.949079,204.617676,228.0
Baseline TritonConv2d,128,Baseline TritonConv2d (bs=128),13.662614,15.38072,29.043334,4512.833889,281.83252,340.0
nn.Conv2d,160,nn.Conv2d (bs=160),2.614747,3.974958,6.589705,24334.832347,219.119629,252.0
Baseline TritonConv2d,160,Baseline TritonConv2d (bs=160),11.562864,16.098852,27.661716,5791.891497,306.271973,366.0


Вывод `detail_df.groupby(...).describe()` содержит count/mean/std/min/25%/50%/75%/max для метрик `step_ms`, `fwd_ms`, `bwd_ms`, `max_mem_alloc_mb` отдельно по каждому `(variant, batch_size)`.


In [8]:
detail_df = pd.concat(batch_details, ignore_index=True)
metrics = ["step_ms", "fwd_ms", "bwd_ms", "max_mem_alloc_mb"]
detail_df.groupby(["variant", "batch_size"])[metrics].describe()


Unnamed: 0_level_0,Unnamed: 1_level_0,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,step_ms,fwd_ms,fwd_ms,...,bwd_ms,bwd_ms,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb,max_mem_alloc_mb
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
variant,batch_size,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2
Baseline TritonConv2d,32,35.0,33.858973,5.744641,22.376801,28.662608,37.607103,38.315744,40.636736,35.0,19.836949,...,16.02048,16.560127,35.0,207.492676,0.0,207.492676,207.492676,207.492676,207.492676,207.492676
Baseline TritonConv2d,64,35.0,33.061495,8.217262,20.058176,22.53936,38.267839,38.718416,42.518528,35.0,17.765971,...,17.565184,21.909248,35.0,228.024902,0.0,228.024902,228.024902,228.024902,228.024902,228.024902
Baseline TritonConv2d,96,35.0,31.223125,5.574035,23.408384,26.591424,30.303201,35.195487,41.889696,35.0,16.232723,...,17.261568,24.86272,35.0,252.019043,0.0,252.019043,252.019043,252.019043,252.019043,252.019043
Baseline TritonConv2d,128,35.0,29.043334,4.832932,23.702687,26.200144,27.2488,31.422801,41.012735,35.0,13.662614,...,17.470352,18.46784,35.0,281.125377,0.748738,280.08252,280.45752,280.83252,281.83252,281.83252
Baseline TritonConv2d,160,35.0,27.661716,1.02454,25.974336,26.774592,27.709888,28.542384,29.760223,35.0,11.562864,...,16.915456,17.731585,35.0,306.271973,0.0,306.271973,306.271973,306.271973,306.271973,306.271973
Baseline TritonConv2d,192,35.0,31.344689,1.965179,28.798271,29.997295,30.671456,32.075392,37.516481,35.0,12.458638,...,19.495936,22.489088,35.0,334.874735,0.737501,333.960449,333.960449,334.960449,335.710449,335.710449
Baseline TritonConv2d,256,35.0,35.782031,2.107488,32.995071,34.510256,35.054015,37.044016,41.610048,35.0,12.359798,...,24.132735,26.356735,35.0,383.338379,0.0,383.338379,383.338379,383.338379,383.338379,383.338379
nn.Conv2d,32,35.0,8.936418,1.600873,4.734464,8.181216,9.590016,10.108624,10.567328,35.0,4.718899,...,4.426512,4.614592,35.0,190.307129,0.0,190.307129,190.307129,190.307129,190.307129,190.307129
nn.Conv2d,64,35.0,7.787515,1.444092,4.973568,6.666848,8.32512,8.859728,9.406336,35.0,4.07905,...,4.124752,4.864096,35.0,190.058105,0.0,190.058105,190.058105,190.058105,190.058105,190.058105
nn.Conv2d,96,35.0,6.199985,0.544631,5.111008,5.660896,6.205376,6.742352,7.021024,35.0,2.89512,...,3.536896,4.148224,35.0,190.685059,0.0,190.685059,190.685059,190.685059,190.685059,190.685059


`baseline_vs_torch_df` сравнивает nn.Conv2d и Baseline TritonConv2d: пары столбцов с абсолютными значениями (forward/backward/step время, throughput, память) и коэффициенты ускорения (`speedup_*`, `throughput_ratio`, `mem_*_ratio`).


In [9]:
baseline_compare_rows = []
for bs in config["batch_sizes"]:
    torch_row = summary_df.loc[("nn.Conv2d", bs)]
    baseline_row = summary_df.loc[("Baseline TritonConv2d", bs)]
    comparison = {
        "batch_size": bs,
        "torch_forward_ms": torch_row["avg_forward_ms"],
        "baseline_forward_ms": baseline_row["avg_forward_ms"],
        "torch_backward_ms": torch_row["avg_backward_ms"],
        "baseline_backward_ms": baseline_row["avg_backward_ms"],
        "torch_step_ms": torch_row["avg_step_ms"],
        "baseline_step_ms": baseline_row["avg_step_ms"],
        "torch_samples_per_s": torch_row["samples_per_s"],
        "baseline_samples_per_s": baseline_row["samples_per_s"],
        "speedup_forward": torch_row["avg_forward_ms"] / baseline_row["avg_forward_ms"],
        "speedup_backward": torch_row["avg_backward_ms"] / baseline_row["avg_backward_ms"],
        "speedup_step": torch_row["avg_step_ms"] / baseline_row["avg_step_ms"],
        "throughput_ratio": baseline_row["samples_per_s"] / torch_row["samples_per_s"],
        "torch_mem_alloc_mb": torch_row["max_mem_alloc_mb"],
        "baseline_mem_alloc_mb": baseline_row["max_mem_alloc_mb"],
        "torch_mem_reserved_mb": torch_row["max_mem_reserved_mb"],
        "baseline_mem_reserved_mb": baseline_row["max_mem_reserved_mb"],
        "mem_alloc_ratio": baseline_row["max_mem_alloc_mb"] / torch_row["max_mem_alloc_mb"],
        "mem_reserved_ratio": baseline_row["max_mem_reserved_mb"] / torch_row["max_mem_reserved_mb"],
    }
    baseline_compare_rows.append(comparison)

baseline_vs_torch_df = pd.DataFrame(baseline_compare_rows).set_index("batch_size")
baseline_vs_torch_df


Unnamed: 0_level_0,torch_forward_ms,baseline_forward_ms,torch_backward_ms,baseline_backward_ms,torch_step_ms,baseline_step_ms,torch_samples_per_s,baseline_samples_per_s,speedup_forward,speedup_backward,speedup_step,throughput_ratio,torch_mem_alloc_mb,baseline_mem_alloc_mb,torch_mem_reserved_mb,baseline_mem_reserved_mb,mem_alloc_ratio,mem_reserved_ratio
batch_size,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
32,4.718899,19.836949,4.217519,14.022024,8.936418,33.858973,3730.567146,976.347011,0.237884,0.300778,0.263931,0.261715,190.307129,207.492676,220.0,226.0,1.090304,1.027273
64,4.07905,17.765971,3.708465,15.295524,7.787515,33.061495,8566.29246,2087.275114,0.229599,0.242454,0.235546,0.243661,190.058105,228.024902,228.0,256.0,1.199764,1.122807
96,2.89512,16.232723,3.304866,14.990402,6.199985,31.223125,15602.282334,3169.658079,0.178351,0.220465,0.19857,0.203153,190.685059,252.019043,222.0,312.0,1.321651,1.405405
128,2.407906,13.662614,3.489559,15.38072,5.897465,29.043334,21868.949079,4512.833889,0.17624,0.226879,0.203057,0.206358,204.617676,281.83252,228.0,340.0,1.377362,1.491228
160,2.614747,11.562864,3.974958,16.098852,6.589705,27.661716,24334.832347,5791.891497,0.226133,0.246909,0.238225,0.238008,219.119629,306.271973,252.0,366.0,1.397739,1.452381
192,2.697283,12.458638,4.452401,18.886051,7.149684,31.344689,26948.687933,6147.323264,0.216499,0.235751,0.228099,0.228112,230.495605,335.710449,274.0,402.0,1.456472,1.467153
256,2.951425,12.359798,5.560192,23.422233,8.511617,35.782031,30291.919684,7177.305322,0.238792,0.237389,0.237874,0.236938,258.373535,383.338379,306.0,598.0,1.48366,1.954248


Таблица `summary_df` показывает средние метрики по каждому batch size: `avg_forward_ms`, `avg_backward_ms`, `avg_step_ms`, `samples_per_s`, а также пики памяти (`max_mem_alloc_mb`, `max_mem_reserved_mb`).


In [None]:
sparsity_cfg = config["sparsity_bench"]
sparsity_bs = sparsity_cfg["batch_size"]
if sparsity_bs not in train_loaders:
    train_loaders[sparsity_bs] = make_loader(sparsity_bs)
sparsity_loader = train_loaders[sparsity_bs]

sparsity_summaries = []
sparsity_details = []

for mode in sparsity_cfg["modes"]:
    for ratio in sparsity_cfg["keep_ratios"]:
        _, baseline_model = build_model_pair(config)
        apply_sparsity_to_model(
            baseline_model,
            mode,
            keep_ratio=ratio,
            block_size=sparsity_cfg.get("block_size", 4),
        )
        label = f"{mode.capitalize()} sparsity (keep={ratio:.2f}, bs={sparsity_bs})"
        bench_df, bench_summary = run_benchmark(baseline_model, label, sparsity_loader, config)
        bench_summary.update({
            "variant": f"Sparsity::{mode}",
            "mode": mode,
            "keep_ratio": ratio,
            "batch_size": sparsity_bs,
        })
        sparsity_summaries.append(bench_summary)
        sparsity_details.append(
            bench_df.assign(variant=f"Sparsity::{mode}", mode=mode, keep_ratio=ratio, batch_size=sparsity_bs)
        )

sparsity_summary_df = pd.DataFrame(sparsity_summaries).sort_values("samples_per_s", ascending=False).reset_index(drop=True)
sparsity_summary_df


`sparsity_compare_df` добавляет к тем же сценариям относительные значения относительно эталонного nn.Conv2d (`speedup_*_vs_torch`, `throughput_ratio_vs_torch`, `mem_*_ratio_vs_torch`).


In [None]:
sparsity_reference = summary_df.loc[("nn.Conv2d", sparsity_bs)]

sparsity_compare_df = sparsity_summary_df.copy()
sparsity_compare_df["speedup_forward_vs_torch"] = sparsity_reference["avg_forward_ms"] / sparsity_compare_df["avg_forward_ms"]
sparsity_compare_df["speedup_backward_vs_torch"] = sparsity_reference["avg_backward_ms"] / sparsity_compare_df["avg_backward_ms"]
sparsity_compare_df["speedup_step_vs_torch"] = sparsity_reference["avg_step_ms"] / sparsity_compare_df["avg_step_ms"]
sparsity_compare_df["throughput_ratio_vs_torch"] = sparsity_compare_df["samples_per_s"] / sparsity_reference["samples_per_s"]
sparsity_compare_df["mem_alloc_ratio_vs_torch"] = sparsity_compare_df["max_mem_alloc_mb"] / sparsity_reference["max_mem_alloc_mb"]
sparsity_compare_df["mem_reserved_ratio_vs_torch"] = sparsity_compare_df["max_mem_reserved_mb"] / sparsity_reference["max_mem_reserved_mb"]
sparsity_compare_df = sparsity_compare_df.sort_values("samples_per_s", ascending=False).reset_index(drop=True)
sparsity_compare_df


`ranking_df` — упорядоченный рейтинг сценариев спарсификации: показывает `mode`, `keep_ratio`, абсолютный throughput и его отношение к торчу, а также ускорения forward/backward/step и изменение памяти.


In [None]:
ranking_df = sparsity_compare_df[[
    "variant",
    "mode",
    "keep_ratio",
    "samples_per_s",
    "throughput_ratio_vs_torch",
    "speedup_forward_vs_torch",
    "speedup_backward_vs_torch",
    "speedup_step_vs_torch",
    "mem_alloc_ratio_vs_torch",
    "mem_reserved_ratio_vs_torch",
]].copy()
ranking_df = ranking_df.sort_values("throughput_ratio_vs_torch", ascending=False).reset_index(drop=True)
ranking_df
