In [1]:
import torch, torch.nn.functional as F
import os, sys
sys.path.insert(0, os.path.abspath("."))  # вставляем в начало
from conv_gemm.layers.triton_conv2d import TritonConv2d

In [2]:
import time
import torch

def bench(fn, warmup=50, iters=500):
    """Простой CUDA benchmark: среднее время итерации в секундах."""
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters

In [3]:
def check_forward_backward_with_precision(cfg, precision_mode: str):
    device = "cuda"
    B, Cin, Cout, H, W, ks = cfg["B"], cfg["Cin"], cfg["Cout"], cfg["H"], cfg["W"], cfg["ks"]
    stride, padding, dilation = cfg["stride"], cfg["padding"], cfg["dilation"]
    dtype = cfg["dtype"]                      # у тебя уже fp16
    use_bias = cfg.get("use_bias", True)

    torch.manual_seed(0)

    # вход — half (чтобы имело смысл гонять img2col в 2 байта/элемент)
    x = torch.randn(B, Cin, H, W, device=device, dtype=dtype, requires_grad=True)

    # Triton-модель
    m_tri = TritonConv2d(
        in_channels=Cin, out_channels=Cout, kernel_size=ks,
        stride=stride, padding=padding, dilation=dilation, bias=use_bias,
        BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_N=cfg.get("BLOCK_N",64), BLOCK_K=cfg.get("BLOCK_K",32),
        NUM_WARPS=cfg.get("NUM_WARPS",4), NUM_STAGES=cfg.get("NUM_STAGES",2),
        precision_mode=precision_mode
    ).to(device)

    # Эталон: Conv2d
    m_ref = torch.nn.Conv2d(Cin, Cout, ks, stride=stride, padding=padding,
                            dilation=dilation, bias=use_bias).to(device)

    # Синхронизируем начальные веса/биас
    with torch.no_grad():
        m_ref.weight.copy_(m_tri.weight.float())   # у ref всё в fp16
        if use_bias: m_ref.bias.copy_( (m_tri.bias.float() if m_tri.bias is not None else torch.zeros_like(m_ref.bias)) )

    # Применяем precision режим к Triton-слою (переводит хранение/тени)
    m_tri.set_precision(precision_mode)

    # Для корректного сравнения: прогоняем ref в том же «эффективном» dtype входа
    # y_ref всегда считаем в fp16 (как torch делает по умолчанию),
    # а для сравнения кастаем в float, как и раньше.
    with torch.no_grad():
        y_tri = m_tri(x)          # Triton: в зависимости от режима считает в fp16/fp16 с acc fp16
        y_ref = m_ref(x.float())  # Torch: считает в fp16 (стабильная эталонная база)
        mae = (y_ref.float() - y_tri.float()).abs().mean().item()
        mx  = (y_ref.float() - y_tri.float()).abs().max().item()
        print(f"[{precision_mode}][FWD] MAE={mae:.6e} | max|diff|={mx:.3e}")

    # ---- Backward сравнение (общий upstream-градиент) ----
    gy = torch.randn_like(y_ref, dtype=torch.float32)

    # ref grads
    for p in m_ref.parameters():
        if p.grad is not None: p.grad.zero_()
    if x.grad is not None: x.grad.zero_()
    y = m_ref(x.float())
    (y.float() * gy).sum().backward()
    dx_ref = x.grad.detach().float().clone()
    dw_ref = m_ref.weight.grad.detach().float().clone()
    db_ref = m_ref.bias.grad.detach().float().clone() if use_bias else None

    # triton grads
    if x.grad is not None: x.grad.zero_()
    for p in m_tri.parameters():
        if p.grad is not None: p.grad.zero_()
    y = m_tri(x)                                  # внутри bwd у нас всё в fp16 — стабильно
    (y.float() * gy).sum().backward()
    dx_tri = x.grad.detach().float().clone()
    dw_tri = m_tri.weight.grad.detach().float().clone()
    db_tri = m_tri.bias.grad.detach().float().clone() if use_bias else None

    def stats(name, a, b):
        d = (a - b).abs()
        mae = d.mean().item(); mx = d.max().item()
        rel = d.norm().item() / max(a.norm().item(), 1e-12)
        print(f"[{precision_mode}][BWD:{name}] MAE={mae:.6e} | max|diff|={mx:.3e} | relL2={rel:.3e}")

    stats("dx", dx_ref, dx_tri)
    stats("dw", dw_ref, dw_tri)
    if use_bias: stats("db", db_ref, db_tri)

    # ---- Тайминг end-to-end ----
    warmup = cfg.get("warmup", 50); iters = cfg.get("iters", 500)
    x_b = torch.randn_like(x, requires_grad=True)
    with torch.no_grad():
        m_ref.weight.copy_(m_tri.weight.float())
        if use_bias and (m_tri.bias is not None):
            m_ref.bias.copy_(m_tri.bias.float())
    gy_b = torch.randn_like(m_ref(x_b.float()), dtype=torch.float32)

    def f_ref():
        if x_b.grad is not None: x_b.grad = None
        for p in m_ref.parameters():
            if p.grad is not None: p.grad = None
        y = m_ref(x_b.float())
        (y.float() * gy_b).sum().backward()

    def f_tri():
        if x_b.grad is not None: x_b.grad = None
        for p in m_tri.parameters():
            if p.grad is not None: p.grad = None
        y = m_tri(x_b)
        (y.float() * gy_b).sum().backward()

    t_ref = bench(f_ref, warmup=warmup, iters=iters)
    t_tri = bench(f_tri, warmup=warmup, iters=iters)
    print(f"[{precision_mode}][TIME] Torch: {t_ref*1e3:.3f} ms/it | Triton: {t_tri*1e3:.3f} ms/it | speedup={t_ref/max(t_tri,1e-12):.3f}x")


