# Kernel Size & Block Tuning Benchmark

Этот ноутбук перебирает разные размеры ядер / пространственных размеров и конфигурации Triton-блоков, чтобы найти сочетания, где наша реализация хотя бы сопоставима или быстрее PyTorch Conv2d.

## Подготовка окружения
Добавляем корень репозитория в `sys.path`, чтобы импортировать пакет `conv_gemm`.

In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

## Импорты и базовые настройки

In [2]:
import time, itertools
import torch
import pandas as pd
import torch.nn.functional as F

from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d as BaselineTritonConv2d

torch.backends.cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
print(f'device = {device}, dtype = {dtype}')
print('Baseline Triton available:', BaselineTritonConv2d is not None)

device = cuda, dtype = torch.float16
Baseline Triton available: True


## Вспомогательные функции

In [3]:
def sync_device():
    if device == 'cuda':
        torch.cuda.synchronize()

def clone_weights(dst, src):
    with torch.no_grad():
        dst.weight.copy_(src.weight)
        if dst.bias is not None and src.bias is not None:
            dst.bias.copy_(src.bias)

def compare_outputs(ref, other, x):
    ref_out = ref(x).float()
    test_out = other(x).float()
    diff = (ref_out - test_out).abs()
    return {
        'mae': diff.mean().item(),
        'max': diff.max().item(),
        'rel_l2': diff.norm().item() / (ref_out.norm().item() + 1e-12)
    }

def benchmark_layer(layer, x, iters=20, warmup=10):
    layer.eval()
    sync_device()
    with torch.no_grad():
        for _ in range(warmup):
            layer(x)
    sync_device()
    start = time.perf_counter()
    with torch.no_grad():
        for _ in range(iters):
            layer(x)
    sync_device()
    return (time.perf_counter() - start) * 1000.0 / iters

def build_torch_conv(cfg):
    conv = torch.nn.Conv2d(
        cfg['in_channels'], cfg['out_channels'], cfg['kernel_size'],
        stride=cfg['stride'], padding=cfg['padding'], dilation=cfg['dilation'], bias=True
    ).to(device=device, dtype=dtype)
    return conv

def build_triton_conv(cfg, block_cfg):
    if device != 'cuda':
        return None
    tri = BaselineTritonConv2d(
        in_channels=cfg['in_channels'], out_channels=cfg['out_channels'],
        kernel_size=cfg['kernel_size'], stride=cfg['stride'], padding=cfg['padding'],
        dilation=cfg['dilation'], bias=True,
        BLOCK_M=block_cfg['BLOCK_M'], BLOCK_N=block_cfg['BLOCK_N'], BLOCK_K=block_cfg['BLOCK_K'],
        NUM_WARPS=block_cfg['NUM_WARPS'], NUM_STAGES=block_cfg['NUM_STAGES']
    ).to(device)
    return tri


## Поисковое пространство
* `kernel_grid` — комбинации (B, H, W, kernel_size, stride, padding).
* `block_grid` — варианты параметров Triton ядра.

