In [2]:
from col2img_int8_kernel import col2img_int32_kernel
import time
import pandas as pd
import torch
import torch.nn.functional as F
import triton
import triton.language as tl

In [3]:
# 1. Обёртка над col2img_int32_kernel
# -----------------------------
@torch.no_grad()
def col2img_int32(
    cols_i32: torch.Tensor,  # [M, K] int32, M = N*Ho*Wo, K = Cin*Kh*Kw
    N: int, Cin: int, H: int, W: int,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int,
    BLOCK_K: int,
    num_warps: int = 4,
    num_stages: int = 2,
):
    """
    Обёртка для col2img_int32_kernel.
    Восстанавливает x_i32: [N, Cin, H, W] из cols_i32[int32][M,K].
    """
    assert cols_i32.is_cuda
    assert cols_i32.dtype == torch.int32
    cols_i32 = cols_i32.contiguous()

    # считаем Ho, Wo так же, как в img2col
    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    M = N * Ho * Wo
    K = Cin * Kh * Kw

    assert cols_i32.shape == (M, K), f"cols shape {cols_i32.shape}, expected {(M, K)}"

    x_i32 = torch.zeros((N, Cin, H, W), device=cols_i32.device, dtype=torch.int32)
    sN, sC, sH, sW = x_i32.stride()

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))

    col2img_int32_kernel[grid](
        cols_i32, x_i32,
        N, Cin, H, W,
        Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
        Ho, Wo,
        sN, sC, sH, sW,
        K,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return x_i32

# title search

In [None]:
@torch.no_grad()
def tune_col2img_int32_tiles_for_shape(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters: int = 50,
    device: str = "cuda",
):
    """
    Перебираем BLOCK_M / BLOCK_K / num_warps / num_stages
    для одного фиксированного shape, возвращаем DataFrame
    со всеми результатами.
    Ты потом вручную выберешь лучшие тайлы.
    """
    records = []

    for BM in blocks_M:
        for BK in blocks_K:
            for W_ in warps:
                for S_ in stages:
                    try:
                        rec = bench_once_col2img_vs_fold(
                            N, Cin, H, W,
                            Kh, Kw,
                            Sh, Sw,
                            Ph, Pw,
                            Dh, Dw,
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=W_,
                            num_stages=S_,
                            iters=iters,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] BM={BM}, BK={BK}, W={W_}, S={S_}: {e}")
                        continue

                    print(
                        f"BM={BM}, BK={BK}, W={W_}, S={S_}: "
                        f"t_triton={rec['t_triton_ms']:.3f} ms, "
                        f"speed_vs_fold={rec['speed_vs_fold']:.3f}x, "
                        f"err={rec['max_abs_err']}"
                    )
                    records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this col2img_int32 shape")

    df = pd.DataFrame(records)
    # можно сразу отсортировать по времени, чтобы топ был вверху
    df = df.sort_values("t_triton_ms").reset_index(drop=True)
    return df

# bench

In [4]:
# 2. Бенч ОДНОГО shape: col2img_int32 vs F.fold
# -----------------------------
@torch.no_grad()
def bench_once_col2img_vs_fold(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters: int = 50,
    device: str = "cuda",
):
    """
    Бенч для ОДНОГО shape и ОДНОГО набора тайлов:
      - генерим cols_i32 [M,K] int32
      - наш col2img_int32 -> x_i32 [N,Cin,H,W]
      - реф: F.fold (fp16) -> x_ref -> int32
      - меряем:
        t_triton_ms, t_fold_ms,
        speed_vs_fold,
        max_abs_err,
        bw_triton_GBs, bw_fold_GBs
    """
    # считаем Ho, Wo
    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    if Ho <= 0 or Wo <= 0:
        raise RuntimeError(f"Invalid Ho/Wo: Ho={Ho}, Wo={Wo}")

    M = N * Ho * Wo
    K = Cin * Kh * Kw

    # случайные int32 "колонки"
    cols_i32 = torch.randint(
        low=-1000, high=1000,
        size=(M, K),
        device=device,
        dtype=torch.int32,
    )

    # ---------- наш col2img_int32 ----------
    def _call_triton():
        x_i32 = col2img_int32(
            cols_i32,
            N, Cin, H, W,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return x_i32

    # прогрев
    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_i32 = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # ---------- референс через F.fold ----------
    # cols_i32: [M,K] -> [N, L, K] -> [N, K, L]
    L = Ho * Wo
    assert M == N * L
    cols_view = cols_i32.view(N, L, K)          # [N,L,K]
    inp_fold = cols_view.permute(0, 2, 1).contiguous()  # [N,K,L]

    inp_fold_f = inp_fold.float()
    # прогрев
    def _call_fold():
        x_f = F.fold(
            inp_fold_f,
            output_size=(H, W),
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )  # [N,Cin,H,W] fp16
        return x_f

    for _ in range(5):
        _ = _call_fold()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_ref_f = _call_fold()
    torch.cuda.synchronize()
    t_fold = (time.perf_counter() - t0) / iters

    x_ref_i32 = x_ref_f.to(torch.int32)

    # ---------- ошибка ----------
    max_abs_err = (x_i32 - x_ref_i32).abs().max().item()

    # ---------- BW (грубо) ----------
    # читаем cols (int32) и пишем x (int32)
    bytes_moved = cols_i32.numel() * 4 + x_i32.numel() * 4
    bw_triton = bytes_moved / t_triton / 1e9
    bw_fold = bytes_moved / t_fold / 1e9

    return {
        "N": N,
        "Cin": Cin,
        "H": H, "W": W,
        "Kh": Kh, "Kw": Kw,
        "BLOCK_M": BLOCK_M,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_fold_ms": t_fold * 1e3,
        "speed_vs_fold": t_fold / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_fold_GBs": bw_fold,
        "max_abs_err": max_abs_err,
    }


In [6]:
@torch.no_grad()
def benchmark_col2img_int32_best_tiles(
    tiles_cfg: dict,
    image_sizes=(112, 256, 512, 1024),
    batch_sizes=(1, 2, 4),
    channel_pairs=((1, 3), (3, 3), (3, 8), (8, 16), (32, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters_per_shape: int = 50,
    device: str = "cuda",
):
    """
    Сравнение col2img_int32 vs F.fold на сетке параметров
    для фиксированных лучших тайлов tiles_cfg.

    tiles_cfg:
      - "BLOCK_M"
      - "BLOCK_K"
      - "num_warps"
      - "num_stages"

    channel_pairs: (Cin, Cout) — Cout используется только для логов
    (по аналогии с img2col).
    """
    BM = tiles_cfg["BLOCK_M"]
    BK = tiles_cfg["BLOCK_K"]
    NW = tiles_cfg["num_warps"]
    NS = tiles_cfg["num_stages"]

    records = []

    for H in image_sizes:
        W = H
        for N in batch_sizes:
            for (Cin, Cout) in channel_pairs:
                for k in kernels:
                    Ph = Pw = k // 2
                    print(f"=== SHAPE: N={N}, Cin={Cin}, Cout={Cout}, H=W={H}, k={k}x{k} ===")
                    try:
                        rec = bench_once_col2img_vs_fold(
                            N, Cin, H, W,
                            k, k,
                            1, 1,       # Sh, Sw
                            Ph, Pw,     # padding
                            1, 1,       # Dh, Dw
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=NW,
                            num_stages=NS,
                            iters=iters_per_shape,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] N={N}, Cin={Cin}, Cout={Cout}, H={H}, k={k}: {e}")
                        continue

                    if rec["max_abs_err"] != 0:
                        print(f"[WRONG] N={N}, Cin={Cin}, Cout={Cout}, H={H}, k={k}, err={rec['max_abs_err']}")
                        continue

                    rec["Cout"] = Cout
                    records.append(rec)

    if not records:
        print("[WARN] no successful records for col2img_int32")
        return pd.DataFrame()

    df = pd.DataFrame(records)
    return df

In [7]:
# 1) Тюним тайлы для одного shape, например:
df_tiles = tune_col2img_int32_tiles_for_shape(
    N=4, Cin=64, H=224, W=224,
    Kh=3, Kw=3,
    Sh=1, Sw=1,
    Ph=1, Pw=1,
    Dh=1, Dw=1,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=50,
    device="cuda",
)

# смотри top-10


BM=32, BK=32, W=2, S=2: t_triton=2.340 ms, speed_vs_fold=1.159x, err=0
BM=32, BK=32, W=2, S=3: t_triton=2.287 ms, speed_vs_fold=1.198x, err=0
BM=32, BK=32, W=4, S=2: t_triton=1.471 ms, speed_vs_fold=1.864x, err=0
BM=32, BK=32, W=4, S=3: t_triton=1.454 ms, speed_vs_fold=1.883x, err=0
BM=32, BK=32, W=8, S=2: t_triton=1.566 ms, speed_vs_fold=1.802x, err=0
BM=32, BK=32, W=8, S=3: t_triton=1.563 ms, speed_vs_fold=1.757x, err=0
BM=32, BK=64, W=2, S=2: t_triton=4.805 ms, speed_vs_fold=0.566x, err=0
BM=32, BK=64, W=2, S=3: t_triton=4.771 ms, speed_vs_fold=0.579x, err=0
BM=32, BK=64, W=4, S=2: t_triton=2.355 ms, speed_vs_fold=1.182x, err=0
BM=32, BK=64, W=4, S=3: t_triton=2.424 ms, speed_vs_fold=1.131x, err=0
BM=32, BK=64, W=8, S=2: t_triton=1.502 ms, speed_vs_fold=1.826x, err=0
BM=32, BK=64, W=8, S=3: t_triton=1.491 ms, speed_vs_fold=1.815x, err=0
BM=32, BK=128, W=2, S=2: t_triton=5.348 ms, speed_vs_fold=0.513x, err=0
BM=32, BK=128, W=2, S=3: t_triton=5.417 ms, speed_vs_fold=0.515x, err=0
BM=3

In [8]:
df_tiles.head(10)

Unnamed: 0,N,Cin,H,W,Kh,Kw,BLOCK_M,BLOCK_K,num_warps,num_stages,t_triton_ms,t_fold_ms,speed_vs_fold,bw_triton_GBs,bw_fold_GBs,max_abs_err
0,4,64,224,224,3,3,32,32,4,3,1.45405,2.737633,1.882764,353.359325,187.68117,0
1,4,64,224,224,3,3,32,32,4,2,1.470514,2.740385,1.863555,349.403066,187.492748,0
2,4,64,224,224,3,3,64,32,8,3,1.488509,2.779814,1.867516,345.179194,184.833319,0
3,4,64,224,224,3,3,32,64,8,3,1.490962,2.706579,1.815324,344.611229,189.834581,0
4,4,64,224,224,3,3,32,64,8,2,1.502034,2.742225,1.825674,342.07091,187.366899,0
5,4,64,224,224,3,3,64,32,8,2,1.530844,2.730468,1.783636,335.633254,188.173661,0
6,4,64,224,224,3,3,32,32,8,3,1.562804,2.746174,1.75721,328.769548,187.097474,0
7,4,64,224,224,3,3,32,32,8,2,1.566289,2.82306,1.802388,328.038034,182.001881,0
8,4,64,224,224,3,3,32,32,2,3,2.286667,2.739539,1.198049,224.694822,187.550633,0
9,4,64,224,224,3,3,64,32,4,2,2.303792,2.73645,1.187803,223.024608,187.762355,0


In [9]:
BEST_COL2IMG_INT32_TILES = {
    "BLOCK_M": 32,
    "BLOCK_K": 32,
    "num_warps": 4,
    "num_stages": 3,
}

In [10]:
df_global = benchmark_col2img_int32_best_tiles(
    BEST_COL2IMG_INT32_TILES,
    image_sizes=(112, 256, 512, 1024),
    batch_sizes=(1, 2, 4),
    channel_pairs=((1, 3), (3, 3), (3, 8), (8, 16), (32, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters_per_shape=50,
    device="cuda",
)

=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=8, Cout=16, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=8, Cout=16, H=W=112, k=

In [11]:
df_global.sort_values("speed_vs_fold", ascending=False).head(40)

Unnamed: 0,N,Cin,H,W,Kh,Kw,BLOCK_M,BLOCK_K,num_warps,num_stages,t_triton_ms,t_fold_ms,speed_vs_fold,bw_triton_GBs,bw_fold_GBs,max_abs_err,Cout
26,1,32,112,112,5,5,32,32,4,3,0.119208,0.28715,2.408815,350.19877,145.38215,0,32
21,1,8,112,112,7,7,32,32,4,3,0.076214,0.174012,2.283196,263.341795,115.339119,0,16
95,1,1,256,256,11,11,32,32,4,3,0.114009,0.255653,2.242389,280.517875,125.097771,0,3
28,1,32,112,112,9,9,32,32,4,3,0.337929,0.73058,2.161935,389.613981,180.21536,0,32
69,4,3,112,112,7,7,32,32,4,3,0.104523,0.224446,2.147335,288.028748,134.133123,0,3
199,1,32,512,512,1,1,32,32,4,3,0.279496,0.592216,2.118872,240.106912,113.318283,0,32
52,2,8,112,112,9,9,32,32,4,3,0.193782,0.410003,2.115798,339.716727,160.56197,0,16
22,1,8,112,112,9,9,32,32,4,3,0.103478,0.216384,2.091125,318.092819,152.115619,0,16
204,2,1,512,512,5,5,32,32,4,3,0.17607,0.367422,2.0868,309.68415,148.401459,0,3
115,1,32,256,256,3,3,32,32,4,3,0.236362,0.482924,2.043159,354.905764,173.704445,0,32
