In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
import copy
from conv_gemm.layers.triton_conv2d import TritonConv2d


In [2]:
def set_seed(seed=0):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(0)
device = "cuda"

# Базовая сеть (обычный ResNet18)
base = resnet18(weights=None, num_classes=200).to(device)

# Triton-версия (копируем веса и меняем все Conv2d на TritonConv2d)
precision_mode   = "fp32"        # начнём без fp16, чтобы меньше дебага
use_weight_shadow = False
BEST_BLOCKS      = (32, 32, 32)
BEST_LAUNCH      = (4, 2)

tri = copy.deepcopy(base)

In [8]:
def make_triton_from_conv(
    conv: nn.Conv2d,
    *,
    precision_mode: str = "fp32",
    use_weight_shadow: bool = True,
    triton_blocks: tuple[int, int, int] = (64, 64, 32),   # (BLOCK_M, BLOCK_N, BLOCK_K)
    triton_launch: tuple[int, int] = (4, 2),              # (NUM_WARPS, NUM_STAGES)
) -> TritonConv2d:
    """
    Оборачивает один nn.Conv2d в TritonConv2d, копируя все гиперпараметры и веса.

    conv            — исходная свёртка PyTorch.
    precision_mode  — "fp32", "fp16_runtime" или "fp16_infer".
    use_weight_shadow — использовать ли fp16-shadow для веса в режиме fp16_runtime.
    triton_blocks   — размеры тайлов GEMM: (BLOCK_M, BLOCK_N, BLOCK_K).
    triton_launch   — параметры запуска: (NUM_WARPS, NUM_STAGES).
    """
    tm = TritonConv2d(
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        bias=(conv.bias is not None),
        BLOCK_M=triton_blocks[0],
        BLOCK_N=triton_blocks[1],
        BLOCK_K=triton_blocks[2],
        NUM_WARPS=triton_launch[0],
        NUM_STAGES=triton_launch[1],
        precision_mode=precision_mode,
        use_weight_shadow=use_weight_shadow,
    ).to(conv.weight.device).to(conv.weight.dtype)

    with torch.no_grad():
        tm.weight.copy_(conv.weight)
        if conv.bias is not None and tm.bias is not None:
            tm.bias.copy_(conv.bias)

    return tm


In [10]:
def replace_convs_with_triton(
    model: nn.Module,
    *,
    precision_mode: str = "fp32",
    use_weight_shadow: bool = True,
    triton_blocks: tuple[int, int, int] = (64, 64, 32),
    triton_launch: tuple[int, int] = (4, 2),
) -> nn.Module:
    """
    Рекурсивно заменяет все nn.Conv2d в модели на TritonConv2d.

    Возвращает ту же модель (in-place), но с Triton-свёртками.
    """
    for name, module in list(model.named_children()):
        # сначала рекурсивно спускаемся внутрь
        replace_convs_with_triton(
            module,
            precision_mode=precision_mode,
            use_weight_shadow=use_weight_shadow,
            triton_blocks=triton_blocks,
            triton_launch=triton_launch,
        )

        # потом уже заменяем конкретный Conv2d на этом уровне
        if isinstance(module, nn.Conv2d):
            triton_conv = make_triton_from_conv(
                module,
                precision_mode=precision_mode,
                use_weight_shadow=use_weight_shadow,
                triton_blocks=triton_blocks,
                triton_launch=triton_launch,
            )
            setattr(model, name, triton_conv)

    return model


In [11]:

tri = replace_all_convs_with_triton(
    tri,
    precision_mode=precision_mode,
    use_weight_shadow=use_weight_shadow,
    triton_blocks=BEST_BLOCKS,
    triton_launch=BEST_LAUNCH,
).to(device)

# БЕНЧМАРК

In [17]:
from typing import Dict, List, Any, Optional

