# Sparsity Kernel Search (FP16)
????? ?? ??? ?? ?????, ??? ? baseline, ? block-grid ??? ? baseline_kernel_search.ipynb. ???????????? torch Conv2d, TritonConv2d (dense) ? ??? ???? ????????????? ?? ?????? keep_ratio. ??? ?????????? ????????? ?? ???? ?????????, ????? ?????? ????????.

In [1]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent))
print('sys.path[0]=', sys.path[0])

sys.path[0]= /mnt/d/VSCode-Projects/conv2d-img2col-gemm


In [2]:
import time
import torch
import pandas as pd
from pathlib import Path
import importlib
from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d
import conv_gemm.baseline_operators.triton_conv2d_fp16_fn as tri_fn
from conv_gemm.configs import kernel_config as kc

torch.backends.cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
print(f'device={device}, dtype={dtype}')
if device != 'cuda':
    print('?? ????? ?????????? ?? GPU ? CUDA; ?? CPU ????? ????? ??????.')

device=cuda, dtype=torch.float16


In [3]:
def sync_device():
    if device == 'cuda':
        torch.cuda.synchronize()

def clone_weights(dst: torch.nn.Module, src: torch.nn.Module):
    with torch.no_grad():
        dst.weight.copy_(src.weight)
        if dst.bias is not None and src.bias is not None:
            dst.bias.copy_(src.bias)

def benchmark_layer(layer: torch.nn.Module, x: torch.Tensor, warmup: int = 10, iters: int = 50) -> float:
    layer.eval()
    with torch.no_grad():
        for _ in range(warmup):
            _ = layer(x)
        sync_device()
        if device == 'cuda':
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            for _ in range(iters):
                _ = layer(x)
            end.record()
            sync_device()
            elapsed_ms = start.elapsed_time(end)
        else:
            t0 = time.perf_counter()
            for _ in range(iters):
                _ = layer(x)
            elapsed_ms = (time.perf_counter() - t0) * 1e3
    return elapsed_ms / iters

def build_torch_conv(cfg):
    return torch.nn.Conv2d(
        in_channels=cfg['in_channels'],
        out_channels=cfg['out_channels'],
        kernel_size=cfg['kernel_size'],
        stride=cfg['stride'],
        padding=cfg['padding'],
        dilation=cfg['dilation'],
        bias=True,
    ).to(device=device, dtype=dtype)

def calc_diff(ref: torch.Tensor, test: torch.Tensor):
    diff = (ref - test).float()
    return {
        'mae': diff.abs().mean().item(),
        'max': diff.abs().max().item(),
        'rel_l2': (torch.norm(diff) / torch.norm(ref)).item(),
    }

def apply_sparsity(layer: TritonConv2d, mode: str, ratio: float):
    layer.clear_sparsity()
    if mode == 'dense':
        return
    if mode == 'channel':
        layer.set_channel_sparsity(ratio)
    elif mode == 'block':
        layer.set_block_sparsity(ratio, block_size=4)
    elif mode == 'input':
        layer.set_input_channel_sparsity(ratio)
    else:
        raise ValueError(f'unknown mode {mode}')

def apply_block_cfg(cfg):
    # ??????????? GEMM-?????? ????? ???????? ?????
    tri_fn.FP16_GEMM_CFG = kc.KernelConfig(
        BLOCK_M=cfg['BLOCK_M'],
        BLOCK_N=cfg['BLOCK_N'],
        BLOCK_K=cfg['BLOCK_K'],
        NUM_WARPS=cfg['NUM_WARPS'],
        NUM_STAGES=cfg['NUM_STAGES'],
    )
    kc.FP16_GEMM_CFG = tri_fn.FP16_GEMM_CFG
    importlib.reload(tri_fn)

torch.manual_seed(0)
if device == 'cuda':
    torch.cuda.manual_seed(0)