In [4]:

RUNS = [
        dict(B=2, Cin=64, Cout=128, H=32, W=32, ks=5,
             stride=1, padding=1, dilation=1,
             dtype=torch.float16, use_bias=True, warmup=50, iters=500,
             BLOCK_M=64, BLOCK_N=64, BLOCK_K=32),

        dict(B=2, Cin=64, Cout=128, H=64, W=64, ks=5,
             stride=1, padding=1, dilation=1,
             dtype=torch.float16, use_bias=False, warmup=50, iters=500,
             BLOCK_M=64, BLOCK_N=64, BLOCK_K=32),
    ]
for cfg in RUNS:
    print("\n=== CASE:",
          f"B={cfg['B']} Cin={cfg['Cin']} Cout={cfg['Cout']} H={cfg['H']} W={cfg['W']} ks={cfg['ks']}",
          f"dtype={cfg['dtype']} bias={cfg['use_bias']} ===")
    check_forward_backward_with_precision(cfg, "fp16_runtime")  # W-FP32 (мастер), compute в FP16
    check_forward_backward_with_precision(cfg, "fp16_infer")    # W-FP16 (хранение и compute в FP16)




=== CASE: B=2 Cin=64 Cout=128 H=32 W=32 ks=5 dtype=torch.float16 bias=True ===
[fp16_runtime][FWD] MAE=1.300658e-04 | max|diff|=1.188e-03


AttributeError: 'NoneType' object has no attribute 'detach'

In [None]:
import time, torch, triton
from torch import nn
from conv_gemm.triton_kernels.fp16.img2col_kernel import img2col_kernel as i2c_k
from conv_gemm.triton_kernels.fp16.gemm_kernel    import gemm_kernel    as gemm_k
from conv_gemm.triton_kernels.fp16.col2img_kernel import col2img_kernel as c2i_k

def sync(): torch.cuda.synchronize()

def bench(fn, warmup=50, iters=300):
    for _ in range(warmup): fn()
    sync()
    t0 = time.perf_counter()
    for _ in range(iters): fn()
    sync()
    return (time.perf_counter() - t0) / iters

def out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw):
    Ho = (H + 2*Ph - Dh*(Kh - 1) - 1) // Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw - 1) - 1) // Sw + 1
    return Ho, Wo