def benchmark_one_step(
    model: nn.Module,
    B: int,
    H: int,
    W: int,
    *,
    device: str = "cuda",
    num_classes: int = 200,
    dtype: torch.dtype = torch.float32,
    iters_warmup: int = 2,
    iters: int = 5,
    lr: float = 1e-3,
    use_amp: bool = False,
) -> Dict[str, float]:
    """
    Замеряет один training-step модели (forward + backward + step)
    на фиксированном batch size.

    Что делает:
      - генерит случайный батч (x, y)
      - прогревает модель `iters_warmup` шагов
      - меряет среднее время по `iters` итерациям через cuda.Event
      - возвращает:
          avg_ms       — среднее время шага (ms)
          peak_mem_mb  — пиковое GPU-потребление (MB) за всё время вызова
          last_loss    — последний значение loss (чисто sanity-check)

    Параметры:
      model        — любая nn.Module (должна принимать [B, 3, H, W])
      B, H, W      — batch size и spatial размер входа
      device       — "cuda" или "cpu" (по факту нам нужен cuda для тайминга)
      num_classes  — размерность выхода для CrossEntropyLoss
      dtype        — тип входного тензора (обычно torch.float32)
      iters_warmup — сколько итераций прогреть перед замером
      iters        — по скольким итерациям усреднять время
      lr           — learning rate для SGD
      use_amp      — включать ли torch.cuda.amp.autocast
    """
    model = model.to(device)
    model.train()

    x = torch.randn(B, 3, H, W, device=device, dtype=dtype)
    y = torch.randint(0, num_classes, (B,), device=device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    scaler: Optional[torch.cuda.amp.GradScaler] = None
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device=device)

    # --- прогрев ---
    for _ in range(iters_warmup):
        optimizer.zero_grad(set_to_none=True)
        if use_amp:
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)
            assert scaler is not None
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
    torch.cuda.synchronize()

    # --- замер ---
    start_event = torch.cuda.Event(enable_timing=True)
    end_event   = torch.cuda.Event(enable_timing=True)

    total_ms = 0.0
    last_loss = 0.0

    for _ in range(iters):
        optimizer.zero_grad(set_to_none=True)

        start_event.record()
        if use_amp:
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)
            assert scaler is not None
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
        end_event.record()

        torch.cuda.synchronize()
        iter_ms = start_event.elapsed_time(end_event)
        total_ms += iter_ms
        last_loss = float(loss.item())

    avg_ms = total_ms / iters
    peak_mem_bytes = torch.cuda.max_memory_allocated(device=device)
    peak_mem_mb = peak_mem_bytes / (1024**2)

    return {
        "avg_ms": avg_ms,
        "peak_mem_mb": peak_mem_mb,
        "last_loss": last_loss,
    }


In [18]:
def find_max_batch_size(
    model: nn.Module,
    candidate_batches: List[int],
    H: int,
    W: int,
    *,
    num_classes: int = 200,
    label: str = "",
    benchmark_kwargs: Optional[Dict[str, Any]] = None,
):
    """
    Подбирает максимальный batch size, который модель выдерживает
    при полном шаге обучения (forward + backward + optimizer.step).

    Что делает:
      - для каждого B из candidate_batches:
          * вызывает benchmark_one_step(model, B, H, W, ...)
          * логирует время и пиковую память
      - если падает по OOM / OutOfResources — помечает B как FAIL
      - в конце выбирает 'best' как:
          * максимальный B, который успешно прошёл

    Параметры:
      model             — модель (предполагается уже на нужном device)
      candidate_batches — список batch size'ов, которые пробуем (можно в любом порядке)
      H, W              — spatial размер входа (3 x H x W)
      num_classes       — число классов (для CrossEntropy внутри benchmark_one_step)
      label             — метка для логов (например, "Torch ResNet18" / "Triton ResNet18")
      benchmark_kwargs  — доп. аргументы, которые будут проброшены в benchmark_one_step
                          (например: {'use_amp': True})

    Возвращает:
      dict с полями:
        - 'best_B'       — максимальный B, который прошёл (или None, если всё упало)
        - 'best_stats'   — словарь с метриками для best_B (avg_ms, peak_mem_mb, last_loss)
        - 'per_batch'    — список (B, status, stats_or_msg) по всем кандидатам
    """
    if benchmark_kwargs is None:
        benchmark_kwargs = {}

    print(f"\n=== Finding max batch size for {label} ===")
    per_batch = []
    best_B: Optional[int] = None
    best_stats: Optional[Dict[str, float]] = None

    for B in candidate_batches:
        print(f"\n-- Try B = {B} --")
        try:
            res = benchmark_one_step(
                model=model,
                B=B,
                H=H,
                W=W,
                num_classes=num_classes,
                **benchmark_kwargs,
            )
            print(f"OK: B={B}, avg_time={res['avg_ms']:.2f} ms, "
                  f"peak_mem={res['peak_mem_mb']:.1f} MB, "
                  f"loss={res['last_loss']:.4f}")

            per_batch.append((B, "OK", res))

            # выбираем максимальный B, который прошёл
            if (best_B is None) or (B > best_B):
                best_B = B
                best_stats = res

        except RuntimeError as e:
            msg = str(e)
            if ("out of memory" in msg.lower()
                or "outofresources" in msg.lower()
                or "out of resource" in msg.lower()):
                print(f"FAIL (OOM / OutOfResources) at B={B}")
                per_batch.append((B, "OOM", msg))
            else:
                print(f"FAIL other error at B={B}: {msg}")
                per_batch.append((B, "ERR", msg))

    print("\n=== Best that worked for", label, "===")
    if best_B is None:
        print("Nothing worked at all (даже минимальный батч)")
    else:
        print(f"B={best_B}, avg_time={best_stats['avg_ms']:.2f} ms, "
              f"peak_mem={best_stats['peak_mem_mb']:.1f} MB")

    return {
        "best_B": best_B,
        "best_stats": best_stats,
        "per_batch": per_batch,
    }