In [4]:
kernel_grid = [
    dict(name='1x1@56', in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, dilation=1, B=32, H=56, W=56),
    dict(name='3x3@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=56, W=56),
    dict(name='5x5@56', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=56, W=56),
    dict(name='7x7@56', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=16, H=56, W=56),
    # dict(name='8x8@56', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=12, H=56, W=56),
    dict(name='11x11@56', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=56, W=56),
    # dict(name='16x16@56', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=4, H=56, W=56),
    dict(name='3x3@112', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='5x5@112', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=112, W=112),
    dict(name='7x7@112', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=12, H=112, W=112),
    # dict(name='8x8@112', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=12, H=112, W=112),
    dict(name='11x11@112', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=112, W=112),
    # dict(name='16x16@112', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=4, H=112, W=112),
    dict(name='3x3@224', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=8, H=224, W=224),
    dict(name='7x7@224', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=6, H=224, W=224),
    dict(name='11x11@224', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=4, H=224, W=224),
    dict(name='13x13@224', in_channels=64, out_channels=128, kernel_size=13, stride=1, padding=6, dilation=1, B=2, H=224, W=224),
    dict(name='3x3@512', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=4, H=512, W=512),
    dict(name='7x7@512', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=2, H=512, W=512),
    dict(name='3x3@32', in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, B=32, H=32, W=32),
    dict(name='5x5@32', in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=2, dilation=1, B=32, H=32, W=32),
    dict(name='7x7@32', in_channels=128, out_channels=256, kernel_size=7, stride=1, padding=3, dilation=1, B=32, H=32, W=32),
    dict(name='3x3@16', in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, B=64, H=16, W=16),
    dict(name='3x3_s2@112', in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='3x3_dil2@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2, B=16, H=56, W=56),
]

block_grid = [
    dict(name='64-64-16/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-32/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-64/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-128-32/4x2', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='128-64-32/8x2', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-32/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-64/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-64-32/8x3', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=3),
    dict(name='64-128-32/4x3', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=3),
    dict(name='64-64-32/2x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=2, NUM_STAGES=2),
]
print('kernel configs:', len(kernel_grid))
print('block configs:', len(block_grid))

kernel configs: 21
block configs: 10


## Основной перебор
Для каждой конфигурации ядра обучаем PyTorch Conv2d, клонируем веса в TritonConv2d и измеряем время/точность. Результаты сохраняем в `results_df`.

In [5]:
rows = []
if device != 'cuda' or BaselineTritonConv2d is None:
    print('Triton недоступен — пропускаем поиск')
else:
    for cfg in kernel_grid:
        # обычный Conv2d
        torch_conv = build_torch_conv(cfg)
        x = torch.randn(
            cfg['B'], cfg['in_channels'], cfg['H'], cfg['W'],
            device=device, dtype=dtype
        )

        #  плотный TritonConv2d 
        dense = BaselineTritonConv2d(
            in_channels=cfg['in_channels'],
            out_channels=cfg['out_channels'],
            kernel_size=cfg['kernel_size'],
            stride=cfg['stride'],
            padding=cfg['padding'],
            dilation=cfg['dilation'],
            bias=True,  
        ).to(device)

        clone_weights(dense, torch_conv)
        dense_time = benchmark_layer(dense, x)
        stats_dense = compare_outputs(torch_conv, dense, x)

        # зреженный по входным каналам TritonConv2d
        sparse = BaselineTritonConv2d(
            in_channels=cfg['in_channels'],
            out_channels=cfg['out_channels'],
            kernel_size=cfg['kernel_size'],
            stride=cfg['stride'],
            padding=cfg['padding'],
            dilation=cfg['dilation'],
            bias=True,
        ).to(device)

        clone_weights(sparse, torch_conv)
        sparse.set_input_channel_sparsity(keep_ratio=cfg.get('input_keep', 0.5))
        sparse_time = benchmark_layer(sparse, x)
        stats_sparse = compare_outputs(torch_conv, sparse, x)

        rows.append({
            'kernel': cfg['name'],
            'B': cfg['B'], 'H': cfg['H'], 'W': cfg['W'],
            'k': cfg['kernel_size'],
            'mode': 'dense',
            'time_ms': dense_time,
            'mae': stats_dense['mae'],
            'max': stats_dense['max'],
            'rel_l2': stats_dense['rel_l2'],
            'speedup_vs_dense': 'nan',
        })
        rows.append({
            'kernel': cfg['name'],
            'B': cfg['B'], 'H': cfg['H'], 'W': cfg['W'],
            'k': cfg['kernel_size'],
            'mode': 'input_sparse',
            'keep_ratio': cfg.get('input_keep', 0.5),
            'time_ms': sparse_time,
            'mae': stats_sparse['mae'],
            'max': stats_sparse['max'],
            'rel_l2': stats_sparse['rel_l2'],
            'speedup_vs_dense': dense_time / sparse_time if sparse_time > 0 else float('nan'),
        })

results_df = pd.DataFrame(rows)
results_df

Unnamed: 0,kernel,B,H,W,k,mode,time_ms,mae,max,rel_l2,speedup_vs_dense,keep_ratio
0,1x1@56,32,56,56,1,dense,0.683379,8e-05,0.001953,0.000354,,
1,1x1@56,32,56,56,1,input_sparse,0.769938,0.305527,2.172607,0.6632,0.887576,0.5
2,3x3@56,16,56,56,3,dense,1.042263,7.8e-05,0.001953,0.000354,,
3,3x3@56,16,56,56,3,input_sparse,0.878565,0.318442,2.371094,0.700036,1.186325,0.5
4,5x5@56,16,56,56,5,dense,1.795365,7.6e-05,0.001953,0.000354,,
5,5x5@56,16,56,56,5,input_sparse,1.263141,0.315676,2.114136,0.702327,1.421349,0.5
6,7x7@56,16,56,56,7,dense,2.975042,8.1e-05,0.001953,0.000365,,
7,7x7@56,16,56,56,7,input_sparse,1.84928,0.313114,2.035156,0.703221,1.608757,0.5
8,11x11@56,8,56,56,11,dense,3.583289,7.8e-05,0.001953,0.000364,,
9,11x11@56,8,56,56,11,input_sparse,1.983689,0.307548,2.074219,0.704482,1.806376,0.5


## Лучшая конфигурация по каждому ядру

In [6]:
if not results_df.empty:
    best_df = results_df.sort_values('speedup', ascending=False).groupby('kernel', as_index=False).first()
    best_df[['kernel', 'block', 'speedup_vs_dense', 'triton_time', 'torch_time', 'mae', 'max']]
else:и
    best_df = pd.DataFrame()
    print('Нет данных (Triton отключён)')

IndentationError: unexpected indent (698521774.py, line 5)

In [7]:
if not results_df.empty:
    # Берём только разрежённые варианты, потому что speedup_vs_dense есть только там
    sparse_df = results_df[results_df['mode'] == 'input_sparse'].copy()

    if sparse_df.empty:
        best_df = pd.DataFrame()
        print('Нет строк с mode == "input_sparse"')
    else:
        # Для каждого ядра берём конфиг с максимальным ускорением
        best_df = (
            sparse_df
            .sort_values('speedup_vs_dense', ascending=False)
            .groupby('kernel', as_index=False)
            .first()
        )

        # Выводим самые полезные поля
        best_df[
            [
                'kernel',     # имя ядра (cfg['name'])
                'B', 'H', 'W',
                'k',          # kernel_size
                'keep_ratio', # доля оставленных входных каналов
                'time_ms',    # время Triton-разрежённого варианта
                'speedup_vs_dense',
                'mae', 'max', 'rel_l2',
            ]
        ]
else:
    best_df = pd.DataFrame()
    print('Нет данных (Triton отключён или все запуски упали)')

## Сохранение таблицы (опционально)
Если нужно, можно выгрузить `results_df` в CSV для дальнейшего анализа.

In [8]:
top5_sparse = best_df[best_df["mode"] == "input_sparse"].sort_values("speedup_vs_dense", ascending=False).head(5)
top5_sparse

Unnamed: 0,kernel,B,H,W,k,mode,time_ms,mae,max,rel_l2,speedup_vs_dense,keep_ratio
3,13x13@224,2,224,224,13,input_sparse,10.389999,0.319993,2.243652,0.705721,1.849441,0.5
1,11x11@224,4,224,224,11,input_sparse,14.932144,0.320465,2.515625,0.705597,1.827728,0.5
0,11x11@112,8,112,112,11,input_sparse,6.903804,0.316052,2.156738,0.704964,1.807472,0.5
2,11x11@56,8,56,56,11,input_sparse,1.983689,0.307548,2.074219,0.704482,1.806376,0.5
16,7x7@112,12,112,112,7,input_sparse,4.858555,0.318963,2.235596,0.70474,1.746326,0.5


In [9]:
top5_dense = best_df[best_df["mode"] == "dense"].sort_values("time_ms", ascending=True).head(5)
top5_dense

Unnamed: 0,kernel,B,H,W,k,mode,time_ms,mae,max,rel_l2,speedup_vs_dense,keep_ratio