def triton_img2col(x, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo, BLOCK_M=64, BLOCK_K=32, fp16=False):
    K = Cin*Kh*Kw
    cols = torch.empty((B*Ho*Wo, K), device=x.device, dtype=(torch.float16 if fp16 else torch.float32))
    sN,sC,sH,sW = x.stride()
    grid = (triton.cdiv(B*Ho*Wo, BLOCK_M), triton.cdiv(K, BLOCK_K))
    i2c_k[grid](
        x, cols, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
        sN,sC,sH,sW, K,
        BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K,
        CAST_FP16=fp16, num_warps=4, num_stages=2,
    )
    return cols

def triton_gemm(A, B, M, N, K, BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, use_fp16=False):
    C = torch.empty((M,N), device=A.device, dtype=torch.float32)  # acc fp16
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    gemm_k[grid](
        A, B, C, M,N,K,
        K,1,  N,1,  N,1,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        USE_FP16=use_fp16, num_warps=4, num_stages=2,
    )
    return C

def triton_col2im(cols, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo, BLOCK_M=64, BLOCK_K=32):
    K = Cin*Kh*Kw; M = B*Ho*Wo
    x = torch.zeros((B,Cin,H,W), device=cols.device, dtype=torch.float32)
    sN,sC,sH,sW = x.stride()
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
    c2i_k[grid](
        cols, x, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
        sN,sC,sH,sW, K,
        BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K,
        num_warps=4, num_stages=2,
    )
    return x


In [None]:
def profile_stages(cfg, precision_mode: str):
    device = "cuda"
    B, Cin, Cout, H, W, ks = cfg["B"], cfg["Cin"], cfg["Cout"], cfg["H"], cfg["W"], cfg["ks"]
    stride, padding, dilation = cfg["stride"], cfg["padding"], cfg["dilation"]
    Sh,Sw = (stride, stride) if isinstance(stride,int) else stride
    Ph,Pw = (padding, padding) if isinstance(padding,int) else padding
    Dh,Dw = (dilation, dilation) if isinstance(dilation,int) else dilation
    Kh=Kw=ks
    Ho, Wo = out_hw(H,W,Kh,Kw,Sh,Sw,Ph,Pw,Dh,Dw)
    L = Ho*Wo; M = B*L; K = Cin*Kh*Kw

    dtype_in = cfg["dtype"]
    x_fp = torch.randn(B, Cin, H, W, device=device, dtype=dtype_in, requires_grad=True)
    W_full = torch.randn(Cout, Cin, Kh, Kw, device=device, dtype=torch.float32)

    # Torch stage ops
    unfold = nn.Unfold(kernel_size=(Kh,Kw), dilation=(Dh,Dw), padding=(Ph,Pw), stride=(Sh,Sw)).to(device)
    fold   = nn.Fold(output_size=(H,W), kernel_size=(Kh,Kw), dilation=(Dh,Dw), padding=(Ph,Pw), stride=(Sh,Sw)).to(device)
    W_mat32 = W_full.view(Cout,-1).t().contiguous()  # [K, Cout]

    # Triton dtypes
    fp16_mode = (precision_mode in ("fp16_runtime","fp16_infer"))
    x_tri = (x_fp.half() if fp16_mode else x_fp.float()).detach().requires_grad_(True)
    W_tri = (W_full.half() if precision_mode=="fp16_infer" else W_full.float())

    # --- FWD: Torch stages ---
    def f_unfold(): unfold(x_fp.float())
    def f_gemm_torch():
        cols_ref = unfold(x_fp.float()).transpose(1,2).contiguous().view(-1,K).float()
        return cols_ref @ W_mat32
    def f_fold():
        cols_ref = unfold(x_fp.float()).transpose(1,2).contiguous().view(-1,K).float()
        y_mat = cols_ref @ W_mat32
        return y_mat.view(B,L,Cout).transpose(1,2).contiguous()

    t_unfold = bench(f_unfold, warmup=30, iters=2000)
    t_gemm_t = bench(f_gemm_torch, warmup=30, iters=2000)
    t_fold   = bench(f_fold, warmup=30, iters=2000)

    # --- FWD: Triton stages ---
    def f_i2c():
        triton_img2col((x_tri.half() if fp16_mode else x_tri.float()),
                       B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                       BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_K=cfg.get("BLOCK_K",32), fp16=fp16_mode)
    def f_gemm_tri():
        cols = triton_img2col((x_tri.half() if fp16_mode else x_tri.float()),
                              B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                              BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_K=cfg.get("BLOCK_K",32), fp16=fp16_mode)
        W_mat = W_tri.view(Cout,-1).t().contiguous()
        if fp16_mode and W_mat.dtype != torch.float16: W_mat = W_mat.half()
        triton_gemm(cols, W_mat, M, Cout, K,
                    BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_N=cfg.get("BLOCK_N",64), BLOCK_K=cfg.get("BLOCK_K",32),
                    use_fp16=fp16_mode)
    def f_c2i():
        cols = triton_img2col((x_tri.half() if fp16_mode else x_tri.float()),
                              B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                              BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_K=cfg.get("BLOCK_K",32), fp16=fp16_mode)
        triton_col2im(cols.float(), B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                      BLOCK_M=cfg.get("BLOCK_M",64), BLOCK_K=cfg.get("BLOCK_K",32))

    t_i2c  = bench(f_i2c,  warmup=30, iters=2000)
    t_gemm = bench(f_gemm_tri, warmup=30, iters=2000)
    t_c2i  = bench(f_c2i,  warmup=30, iters=2000)

    print(f"\n[STAGES {precision_mode}] (ms/it)")
    print(f"Torch   Unfold: {t_unfold*1e3:.3f} | GEMM@: {t_gemm_t*1e3:.3f} | Fold:  {t_fold*1e3:.3f}")
    print(f"Triton  i2c:    {t_i2c*1e3:.3f} | GEMM:   {t_gemm*1e3:.3f} | c2i:   {t_c2i*1e3:.3f}")