In [20]:
H = 224
W = 224
candidate_batches = [256, 192, 128, 96, 64, 48, 32, 24, 16, 8, 4]

# Для честности — копии моделей, чтобы оптимизатор не вмешивался между прогонками
base_for_test = copy.deepcopy(base).to(device)
tri_for_test  = copy.deepcopy(tri).to(device)

best_base = find_max_batch_size(
    model=base_for_test,
    candidate_batches=candidate_batches,
    H=H,
    W=W,
    num_classes=200,
    label="Torch ResNet18",
)

best_tri = find_max_batch_size(
    model=tri_for_test,
    candidate_batches=candidate_batches,
    H=H,
    W=W,
    num_classes=200,
    label="Triton ResNet18",
)



=== Finding max batch size for Torch ResNet18 ===

-- Try B = 256 --
FAIL (OOM / OutOfResources) at B=256

-- Try B = 192 --
OK: B=192, avg_time=219.36 ms, peak_mem=4438.6 MB, loss=5.3775

-- Try B = 128 --
OK: B=128, avg_time=149.69 ms, peak_mem=3032.8 MB, loss=5.3628

-- Try B = 96 --
OK: B=96, avg_time=117.99 ms, peak_mem=2336.5 MB, loss=5.2413

-- Try B = 64 --
OK: B=64, avg_time=78.59 ms, peak_mem=1642.9 MB, loss=5.3118

-- Try B = 48 --
OK: B=48, avg_time=62.58 ms, peak_mem=1290.3 MB, loss=5.1061

-- Try B = 32 --
OK: B=32, avg_time=45.63 ms, peak_mem=946.1 MB, loss=5.0791

-- Try B = 24 --
OK: B=24, avg_time=34.13 ms, peak_mem=776.0 MB, loss=4.9017

-- Try B = 16 --
OK: B=16, avg_time=23.95 ms, peak_mem=604.8 MB, loss=4.8180

-- Try B = 8 --
OK: B=8, avg_time=12.95 ms, peak_mem=434.1 MB, loss=4.7043

-- Try B = 4 --
OK: B=4, avg_time=8.62 ms, peak_mem=347.5 MB, loss=3.8065

=== Best that worked for Torch ResNet18 ===
B=192, avg_time=219.36 ms, peak_mem=4438.6 MB

=== Finding ma

In [21]:
B_TORCH_MAX = 192
B_TRITON_MAX = 96

In [14]:
import time
import copy
import torch
import torch.nn as nn

def bench(fn, warmup=10, iters=100):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters


def make_pair(
    in_channels=64,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
    dilation=1,
    bias=True,
    device="cuda",
):
    conv_ref = nn.Conv2d(
        in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, dilation=dilation, bias=bias
    ).to(device)

    triton_conv = TritonConv2d(
        in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, dilation=dilation, bias=bias,
        BLOCK_M=32, BLOCK_N=32, BLOCK_K=32,
        NUM_WARPS=4, NUM_STAGES=2,
        precision_mode="fp32",
    ).to(device)

    with torch.no_grad():
        triton_conv.weight.copy_(conv_ref.weight)
        if conv_ref.bias is not None and triton_conv.bias is not None:
            triton_conv.bias.copy_(conv_ref.bias)

    return conv_ref, triton_conv


