# Kernel Size & Block Tuning Benchmark

Этот ноутбук перебирает разные размеры ядер / пространственных размеров и конфигурации Triton-блоков, чтобы найти сочетания, где наша реализация хотя бы сопоставима или быстрее PyTorch Conv2d.

## Подготовка окружения
Добавляем корень репозитория в `sys.path`, чтобы импортировать пакет `conv_gemm`.

In [5]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))

## Импорты и базовые настройки

In [6]:
import time, itertools
import torch
import pandas as pd
import torch.nn.functional as F

from conv_gemm.layers import TritonConv2d, Gem2ColConv2d

torch.backends.cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
print(f'device = {device}, dtype = {dtype}')
print('Triton available:', TritonConv2d is not None)

ImportError: cannot import name 'TritonConv2d' from 'conv_gemm.layers' (/home/manzhura/ITMO/EDLM/conv2d-img2col-gemm/conv_gemm/layers/__init__.py)

## Вспомогательные функции

In [None]:
def sync_device():
    if device == 'cuda':
        torch.cuda.synchronize()

def clone_weights(dst, src):
    with torch.no_grad():
        dst.weight.copy_(src.weight)
        if dst.bias is not None and src.bias is not None:
            dst.bias.copy_(src.bias)


def compare_outputs(ref, other, x):
    ref_out = ref(x).float()
    test_out = other(x).float()
    diff = (ref_out - test_out).abs()
    return {
        'mae': diff.mean().item(),
        'max': diff.max().item(),
        'rel_l2': diff.norm().item() / (ref_out.norm().item() + 1e-12)
    }


def benchmark_layer(layer, x, iters=20, warmup=10):
    layer.train(False)
    sync_device()
    for _ in range(warmup):
        y = layer(x)
        if y.requires_grad:
            y.sum().backward()
    sync_device()
    start = time.perf_counter()
    for _ in range(iters):
        y = layer(x)
    sync_device()
    return (time.perf_counter() - start) * 1000.0 / iters


def build_torch_conv(cfg):
    conv = torch.nn.Conv2d(
        cfg['in_channels'], cfg['out_channels'], cfg['kernel_size'],
        stride=cfg['stride'], padding=cfg['padding'], dilation=cfg['dilation'], bias=True
    ).to(device=device, dtype=dtype)
    return conv


def build_triton_conv(cfg, block_cfg):
    if TritonConv2d is None or device != 'cuda':
        return None
    tri = TritonConv2d(
        in_channels=cfg['in_channels'], out_channels=cfg['out_channels'],
        kernel_size=cfg['kernel_size'], stride=cfg['stride'], padding=cfg['padding'],
        dilation=cfg['dilation'], bias=True,
        BLOCK_M=block_cfg['BLOCK_M'], BLOCK_N=block_cfg['BLOCK_N'], BLOCK_K=block_cfg['BLOCK_K'],
        NUM_WARPS=block_cfg['NUM_WARPS'], NUM_STAGES=block_cfg['NUM_STAGES'],
        precision_mode='fp16_infer'
    ).to(device)
    return tri

## Поисковое пространство
* `kernel_grid` — комбинации (B, H, W, kernel_size, stride, padding).
* `block_grid` — варианты параметров Triton ядра.