In [None]:
profile_stages(cfg, "fp16")
profile_stages(cfg, "fp16_runtime")
profile_stages(cfg, "fp16_infer")

In [5]:
# tests/bench_units_img2col_gemm_col2im.py
import time, torch, triton
from torch import nn

from conv_gemm.triton_kernels.fp16.img2col_kernel import img2col_kernel as i2c_k
from conv_gemm.triton_kernels.fp16.col2img_kernel import col2img_kernel as c2i_k
from conv_gemm.triton_kernels.fp16.gemm_kernel    import triton_gemm

def _force_strict_fp32():
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
_force_strict_fp32()
torch.backends.cudnn.benchmark = True

device = "cuda"
def sync(): torch.cuda.synchronize()

def bench(fn, warmup=50, iters=300):
    for _ in range(warmup): fn()
    sync()
    t0 = time.perf_counter()
    for _ in range(iters): fn()
    sync()
    return (time.perf_counter() - t0) / iters

def out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw):
    Ho = (H + 2*Ph - Dh*(Kh - 1) - 1) // Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw - 1) - 1) // Sw + 1
    return Ho, Wo

def triton_img2col(x, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                   BLOCK_M=64, BLOCK_K=32, fp16=False):
    K = Cin*Kh*Kw
    cols_dtype = torch.float16 if fp16 else torch.float32
    cols = torch.empty((B*Ho*Wo, K), device=x.device, dtype=cols_dtype)
    sN,sC,sH,sW = x.stride()
    grid = (triton.cdiv(B*Ho*Wo, BLOCK_M), triton.cdiv(K, BLOCK_K))
    i2c_k[grid](
        x, cols,
        B,Cin,H,W,
        Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw,
        Ho,Wo,
        sN,sC,sH,sW,
        K,
        BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K,
        CAST_FP16=fp16,
        num_warps=4, num_stages=2
    )
    return cols

def triton_col2im(cols, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                  BLOCK_M=64, BLOCK_K=32):
    K = Cin*Kh*Kw; M = B*Ho*Wo
    x = torch.zeros((B,Cin,H,W), device=cols.device, dtype=torch.float32)
    sN,sC,sH,sW = x.stride()
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
    c2i_k[grid](
        cols, x,
        B,Cin,H,W,
        Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw,
        Ho,Wo,
        sN,sC,sH,sW,
        K,
        BLOCK_M=BLOCK_M, BLOCK_K=BLOCK_K,
        num_warps=4, num_stages=2
    )
    return x