def run_triton_vs_conv_bench():
    device = "cuda"

    N, Cin, Cout = 32, 64, 64
    H, W = 56, 56
    ks = 3
    stride = 1
    padding = 1
    dilation = 1

    x_fp32 = torch.randn(N, Cin, H, W, device=device)

    conv_ref_fp32, triton_conv = make_pair(
        in_channels=Cin,
        out_channels=Cout,
        kernel_size=ks,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=True,
        device=device,
    )

    # =========================================================
    # FP32
    # =========================================================
    print("=== FP32 режим ===")
    triton_conv.set_precision("fp32")

    with torch.no_grad():
        y_ref = conv_ref_fp32(x_fp32)
        y_triton = triton_conv(x_fp32)

    max_err = (y_ref - y_triton).abs().max().item()
    print(f"fp32 max_err = {max_err:.6e}")

    with torch.no_grad():
        t_ref = bench(lambda: conv_ref_fp32(x_fp32))
        t_tr = bench(lambda: triton_conv(x_fp32))

    print(f"Conv2d (cuDNN, fp32):    {t_ref*1e3:.3f} ms")
    print(f"TritonConv2d (fp32):     {t_tr*1e3:.3f} ms")
    print(f"speedup (Triton/Conv2d): {t_ref/t_tr:.3f}x\n")

    # =========================================================
    # FP16 INFER — отдельная копия conv_ref
    # =========================================================
    print("=== FP16 INFER режим (вся сеть в half) ===")
    triton_conv.set_precision("fp16_infer")

    x_fp16 = x_fp32.half()

    # делаем deepcopy, чтобы не портить conv_ref_fp32
    conv_ref_fp16 = copy.deepcopy(conv_ref_fp32).half()

    with torch.no_grad():
        y_ref16 = conv_ref_fp16(x_fp16)      # half
        y_triton16 = triton_conv(x_fp16)     # half

    max_err16 = (y_ref16.float() - y_triton16.float()).abs().max().item()
    print(f"fp16_infer max_err = {max_err16:.6e}")

    with torch.no_grad():
        t_ref16 = bench(lambda: conv_ref_fp16(x_fp16))
        t_tr16 = bench(lambda: triton_conv(x_fp16))

    print(f"Conv2d (cuDNN, fp16):       {t_ref16*1e3:.3f} ms")
    print(f"TritonConv2d (fp16_infer):  {t_tr16*1e3:.3f} ms")
    print(f"speedup (Triton/Conv2d):    {t_ref16/t_tr16:.3f}x\n")

    # =========================================================
    # FP16 RUNTIME — опять используем исходный fp32 conv
    # =========================================================
    print("=== FP16 RUNTIME режим (вход fp32, внутри half, наружу fp32) ===")
    triton_conv.set_precision("fp16_runtime")

    with torch.no_grad():
        y_ref_rt = conv_ref_fp32(x_fp32)        # fp32 эталон
        y_triton_rt = triton_conv(x_fp32)       # внутри half, наружу fp32

    max_err_rt = (y_ref_rt - y_triton_rt).abs().max().item()
    print(f"fp16_runtime max_err = {max_err_rt:.6e}")

    with torch.no_grad():
        t_ref_rt = bench(lambda: conv_ref_fp32(x_fp32))
        t_tr_rt = bench(lambda: triton_conv(x_fp32))

    print(f"Conv2d (cuDNN, fp32):         {t_ref_rt*1e3:.3f} ms")
    print(f"TritonConv2d (fp16_runtime):  {t_tr_rt*1e3:.3f} ms")
    print(f"speedup (Triton/Conv2d):      {t_ref_rt/t_tr_rt:.3f}x")

In [15]:
run_triton_vs_conv_bench()

=== FP32 режим ===
fp32 max_err = 4.053116e-06
Conv2d (cuDNN, fp32):    0.717 ms
TritonConv2d (fp32):     1.990 ms
speedup (Triton/Conv2d): 0.360x

=== FP16 INFER режим (вся сеть в half) ===
fp16_infer max_err = 1.953125e-03
Conv2d (cuDNN, fp16):       0.421 ms
TritonConv2d (fp16_infer):  1.706 ms
speedup (Triton/Conv2d):    0.247x