In [3]:
kernel_grid = [
    dict(name='1x1@56', in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, dilation=1, B=32, H=56, W=56),
    dict(name='3x3@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=56, W=56),
    dict(name='5x5@56', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=56, W=56),
    dict(name='7x7@56', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=16, H=56, W=56),
    dict(name='8x8@56', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=12, H=56, W=56),
    dict(name='11x11@56', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=56, W=56),
    dict(name='16x16@56', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=4, H=56, W=56),
    dict(name='3x3@112', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='5x5@112', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=112, W=112),
    dict(name='7x7@112', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=12, H=112, W=112),
    dict(name='8x8@112', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=12, H=112, W=112),
    dict(name='11x11@112', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=112, W=112),
    dict(name='16x16@112', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=4, H=112, W=112),
    dict(name='3x3@224', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=8, H=224, W=224),
    dict(name='7x7@224', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=6, H=224, W=224),
    dict(name='8x8@224', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=6, H=224, W=224),
    dict(name='11x11@224', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=4, H=224, W=224),
    dict(name='16x16@224', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=2, H=224, W=224),
    dict(name='3x3@512', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=4, H=512, W=512),
    dict(name='7x7@512', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=2, H=512, W=512),
    dict(name='8x8@512', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=2, H=512, W=512),
    dict(name='11x11@512', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=2, H=512, W=512),
    dict(name='16x16@512', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=1, H=512, W=512),
    dict(name='3x3@1024', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=2, H=1024, W=1024),
    dict(name='7x7@1024', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=1, H=1024, W=1024),
    dict(name='8x8@1024', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=1, H=1024, W=1024),
    dict(name='11x11@1024', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=1, H=1024, W=1024),
    dict(name='16x16@1024', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=1, H=1024, W=1024),
    dict(name='3x3@2048', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=1, H=2048, W=2048),
    dict(name='7x7@2048', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=1, H=2048, W=2048),
    dict(name='8x8@2048', in_channels=64, out_channels=128, kernel_size=8, stride=1, padding=4, dilation=1, B=1, H=2048, W=2048),
    dict(name='11x11@2048', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=1, H=2048, W=2048),
    dict(name='16x16@2048', in_channels=64, out_channels=128, kernel_size=16, stride=1, padding=8, dilation=1, B=1, H=2048, W=2048),
    dict(name='3x3@32', in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, B=32, H=32, W=32),
    dict(name='5x5@32', in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=2, dilation=1, B=32, H=32, W=32),
    dict(name='7x7@32', in_channels=128, out_channels=256, kernel_size=7, stride=1, padding=3, dilation=1, B=32, H=32, W=32),
    dict(name='3x3@16', in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, B=64, H=16, W=16),
    dict(name='3x3_s2@112', in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='3x3_dil2@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2, B=16, H=56, W=56),
]

block_grid = [
    dict(name='64-64-16/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-32/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-64/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-128-32/4x2', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='128-64-32/8x2', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-32/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-64/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-64-32/8x3', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=3),
    dict(name='64-128-32/4x3', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=3),
    dict(name='64-64-32/2x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=2, NUM_STAGES=2),
]
print('kernel configs:', len(kernel_grid))
print('block configs:', len(block_grid))

kernel configs: 39
block configs: 10


## Основной перебор
Для каждой конфигурации ядра обучаем PyTorch Conv2d, клонируем веса в TritonConv2d и измеряем время/точность. Результаты сохраняем в `results_df`.

In [4]:
rows = []
if device != 'cuda' or TritonConv2d is None:
    print('Triton недоступен — пропускаем поиск')
else:
    for cfg in kernel_grid:
        try:
            x = torch.randn(cfg['B'], cfg['in_channels'], cfg['H'], cfg['W'], device=device, dtype=dtype)
        except torch.cuda.OutOfMemoryError:
            print(f'OOM allocating input: kernel={cfg["name"]}')
            torch.cuda.empty_cache()
            continue
        torch_conv = build_torch_conv(cfg)
        torch_time = benchmark_layer(torch_conv, x)
        for block in block_grid:
            tri = build_triton_conv(cfg, block)
            if tri is None:
                continue
            clone_weights(tri, torch_conv)
            try:
                tri_time = benchmark_layer(tri, x)
                stats = compare_outputs(torch_conv, tri, x)
                rows.append({
                    'kernel': cfg['name'],
                    'B': cfg['B'], 'H': cfg['H'], 'W': cfg['W'],
                    'k': cfg['kernel_size'], 'block': block['name'],
                    'torch_time': torch_time,
                    'triton_time': tri_time,
                    'speedup': torch_time / tri_time if tri_time > 0 else float('nan'),
                    'mae': stats['mae'], 'max': stats['max'], 'rel_l2': stats['rel_l2']
                })
            except torch.cuda.OutOfMemoryError:
                print(f'OOM during run: kernel={cfg["name"]}, block={block["name"]}')
                torch.cuda.empty_cache()
                continue
results_df = pd.DataFrame(rows)
results_df

NameError: name 'device' is not defined

## Лучшая конфигурация по каждому ядру

In [None]:
if not results_df.empty:
    best_df = results_df.sort_values('speedup', ascending=False).groupby('kernel', as_index=False).first()
    best_df[['kernel', 'block', 'speedup', 'triton_time', 'torch_time', 'mae', 'max']]
else:
    best_df = pd.DataFrame()
    print('Нет данных (Triton отключён)')

## Отбор конфигураций ≥15% ускорения
Фильтруем результаты, где `speedup >= 1.15` и ошибки ниже порогов.

In [None]:
speed_threshold = 1.01
mae_threshold = 5e-3
if not results_df.empty:
    good = results_df[(results_df['speedup'] >= speed_threshold) & (results_df['mae'] <= mae_threshold)]
    if good.empty:
        print('Нет конфигураций, удовлетворяющих speedup>=1.15 и mae<=0.05')
    else:
        good.sort_values('speedup', ascending=False)


## Сохранение таблицы (опционально)
Если нужно, можно выгрузить `results_df` в CSV для дальнейшего анализа.

In [None]:
results_df.to_csv('kernel_search_results.csv', index=False)