# Channel Sparsity Benchmark

Эксперименты по сравнению PyTorch Conv2d, TritonConv2d (fp16, с поддержкой канальной спарсификации).

## Подготовка окружения
Ниже мы добавляем корень репозитория в `sys.path`, чтобы ноутбук, запущенный из папки `notebooks/`, мог импортировать пакет `conv_gemm`.

In [2]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

## Импорты и настройки
В этой секции загружаем библиотеки и конфигурируем устройство/тип данных. По умолчанию вычисления идут в fp16 на GPU (если доступен CUDA), иначе падаем на CPU + fp32.

In [3]:
import time, copy, math
import torch
import pandas as pd
import torch.nn.functional as F
from torch.nn.utils import prune

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d as BaselineTritonConv2d

torch.backends.cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
print(f'device: {device}, dtype: {dtype}')
print('Baseline Triton available:', BaselineTritonConv2d is not None)

device: cuda, dtype: torch.float16
Baseline Triton available: True


## Вспомогательные функции
- `sync_device` и `benchmark_module` — измерение времени (forward + backward).
- `compare_modules` — средняя/максимальная ошибка относительно PyTorch Conv2d.
- `finetune_module` — лёгкий тюнинг sparse модели (имитируем дистилляцию от dense-версии).

In [4]:
def sync_device():
    if device == 'cuda':
        torch.cuda.synchronize()

def clone_weights(dst, src):
    with torch.no_grad():
        dst.weight.copy_(src.weight)
        if dst.bias is not None and src.bias is not None:
            dst.bias.copy_(src.bias)

def compare_modules(ref, other, x):
    ref_out = ref(x).float()
    test_out = other(x).float()
    diff = (ref_out - test_out).abs()
    return {
        'mae': diff.mean().item(),
        'max': diff.max().item(),
        'rel_l2': diff.norm().item() / (ref_out.norm().item() + 1e-12)
    }

def benchmark_module(module, x, iters=50, warmup=10):
    module.eval()
    sync_device()
    with torch.no_grad():
        for _ in range(warmup):
            module(x)
    sync_device()
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(iters):
            module(x)
    sync_device()
    return (time.perf_counter() - start) * 1000.0 / iters