=== FP16 RUNTIME режим (вход fp32, внутри half, наружу fp32) ===
fp16_runtime max_err = 1.393557e-03
Conv2d (cuDNN, fp32):         0.701 ms
TritonConv2d (fp16_runtime):  1.875 ms
speedup (Triton/Conv2d):      0.374x


In [4]:
from conv_gemm.layers.triton_conv2d import TritonConv2d

In [7]:
import time
import torch
import torch.nn as nn
import pandas as pd
from IPython.display import display

from conv_gemm.layers.triton_conv2d import TritonConv2d
from conv_gemm.layers.triton_conv2d_int8 import TritonConv2dInt8


def bench_mean(fn, iters: int = 50) -> float:
    """Среднее время выполнения fn() за iters прогонов (в секундах)."""
    # небольшой прогрев
    for _ in range(5):
        fn()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters


def safe_forward(fn):
    """Выполнить forward и аккуратно поймать OOM/другие RuntimeError."""
    try:
        out = fn()
        return out, None
    except RuntimeError as e:
        msg = str(e).lower()
        if "out of memory" in msg:
            torch.cuda.empty_cache()
            return None, "OOM"
        return None, "ERR"


def show_results_as_table(results):
    cols = [
        "img", "N", "Cin", "Cout", "K",
        "cuDNN FP16 ms", "cuDNN FP32 ms",
        "Triton FP16 ms", "Triton INT8 ms",
        "Speed FP16", "Speed INT8",
        "Err FP16", "Err INT8",
    ]
    df = pd.DataFrame(results, columns=cols)
    display(df)
    return df