def row(name, items):
    print("| " + name + " | " + " | ".join(items) + " |")

# ===== ГЛАВА 1: ImageToColumn =====
def chapter_img2col(cfg):
    print("\n# Глава 1: ImageToColumn (Unfold vs Triton img2col)")
    B,Cin,Cout,H,W,ks = cfg["B"],cfg["Cin"],cfg["Cout"],cfg["H"],cfg["W"],cfg["ks"]
    stride,padding,dilation = cfg["stride"],cfg["padding"],cfg["dilation"]
    Sh,Sw = (stride,stride) if isinstance(stride,int) else stride
    Ph,Pw = (padding,padding) if isinstance(padding,int) else padding
    Dh,Dw = (dilation,dilation) if isinstance(dilation,int) else dilation
    Kh=Kw=ks
    Ho,Wo = out_hw(H,W,Kh,Kw,Sh,Sw,Ph,Pw,Dh,Dw)
    K = Cin*Kh*Kw

    unfold = nn.Unfold((Kh,Kw), dilation=(Dh,Dw), padding=(Ph,Pw), stride=(Sh,Sw)).to(device)

    for prec in ("fp16","fp16"):
        fp16 = (prec=="fp16")
        dtype = torch.float16 if fp16 else torch.float32

        x = torch.randn(B,Cin,H,W, device=device, dtype=dtype)

        # Forward acc
        with torch.no_grad():
            cols_ref = unfold(x.float()).transpose(1,2).contiguous().view(-1,K).float()
            cols_tri = triton_img2col(x if not fp16 else x.half(),
                                      B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                                      BLOCK_M=64, BLOCK_K=32, fp16=fp16)
            diff = (cols_ref - cols_tri.float()).abs()
            f_mae = diff.mean().item(); f_mx = diff.max().item()

        # Backward эквивалент
        dcols = torch.randn_like(cols_ref, dtype=torch.float32)
        x_ref = x.clone().detach().float().requires_grad_(True)
        cols_ref2 = unfold(x_ref).transpose(1,2).contiguous().view(-1,K)
        (cols_ref2 * dcols).sum().backward()
        dx_ref = x_ref.grad.detach().float()
        dx_tri = triton_col2im(dcols, B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                               BLOCK_M=64, BLOCK_K=32)
        bd = (dx_ref - dx_tri).abs()
        b_mae = bd.mean().item(); b_mx = bd.max().item()

        # Timings
        def f_unfold(): unfold(x.float())
        def f_i2c(): triton_img2col(x if not fp16 else x.half(),
                                    B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                                    BLOCK_M=64, BLOCK_K=32, fp16=fp16)
        t_unfold = bench(f_unfold, warmup=30, iters=200)
        t_i2c    = bench(f_i2c,    warmup=30, iters=200)

        print(f"\n[img2col:{prec}]")
        row("Этап", ["FWD MAE", "FWD max", "BWD MAE(dx)", "BWD max(dx)", "Torch Unfold ms", "Triton i2c ms", "Speedup"])
        row("Знач", [f"{f_mae:.3e}", f"{f_mx:.3e}", f"{b_mae:.3e}", f"{b_mx:.3e}",
                     f"{t_unfold*1e3:.3f}", f"{t_i2c*1e3:.3f}", f"{t_unfold/max(t_i2c,1e-12):.2f}x"])