def finetune_module(module, teacher, steps=0, lr=1e-3, batch_shape=(16, 64, 56, 56), grad_clip=None):
    if steps <= 0:
        return
    orig_dtype = next(module.parameters()).dtype
    module.to(torch.float32)
    teacher_copy = copy.deepcopy(teacher).to(torch.float32).eval()
    module.train()
    opt = torch.optim.Adam(module.parameters(), lr=lr)
    for step in range(steps):
        x = torch.randn(*batch_shape, device=device, dtype=torch.float32)
        with torch.no_grad():
            target = teacher_copy(x)
        pred = module(x)
        loss = F.mse_loss(pred, target)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(module.parameters(), grad_clip)
        opt.step()
        if step % max(1, steps // 5) == 0:
            print(f'[finetune step {step}] loss={loss.item():.4e}')
    module.to(orig_dtype).eval()
    sync_device()


## Базовая конфигурация слоёв
Задаём параметры свёртки и создаём три реализации: PyTorch Conv2d, TritonConv2d (если доступен GPU).

In [5]:
params = dict(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
B, H, W = 16, 56, 56

torch_conv = torch.nn.Conv2d(**params).to(device=device, dtype=dtype)

baseline_block_cfg = dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2)

def build_baseline(block_cfg=None):
    if device != 'cuda':
        return None
    cfg = block_cfg or baseline_block_cfg
    tri = BaselineTritonConv2d(
        **params,
        BLOCK_M=cfg['BLOCK_M'], BLOCK_N=cfg['BLOCK_N'], BLOCK_K=cfg['BLOCK_K'],
        NUM_WARPS=cfg['NUM_WARPS'], NUM_STAGES=cfg['NUM_STAGES']
    ).to(device)
    clone_weights(tri, torch_conv)
    return tri

tri_dense = build_baseline()

## Базовое сравнение точности/времени
Сначала проверяем расхождение выходов и время одной итерации для PyTorch, Triton без спарсификации.

In [6]:
x_sample = torch.randn(B, params['in_channels'], H, W, device=device, dtype=dtype)
records = []
modules = [('Torch Conv2d', torch_conv)]
if tri_dense is not None:
    modules.append(('Baseline Triton fp16', tri_dense))

for name, module in modules:
    stats = compare_modules(torch_conv, module, x_sample)
    t_ms = benchmark_module(module, x_sample.clone().detach())
    records.append({'layer': name, 'mae': stats['mae'], 'max': stats['max'], 'rel_l2': stats['rel_l2'], 'time_ms': t_ms})
baseline_df = pd.DataFrame(records)
baseline_df

Unnamed: 0,layer,mae,max,rel_l2,time_ms
0,Torch Conv2d,0.0,0.0,0.0,1.103303
1,Baseline Triton fp16,8.2e-05,0.001953,0.000359,3.107725


## Эксперимент 1: Sweep по доле оставляемых каналов
Создаём функцию, которая для каждой доли `keep_ratio` заменяет `TritonConv2d` на sparse-вариант, профилирует его и собирает статистику по ошибкам/времени.

In [7]:
def run_channel_sweep(keep_ratios, finetune_steps=0, finetune_lr=1e-3, block_cfg=None, grad_clip=None, batch_shape=None):
    if tri_dense is None:
        raise RuntimeError('Triton недоступен (нужен GPU)')
    cfg = block_cfg or baseline_block_cfg
    batch_shape = batch_shape or (B, params['in_channels'], H, W)
    rows = []
    teacher = build_baseline(cfg) if finetune_steps > 0 else tri_dense
    for ratio in keep_ratios:
        tri = build_baseline(cfg)
        tri.set_channel_sparsity(ratio)
        if finetune_steps > 0 and teacher is not None:
            finetune_module(tri, teacher, steps=finetune_steps, lr=finetune_lr, batch_shape=batch_shape, grad_clip=grad_clip)
        x = torch.randn(*batch_shape, device=device, dtype=dtype)
        stats = compare_modules(torch_conv, tri, x)
        t_ms = benchmark_module(tri, x.clone().detach())
        rows.append({'mode': 'channel', 'keep_ratio': ratio, 'mae': stats['mae'], 'max': stats['max'],
                     'rel_l2': stats['rel_l2'], 'time_ms': t_ms,
                     'BLOCK_M': cfg['BLOCK_M'], 'BLOCK_N': cfg['BLOCK_N'],
                     'BLOCK_K': cfg['BLOCK_K'], 'NUM_WARPS': cfg['NUM_WARPS'], 'NUM_STAGES': cfg['NUM_STAGES']})
    return pd.DataFrame(rows)

def run_block_sweep(keep_ratios, block_size=4, block_cfg=None):
    if tri_dense is None:
        raise RuntimeError('Triton недоступен (нужен GPU)')
    cfg = block_cfg or baseline_block_cfg
    rows = []
    for ratio in keep_ratios:
        tri = build_baseline(cfg)
        tri.set_block_sparsity(ratio, block_size=block_size)
        x = torch.randn(B, params['in_channels'], H, W, device=device, dtype=dtype)
        stats = compare_modules(torch_conv, tri, x)
        t_ms = benchmark_module(tri, x.clone().detach())
        rows.append({'mode': f'block-{block_size}', 'keep_ratio': ratio, 'mae': stats['mae'], 'max': stats['max'],
                     'rel_l2': stats['rel_l2'], 'time_ms': t_ms,
                     'BLOCK_M': cfg['BLOCK_M'], 'BLOCK_N': cfg['BLOCK_N'],
                     'BLOCK_K': cfg['BLOCK_K'], 'NUM_WARPS': cfg['NUM_WARPS'], 'NUM_STAGES': cfg['NUM_STAGES']})
    return pd.DataFrame(rows)

def run_input_sweep(keep_ratios, block_cfg=None):
    if tri_dense is None:
        raise RuntimeError('Triton недоступен (нужен GPU)')
    cfg = block_cfg or baseline_block_cfg
    rows = []
    for ratio in keep_ratios:
        tri = build_baseline(cfg)
        tri.set_input_channel_sparsity(ratio)
        x = torch.randn(B, params['in_channels'], H, W, device=device, dtype=dtype)
        stats = compare_modules(torch_conv, tri, x)
        t_ms = benchmark_module(tri, x.clone().detach())
        rows.append({'mode': 'input', 'keep_ratio': ratio, 'mae': stats['mae'], 'max': stats['max'],
                     'rel_l2': stats['rel_l2'], 'time_ms': t_ms,
                     'BLOCK_M': cfg['BLOCK_M'], 'BLOCK_N': cfg['BLOCK_N'],
                     'BLOCK_K': cfg['BLOCK_K'], 'NUM_WARPS': cfg['NUM_WARPS'], 'NUM_STAGES': cfg['NUM_STAGES']})
    return pd.DataFrame(rows)

keep_ratios = [1.0, 0.85, 0.75, 0.65, 0.5, 0.35, 0.25]
channel_sweep_df = run_channel_sweep(keep_ratios)
channel_sweep_df

Unnamed: 0,mode,keep_ratio,mae,max,rel_l2,time_ms,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_STAGES
0,channel,1.0,8.2e-05,0.001953,0.000359,3.257513,64,64,32,4,2
1,channel,0.85,0.065919,2.953125,0.375299,4.318055,64,64,32,4,2
2,channel,0.75,0.111443,2.814453,0.489092,3.299439,64,64,32,4,2
3,channel,0.65,0.157027,2.931641,0.581392,3.755666,64,64,32,4,2
4,channel,0.5,0.224219,3.419922,0.696688,2.761235,64,64,32,4,2
5,channel,0.35,0.292061,3.306641,0.796521,2.755066,64,64,32,4,2
6,channel,0.25,0.338279,3.087891,0.858843,2.44431,64,64,32,4,2


### Визуализация trade-off
Ниже строим таблицу с дополнительными метриками и вычисляем относительное ускорение по сравнению с плотным Triton.

In [8]:
has_dense = 'Baseline Triton fp16' in baseline_df['layer'].values
if has_dense:
    dense_time = baseline_df.loc[baseline_df['layer'] == 'Baseline Triton fp16', 'time_ms'].iloc[0]
else:
    dense_time = None

viz_df = channel_sweep_df.copy()
if dense_time is not None:
    viz_df['speedup_vs_dense'] = dense_time / viz_df['time_ms']
viz_df.sort_values('keep_ratio')

Unnamed: 0,mode,keep_ratio,mae,max,rel_l2,time_ms,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_STAGES,speedup_vs_dense
6,channel,0.25,0.338279,3.087891,0.858843,2.44431,64,64,32,4,2,1.271412
5,channel,0.35,0.292061,3.306641,0.796521,2.755066,64,64,32,4,2,1.128004
4,channel,0.5,0.224219,3.419922,0.696688,2.761235,64,64,32,4,2,1.125484
3,channel,0.65,0.157027,2.931641,0.581392,3.755666,64,64,32,4,2,0.827476
2,channel,0.75,0.111443,2.814453,0.489092,3.299439,64,64,32,4,2,0.941895
1,channel,0.85,0.065919,2.953125,0.375299,4.318055,64,64,32,4,2,0.719705
0,channel,1.0,8.2e-05,0.001953,0.000359,3.257513,64,64,32,4,2,0.954018


### Block sparsity (grouped filters)

In [9]:
block_sweep_df = run_block_sweep(keep_ratios, block_size=4)
block_sweep_df

Unnamed: 0,mode,keep_ratio,mae,max,rel_l2,time_ms,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_STAGES
0,block-4,1.0,8.2e-05,0.001953,0.000359,3.286024,64,64,32,4,2
1,block-4,0.85,0.070396,2.921875,0.390733,4.102976,64,64,32,4,2
2,block-4,0.75,0.112799,2.982422,0.495316,4.173936,64,64,32,4,2
3,block-4,0.65,0.155179,2.935547,0.581133,3.877464,64,64,32,4,2
4,block-4,0.5,0.226138,2.828125,0.702511,2.775524,64,64,32,4,2
5,block-4,0.35,0.297469,3.421875,0.806343,2.836737,64,64,32,4,2
6,block-4,0.25,0.339996,3.066406,0.862926,2.346706,64,64,32,4,2


### Input-channel sparsity

In [10]:
input_sweep_df = run_input_sweep([1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.25])
input_sweep_df

Unnamed: 0,mode,keep_ratio,mae,max,rel_l2,time_ms,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_STAGES
0,input,1.0,8.2e-05,0.001953,0.000359,3.040364,64,64,32,4,2
1,input,0.9,0.13518,0.951416,0.297872,3.641055,64,64,32,4,2
2,input,0.8,0.200654,1.351562,0.44074,3.452546,64,64,32,4,2
3,input,0.7,0.243655,1.639404,0.535744,3.210645,64,64,32,4,2
4,input,0.6,0.286175,2.08197,0.628993,2.930699,64,64,32,4,2
5,input,0.5,0.317927,2.299316,0.698936,2.653592,64,64,32,4,2
6,input,0.4,0.347557,2.325439,0.763294,2.450043,64,64,32,4,2
7,input,0.3,0.378818,2.560303,0.832627,2.365436,64,64,32,4,2
8,input,0.25,0.391766,2.664551,0.860058,2.174155,64,64,32,4,2


## Summary
Ниже сводим сравнение baseline vs. sparsity режимов (channel/block/input).

In [11]:
summary_frames = []
summary_frames.append(channel_sweep_df.assign(mode='channel'))
summary_frames.append(block_sweep_df.assign(mode='block'))
sum_input = input_sweep_df.assign(mode='input')
summary_frames.append(sum_input)
summary_df = pd.concat(summary_frames, ignore_index=True)
if 'Baseline Triton fp16' in baseline_df['layer'].values:
    dense_time = baseline_df.loc[baseline_df['layer'] == 'Baseline Triton fp16', 'time_ms'].iloc[0]
    summary_df['speedup_vs_dense'] = dense_time / summary_df['time_ms']
summary_df[['mode','keep_ratio','mae','max','time_ms','speedup_vs_dense']]

Unnamed: 0,mode,keep_ratio,mae,max,time_ms,speedup_vs_dense
0,channel,1.0,8.2e-05,0.001953,3.257513,0.954018
1,channel,0.85,0.065919,2.953125,4.318055,0.719705
2,channel,0.75,0.111443,2.814453,3.299439,0.941895
3,channel,0.65,0.157027,2.931641,3.755666,0.827476
4,channel,0.5,0.224219,3.419922,2.761235,1.125484
5,channel,0.35,0.292061,3.306641,2.755066,1.128004
6,channel,0.25,0.338279,3.087891,2.44431,1.271412
7,block,1.0,8.2e-05,0.001953,3.286024,0.94574
8,block,0.85,0.070396,2.921875,4.102976,0.757432
9,block,0.75,0.112799,2.982422,4.173936,0.744555


## Summary
Ниже сводим сравнение baseline vs. sparsity режимов (channel/block/input).

In [12]:
summary_frames = []
summary_frames.append(channel_sweep_df.assign(mode='channel'))
summary_frames.append(block_sweep_df.assign(mode='block'))
summary_frames.append(input_sweep_df.assign(mode='input'))
summary_df = pd.concat(summary_frames, ignore_index=True)
if 'Baseline Triton fp16' in baseline_df['layer'].values:
    dense_time = baseline_df.loc[baseline_df['layer']=='Baseline Triton fp16','time_ms'].iloc[0]
    summary_df['speedup_vs_dense'] = dense_time / summary_df['time_ms']
else:
    summary_df['speedup_vs_dense'] = float('nan')
summary_df[['mode','keep_ratio','mae','max','time_ms','speedup_vs_dense']]

Unnamed: 0,mode,keep_ratio,mae,max,time_ms,speedup_vs_dense
0,channel,1.0,8.2e-05,0.001953,3.257513,0.954018
1,channel,0.85,0.065919,2.953125,4.318055,0.719705
2,channel,0.75,0.111443,2.814453,3.299439,0.941895
3,channel,0.65,0.157027,2.931641,3.755666,0.827476
4,channel,0.5,0.224219,3.419922,2.761235,1.125484
5,channel,0.35,0.292061,3.306641,2.755066,1.128004
6,channel,0.25,0.338279,3.087891,2.44431,1.271412
7,block,1.0,8.2e-05,0.001953,3.286024,0.94574
8,block,0.85,0.070396,2.921875,4.102976,0.757432
9,block,0.75,0.112799,2.982422,4.173936,0.744555


### Top-10 configurations by speedup

In [13]:
top10 = summary_df.dropna(subset=['speedup_vs_dense']).sort_values('speedup_vs_dense', ascending=False).head(10)
top10[['mode','keep_ratio','time_ms','speedup_vs_dense','mae','max']]

Unnamed: 0,mode,keep_ratio,time_ms,speedup_vs_dense,mae,max
22,input,0.25,2.174155,1.429394,0.391766,2.664551
13,block,0.25,2.346706,1.324292,0.339996,3.066406
21,input,0.3,2.365436,1.313806,0.378818,2.560303
6,channel,0.25,2.44431,1.271412,0.338279,3.087891
20,input,0.4,2.450043,1.268437,0.347557,2.325439
19,input,0.5,2.653592,1.171139,0.317927,2.299316
5,channel,0.35,2.755066,1.128004,0.292061,3.306641
4,channel,0.5,2.761235,1.125484,0.224219,3.419922
11,block,0.5,2.775524,1.119689,0.226138,2.828125
12,block,0.35,2.836737,1.095528,0.297469,3.421875