def run_full_conv_bench_fp16_int8(
    image_sizes=(224, 112, 56),
    batch_sizes=(1, 2, 4, 8),
    channels=((3, 32), (32, 32), (32, 64), (64, 64), (64, 128)),
    kernels=(3, 5, 7, 9),
    stride=1,
    iters=200,
):
    device = "cuda"
    results = []

    for img in image_sizes:
        H = W = img

        for N in batch_sizes:
            for Cin, Cout in channels:
                for ks in kernels:
                    print(f"DONE: img={img}, N={N}, Cin={Cin}, Cout={Cout}, ks={ks}")

                    x_fp16 = torch.randn(
                        N, Cin, H, W, device=device, dtype=torch.float16
                    )

                    # === БАЗОВЫЕ свёртки: cuDNN FP32 + cuDNN FP16 ===
                    conv_ref_fp32 = nn.Conv2d(
                        Cin, Cout, ks,
                        stride=stride, padding=ks // 2, dilation=1, bias=True,
                    ).to(device)          # fp32
                    conv_ref_fp32.eval()

                    conv_ref_fp16 = nn.Conv2d(
                        Cin, Cout, ks,
                        stride=stride, padding=ks // 2, dilation=1, bias=True,
                    ).to(device).half()   # fp16
                    conv_ref_fp16.eval()

                    # одинаковые веса в обоих референсах
                    with torch.no_grad():
                        conv_ref_fp16.weight.copy_(conv_ref_fp32.weight.half())
                        if conv_ref_fp32.bias is not None:
                            conv_ref_fp16.bias.copy_(conv_ref_fp32.bias.half())

                    # === Triton FP16 ===
                    triton_conv_fp16 = TritonConv2d(
                        Cin, Cout, ks,
                        stride=stride, padding=ks // 2, dilation=1, bias=True,
                        BLOCK_M=64, BLOCK_N=64, BLOCK_K=32,
                        NUM_WARPS=4, NUM_STAGES=2,
                        precision_mode="fp16_infer",  # вход/веса half, выход half
                    ).to(device)
                    triton_conv_fp16.eval()

                    with torch.no_grad():
                        triton_conv_fp16.weight.copy_(conv_ref_fp16.weight)
                        if triton_conv_fp16.bias is not None:
                            triton_conv_fp16.bias.copy_(conv_ref_fp16.bias)

                    # ---------- FORWARD / ERRORS ----------

                    # cuDNN FP16 baseline (этот выход используем как "истину")
                    y_ref, err_ref = safe_forward(lambda: conv_ref_fp16(x_fp16))
                    if err_ref:
                        results.append((
                            img, N, Cin, Cout, ks,
                            "OOM",      # cuDNN FP16 ms
                            "OOM",      # cuDNN FP32 ms
                            "OOM",      # Triton FP16 ms
                            "OOM",      # Triton INT8 ms
                            "-", "-", "-", "-",
                        ))
                        continue

                    # Triton FP16
                    y_tri_fp16, err_tri_fp16 = safe_forward(
                        lambda: triton_conv_fp16(x_fp16)
                    )
                    if err_tri_fp16:
                        results.append((
                            img, N, Cin, Cout, ks,
                            "OK",       # cuDNN FP16 живой, но Triton упал
                            "OK",
                            "OOM",
                            "OOM",
                            "-", "-", "-", "-",
                        ))
                        continue

                    # Можно ли вообще делать INT8? (K кратно 4)
                    K = Cin * ks * ks
                    supports_int8 = (K % 4 == 0)

                    # ---------- Ветвь: INT8 не поддержан (UNSUP) ----------
                    if not supports_int8:
                        err_fp16 = (y_ref - y_tri_fp16).float().abs().max().item()

                        t_ref_fp16 = bench_mean(lambda: conv_ref_fp16(x_fp16), iters)
                        t_ref_fp32 = bench_mean(lambda: conv_ref_fp32(x_fp16.float()), iters)
                        t_tri_fp16 = bench_mean(lambda: triton_conv_fp16(x_fp16), iters)

                        speed_fp16 = t_ref_fp16 / t_tri_fp16

                        results.append((
                            img, N, Cin, Cout, ks,
                            f"{t_ref_fp16 * 1e3:.3f}",   # cuDNN FP16 ms
                            f"{t_ref_fp32 * 1e3:.3f}",   # cuDNN FP32 ms
                            f"{t_tri_fp16 * 1e3:.3f}",   # Triton FP16 ms
                            "UNSUP",                     # Triton INT8 ms
                            f"{speed_fp16:.3f}x",
                            "-",
                            f"{err_fp16:.2e}",
                            "-",
                        ))
                        continue

                    # === Triton INT8 ===
                    triton_conv_int8 = TritonConv2dInt8(
                        Cin, Cout, ks,
                        stride=stride, padding=ks // 2, dilation=1, bias=True,
                    ).to(device)
                    triton_conv_int8.eval()

                    with torch.no_grad():
                        triton_conv_int8.init_from_conv(conv_ref_fp16)

                    y_tri_int8, err_tri_int8 = safe_forward(
                        lambda: triton_conv_int8(x_fp16)
                    )

                    # ---------- Ветвь: INT8 упал (ERR) ----------
                    if err_tri_int8:
                        err_fp16 = (y_ref - y_tri_fp16).float().abs().max().item()

                        t_ref_fp16 = bench_mean(lambda: conv_ref_fp16(x_fp16), iters)
                        t_ref_fp32 = bench_mean(lambda: conv_ref_fp32(x_fp16.float()), iters)
                        t_tri_fp16 = bench_mean(lambda: triton_conv_fp16(x_fp16), iters)

                        speed_fp16 = t_ref_fp16 / t_tri_fp16

                        results.append((
                            img, N, Cin, Cout, ks,
                            f"{t_ref_fp16 * 1e3:.3f}",
                            f"{t_ref_fp32 * 1e3:.3f}",
                            f"{t_tri_fp16 * 1e3:.3f}",
                            "ERR",
                            f"{speed_fp16:.3f}x",
                            "-",
                            f"{err_fp16:.2e}",
                            "-",
                        ))
                        continue

                    # ---------- Основная ветка: всё отработало ----------

                    err_fp16 = (y_ref - y_tri_fp16).float().abs().max().item()
                    err_int8 = (y_ref - y_tri_int8).float().abs().max().item()

                    t_ref_fp16 = bench_mean(lambda: conv_ref_fp16(x_fp16), iters)
                    t_ref_fp32 = bench_mean(lambda: conv_ref_fp32(x_fp16.float()), iters)
                    t_tri_fp16 = bench_mean(lambda: triton_conv_fp16(x_fp16), iters)
                    t_tri_int8 = bench_mean(lambda: triton_conv_int8(x_fp16), iters)

                    speed_fp16 = t_ref_fp16 / t_tri_fp16
                    speed_int8 = t_ref_fp16 / t_tri_int8

                    results.append((
                        img, N, Cin, Cout, ks,
                        f"{t_ref_fp16 * 1e3:.3f}",   # cuDNN FP16 ms
                        f"{t_ref_fp32 * 1e3:.3f}",   # cuDNN FP32 ms
                        f"{t_tri_fp16 * 1e3:.3f}",   # Triton FP16 ms
                        f"{t_tri_int8 * 1e3:.3f}",   # Triton INT8 ms
                        f"{speed_fp16:.3f}x",
                        f"{speed_int8:.3f}x",
                        f"{err_fp16:.2e}",
                        f"{err_int8:.2e}",
                    ))

    return show_results_as_table(results)