# ===== ГЛАВА 2: GEMM =====
def chapter_gemm(cfg):
    print("\n# Глава 2: GEMM (torch @ vs Triton gemm)")
    B,Cin,Cout,H,W,ks = cfg["B"],cfg["Cin"],cfg["Cout"],cfg["H"],cfg["W"],cfg["ks"]
    stride,padding,dilation = cfg["stride"],cfg["padding"],cfg["dilation"]
    Sh,Sw = (stride,stride) if isinstance(stride,int) else stride
    Ph,Pw = (padding,padding) if isinstance(padding,int) else padding
    Dh,Dw = (dilation,dilation) if isinstance(dilation,int) else dilation
    Kh=Kw=ks
    Ho,Wo = out_hw(H,W,Kh,Kw,Sh,Sw,Ph,Pw,Dh,Dw)
    M = B*Ho*Wo; K = Cin*Kh*Kw; N = Cout

    for prec in ("fp16","fp16"):
        fp16 = (prec=="fp16")
        A = torch.randn(M,K, device=device, dtype=(torch.float16 if fp16 else torch.float32))
        Bm = torch.randn(K,N, device=device, dtype=(torch.float16 if fp16 else torch.float32))

        # Референс — строгий fp16
        A_ref = A.float().contiguous()
        B_ref = Bm.float().contiguous()

        # Forward acc
        with torch.no_grad():
            C_ref = (A_ref @ B_ref).float()
            C_tri = triton_gemm(A, Bm, use_fp16=fp16)
            dd = (C_ref - C_tri).abs()
            f_mae = dd.mean().item(); f_mx = dd.max().item()

        # Backward (по формулам)
        R = torch.randn_like(C_ref, dtype=torch.float32)
        dA_ref = R @ B_ref.t()
        dB_ref = A_ref.t() @ R

        dA_tri = triton_gemm(R, B_ref.t(), use_fp16=False)
        dB_tri = triton_gemm(A_ref.t(), R, use_fp16=False)

        dA_mae = (dA_ref - dA_tri).abs().mean().item(); dA_mx = (dA_ref - dA_tri).abs().max().item()
        dB_mae = (dB_ref - dB_tri).abs().mean().item(); dB_mx = (dB_ref - dB_tri).abs().max().item()

        # Timings
        def f_torch(): A_ref @ B_ref
        def f_tri():   triton_gemm(A, Bm, use_fp16=fp16)
        t_torch = bench(f_torch, warmup=30, iters=200)
        t_tri   = bench(f_tri,   warmup=30, iters=200)

        print(f"\n[GEMM:{prec}]")
        row("Этап", ["FWD MAE", "FWD max", "dA MAE", "dA max", "dB MAE", "dB max", "Torch mm ms", "Triton GEMM ms", "Speedup"])
        row("Знач", [f"{f_mae:.3e}", f"{f_mx:.3e}", f"{dA_mae:.3e}", f"{dA_mx:.3e}",
                     f"{dB_mae:.3e}", f"{dB_mx:.3e}", f"{t_torch*1e3:.3f}", f"{t_tri*1e3:.3f}",
                     f"{t_torch/max(t_tri,1e-12):.2f}x"])