In [4]:
kernel_grid = [
    dict(name='1x1@56', in_channels=64, out_channels=64, kernel_size=1, stride=1, padding=0, dilation=1, B=32, H=56, W=56),
    dict(name='3x3@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=56, W=56),
    dict(name='5x5@56', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=56, W=56),
    dict(name='7x7@56', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=16, H=56, W=56),
    dict(name='11x11@56', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=56, W=56),
    dict(name='3x3@112', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='5x5@112', in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2, dilation=1, B=16, H=112, W=112),
    dict(name='7x7@112', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=12, H=112, W=112),
    dict(name='11x11@112', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=8, H=112, W=112),
    dict(name='3x3@224', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=8, H=224, W=224),
    dict(name='7x7@224', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=6, H=224, W=224),
    dict(name='11x11@224', in_channels=64, out_channels=128, kernel_size=11, stride=1, padding=5, dilation=1, B=4, H=224, W=224),
    dict(name='13x13@224', in_channels=64, out_channels=128, kernel_size=13, stride=1, padding=6, dilation=1, B=2, H=224, W=224),
    dict(name='3x3@512', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1, B=4, H=512, W=512),
    dict(name='7x7@512', in_channels=64, out_channels=128, kernel_size=7, stride=1, padding=3, dilation=1, B=2, H=512, W=512),
    dict(name='3x3@32', in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, B=32, H=32, W=32),
    dict(name='5x5@32', in_channels=128, out_channels=256, kernel_size=5, stride=1, padding=2, dilation=1, B=32, H=32, W=32),
    dict(name='7x7@32', in_channels=128, out_channels=256, kernel_size=7, stride=1, padding=3, dilation=1, B=32, H=32, W=32),
    dict(name='3x3@16', in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, B=64, H=16, W=16),
    dict(name='3x3_s2@112', in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1, B=16, H=112, W=112),
    dict(name='3x3_dil2@56', in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2, B=16, H=56, W=56),
]
print('kernel configs:', len(kernel_grid))

block_grid = [
    dict(name='64-64-16/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-32/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-64-64/4x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='64-128-32/4x2', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
    dict(name='128-64-32/8x2', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-32/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-128-64/8x2', BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, NUM_WARPS=8, NUM_STAGES=2),
    dict(name='128-64-32/8x3', BLOCK_M=128, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=8, NUM_STAGES=3),
    dict(name='64-128-32/4x3', BLOCK_M=64, BLOCK_N=128, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=3),
    dict(name='64-64-32/2x2', BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=2, NUM_STAGES=2),
]
print('block configs:', len(block_grid))

keep_ratios = [1.0, 0.85, 0.75, 0.65, 0.5, 0.35, 0.25]
modes = ['dense', 'channel', 'block', 'input']
print('keep_ratios:', keep_ratios)
print('modes:', modes)

kernel configs: 21
block configs: 10
keep_ratios: [1.0, 0.85, 0.75, 0.65, 0.5, 0.35, 0.25]
modes: ['dense', 'channel', 'block', 'input']


In [5]:

rows = []
if device != 'cuda':
    print('CUDA ?????????? ? ?????????? ????????.')
else:
    combos_per_kernel = len(block_grid) * (1 + 3 * len(keep_ratios))
    total = len(kernel_grid) * combos_per_kernel
    done = 0
    print(f'total combinations: {total}')
    for cfg in kernel_grid:
        torch_conv = build_torch_conv(cfg)
        x = torch.randn(cfg['B'], cfg['in_channels'], cfg['H'], cfg['W'], device=device, dtype=dtype)
        torch_time = benchmark_layer(torch_conv, x)
        torch_out = torch_conv(x).detach().float()
        for block in block_grid:
            apply_block_cfg(block)
            for mode in modes:
                ratios = keep_ratios if mode != 'dense' else [1.0]
                for ratio in ratios:
                    layer = TritonConv2d(
                        in_channels=cfg['in_channels'],
                        out_channels=cfg['out_channels'],
                        kernel_size=cfg['kernel_size'],
                        stride=cfg['stride'],
                        padding=cfg['padding'],
                        dilation=cfg['dilation'],
                        bias=True,
                    ).to(device)
                    clone_weights(layer, torch_conv)
                    apply_sparsity(layer, mode, ratio)
                    time_ms = benchmark_layer(layer, x)
                    with torch.no_grad():
                        y = layer(x).float()
                    diff_stats = calc_diff(torch_out, y)
                    done += 1
                    print(f"[{done}/{total}] kernel={cfg['name']}, block={block['name']}, mode={mode}, keep_ratio={ratio:.2f}, time_ms={time_ms:.3f}, speedup_vs_torch={torch_time/time_ms if time_ms>0 else float('nan'):.3f}", flush=True)
                    rows.append({
                        'kernel': cfg['name'],
                        'B': cfg['B'], 'H': cfg['H'], 'W': cfg['W'], 'k': cfg['kernel_size'],
                        'block': block['name'],
                        'BLOCK_M': block['BLOCK_M'], 'BLOCK_N': block['BLOCK_N'], 'BLOCK_K': block['BLOCK_K'],
                        'NUM_WARPS': block['NUM_WARPS'], 'NUM_STAGES': block['NUM_STAGES'],
                        'mode': mode,
                        'keep_ratio': float(ratio),
                        'time_ms': time_ms,
                        'torch_time_ms': torch_time,
                        'speedup_vs_torch': torch_time / time_ms if time_ms > 0 else float('nan'),
                        'mae': diff_stats['mae'],
                        'max': diff_stats['max'],
                        'rel_l2': diff_stats['rel_l2'],
                    })
results_df = pd.DataFrame(rows)
results_df


total combinations: 4620
[1/4620] kernel=1x1@56, block=64-64-16/4x2, mode=dense, keep_ratio=1.00, time_ms=2.684, speedup_vs_torch=0.138
[2/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=1.00, time_ms=2.752, speedup_vs_torch=0.134
[3/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.85, time_ms=3.198, speedup_vs_torch=0.116
[4/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.75, time_ms=3.159, speedup_vs_torch=0.117
[5/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.65, time_ms=3.064, speedup_vs_torch=0.121
[6/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.50, time_ms=2.720, speedup_vs_torch=0.136
[7/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.35, time_ms=2.645, speedup_vs_torch=0.140
[8/4620] kernel=1x1@56, block=64-64-16/4x2, mode=channel, keep_ratio=0.25, time_ms=2.453, speedup_vs_torch=0.151
[9/4620] kernel=1x1@56, block=64-64-16/4x2, mode=block, keep_ratio=1.00, 


KeyboardInterrupt



In [None]:
if not results_df.empty:
    best_df = results_df.sort_values('time_ms', ascending=True).groupby('kernel', as_index=False).first()
    display(best_df[['kernel', 'block', 'mode', 'keep_ratio', 'time_ms', 'speedup_vs_torch']])
else:
    print('??? ???????????: ????????? ???????? ?? CUDA.')

In [None]:
if not results_df.empty:
    out_all = Path('notebooks/sparsity_kernel_search_results.csv')
    out_best = Path('notebooks/sparsity_kernel_best.csv')
    results_df.to_csv(out_all, index=False)
    results_df.sort_values('time_ms', ascending=True).groupby('kernel', as_index=False).first().to_csv(out_best, index=False)
    print('saved:', out_all)
    print('saved:', out_best)
else:
    print('??? ?????? ??? ??????????.')