# пример вызова
df = run_full_conv_bench_fp16_int8(
    image_sizes=[1024],
    # image_sizes=[224, 112, 56],
    batch_sizes=[1, ],
    # batch_sizes=[1, 2, 4, 8],
    channels=[(1, 1),(1, 3), (3, 8), (32, 64)],
    # channels=[(1, 3), (3, 16), (32, 64), (64, 64), (64, 128)],
    kernels=[1, 3, 7, 9, 11],
    # kernels=[3, 5, 7, 9],
    stride=1,
    iters=300,
)


DONE: img=1024, N=1, Cin=1, Cout=1, ks=1
DONE: img=1024, N=1, Cin=1, Cout=1, ks=3
DONE: img=1024, N=1, Cin=1, Cout=1, ks=7
DONE: img=1024, N=1, Cin=1, Cout=1, ks=9
DONE: img=1024, N=1, Cin=1, Cout=1, ks=11
DONE: img=1024, N=1, Cin=1, Cout=3, ks=1
DONE: img=1024, N=1, Cin=1, Cout=3, ks=3
DONE: img=1024, N=1, Cin=1, Cout=3, ks=7
DONE: img=1024, N=1, Cin=1, Cout=3, ks=9
DONE: img=1024, N=1, Cin=1, Cout=3, ks=11
DONE: img=1024, N=1, Cin=3, Cout=8, ks=1
DONE: img=1024, N=1, Cin=3, Cout=8, ks=3
DONE: img=1024, N=1, Cin=3, Cout=8, ks=7
DONE: img=1024, N=1, Cin=3, Cout=8, ks=9
DONE: img=1024, N=1, Cin=3, Cout=8, ks=11
DONE: img=1024, N=1, Cin=32, Cout=64, ks=1
DONE: img=1024, N=1, Cin=32, Cout=64, ks=3
DONE: img=1024, N=1, Cin=32, Cout=64, ks=7
DONE: img=1024, N=1, Cin=32, Cout=64, ks=9
DONE: img=1024, N=1, Cin=32, Cout=64, ks=11


Unnamed: 0,img,N,Cin,Cout,K,cuDNN FP16 ms,cuDNN FP32 ms,Triton FP16 ms,Triton INT8 ms,Speed FP16,Speed INT8,Err FP16,Err INT8
0,1024,1,1,1,1,0.085,0.063,0.200,UNSUP,0.427x,-,9.77e-04,-
1,1024,1,1,1,3,0.099,0.119,0.570,UNSUP,0.174x,-,9.77e-04,-
2,1024,1,1,1,7,0.144,0.142,1.000,UNSUP,0.144x,-,1.95e-03,-
3,1024,1,1,1,9,0.195,1.002,1.415,UNSUP,0.138x,-,1.95e-03,-
4,1024,1,1,1,11,0.272,1.330,1.931,UNSUP,0.141x,-,1.95e-03,-
5,1024,1,1,3,1,0.118,0.167,0.509,UNSUP,0.232x,-,1.95e-03,-
6,1024,1,1,3,3,0.202,0.325,0.935,UNSUP,0.216x,-,1.95e-03,-
7,1024,1,1,3,7,0.399,0.624,1.417,UNSUP,0.282x,-,1.95e-03,-
8,1024,1,1,3,9,0.499,1.033,1.843,UNSUP,0.271x,-,1.95e-03,-
9,1024,1,1,3,11,0.788,1.372,2.415,UNSUP,0.326x,-,1.95e-03,-


In [None]:
import pandas as pd
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 2000)
from IPython.display import display

display(df)