# ===== ГЛАВА 3: ColumnToImage =====
def chapter_col2im(cfg):
    print("\n# Глава 3: ColumnToImage (Fold vs Triton col2im)")
    B,Cin,Cout,H,W,ks = cfg["B"],cfg["Cin"],cfg["Cout"],cfg["H"],cfg["W"],cfg["ks"]
    stride,padding,dilation = cfg["stride"],cfg["padding"],cfg["dilation"]
    Sh,Sw = (stride,stride) if isinstance(stride,int) else stride
    Ph,Pw = (padding,padding) if isinstance(padding,int) else padding
    Dh,Dw = (dilation,dilation) if isinstance(dilation,int) else dilation
    Kh=Kw=ks
    Ho,Wo = out_hw(H,W,Kh,Kw,Sh,Sw,Ph,Pw,Dh,Dw)
    K = Cin*Kh*Kw

    fold = nn.Fold(output_size=(H,W), kernel_size=(Kh,Kw),
                   dilation=(Dh,Dw), padding=(Ph,Pw), stride=(Sh,Sw)).to(device)

    for prec in ("fp16","fp16"):
        cols = torch.randn(B*Ho*Wo, K, device=device, dtype=torch.float32)
        with torch.no_grad():
            cols_3d = cols.view(B,Ho*Wo,K).transpose(1,2).contiguous()  # [B,K,L]
            x_ref   = fold(cols_3d).float()
            x_tri   = triton_col2im(cols.float(), B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                                    BLOCK_M=64, BLOCK_K=32)
            dd = (x_ref - x_tri).abs()
            f_mae = dd.mean().item(); f_mx = dd.max().item()

        dx = torch.randn_like(x_ref, dtype=torch.float32)
        cols_var = cols_3d.clone().detach().requires_grad_(True)
        (fold(cols_var) * dx).sum().backward()
        dcols_ref = cols_var.grad.detach().view(B, K, Ho*Wo).transpose(1,2).contiguous().view(-1, K).float()
        dcols_tri = triton_img2col(dx.float(), B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                                   BLOCK_M=64, BLOCK_K=32, fp16=False)
        bd = (dcols_ref - dcols_tri).abs()
        b_mae = bd.mean().item(); b_mx = bd.max().item()

        def f_fold():
            c3d = cols.view(B,Ho*Wo,K).transpose(1,2).contiguous()
            fold(c3d)
        def f_c2i():
            triton_col2im(cols.float(), B,Cin,H,W, Kh,Kw, Sh,Sw, Ph,Pw, Dh,Dw, Ho,Wo,
                          BLOCK_M=64, BLOCK_K=32)
        t_fold = bench(f_fold, warmup=30, iters=200)
        t_c2i  = bench(f_c2i,  warmup=30, iters=200)

        print(f"\n[col2im:{prec}]")
        row("Этап", ["FWD MAE", "FWD max", "BWD MAE(dcols)", "BWD max(dcols)", "Torch Fold ms", "Triton c2i ms", "Speedup"])
        row("Знач", [f"{f_mae:.3e}", f"{f_mx:.3e}", f"{b_mae:.3e}", f"{b_mx:.3e}",
                     f"{t_fold*1e3:.3f}", f"{t_c2i*1e3:.3f}", f"{t_fold/max(t_c2i,1e-12):.2f}x"])



RUN = dict(B=2, Cin=64, Cout=128, H=64, W=64, ks=5, stride=1, padding=1, dilation=1)
chapter_img2col(RUN)
chapter_gemm(RUN)
chapter_col2im(RUN)



# Глава 1: ImageToColumn (Unfold vs Triton img2col)

[img2col:fp32]
| Этап | FWD MAE | FWD max | BWD MAE(dx) | BWD max(dx) | Torch Unfold ms | Triton i2c ms | Speedup |
| Знач | 0.000e+00 | 0.000e+00 | 4.188e-07 | 5.722e-06 | 0.128 | 0.140 | 0.91x |

[img2col:fp16]
| Этап | FWD MAE | FWD max | BWD MAE(dx) | BWD max(dx) | Torch Unfold ms | Triton i2c ms | Speedup |
| Знач | 0.000e+00 | 0.000e+00 | 4.186e-07 | 7.629e-06 | 0.119 | 0.108 | 1.11x |

# Глава 2: GEMM (torch @ vs Triton gemm)

[GEMM:fp32]
| Этап | FWD MAE | FWD max | dA MAE | dA max | dB MAE | dB max | Torch mm ms | Triton GEMM ms | Speedup |
| Знач | 9.366e-03 | 6.069e-02 | 2.635e-03 | 2.032e-02 | 2.057e-02 | 1.216e-01 | 0.231 | 0.319 | 0.73x |

[GEMM:fp16]
| Этап | FWD MAE | FWD max | dA MAE | dA max | dB MAE | dB max | Torch mm ms | Triton GEMM ms | Speedup |
| Знач | 7.932e-05 | 5.951e-04 | 1.866e-03 | 1.331e-02 | 1.455e-02 | 8.316e-02 | 0.251 | 0.111 | 2.26x |

# Глава 3: ColumnToImage (Fold vs Triton col2im)

[col2im:fp

In [6]:
print("allow_tf32:", torch.backends.cuda.matmul.allow_tf32)
try: print("f32 matmul precision:", torch.get_float32_matmul_precision())
except: pass

allow_tf32: True
f32 matmul precision: high
