In [5]:
import triton
import triton.language as tl
import torch
from img2col_int8_kernel import img2col_int8_kernel

In [6]:
INT8_I2C_BLOCK_M = 64
INT8_I2C_BLOCK_K = 32


In [7]:
def img2col_int8(
    x_q: torch.Tensor,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int = INT8_I2C_BLOCK_M,
    BLOCK_K: int = INT8_I2C_BLOCK_K,
    num_warps: int = 4,
    num_stages: int = 2,
):
    """
    x_q: [N, Cin, H, W], int8
    return cols_q: [M, K], int8, (Ho, Wo)
    """
    assert x_q.is_cuda and x_q.dtype == torch.int8
    N, Cin, H, W = x_q.shape

    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    M = N * Ho * Wo
    K = Cin * Kh * Kw

    cols_q = torch.empty((M, K), device=x_q.device, dtype=torch.int8)
    sN, sC, sH, sW = x_q.stride()

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))
    img2col_int8_kernel[grid](
        x_q, cols_q,
        N, Cin, H, W,
        Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
        Ho, Wo,
        sN, sC, sH, sW,
        K,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    return cols_q, (Ho, Wo)

In [8]:
# BLOKS

In [9]:
@torch.no_grad()
def bench_tile_once_img2col(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=50,
    device="cuda",
):
    """
    Бенч ОДНОГО конфига тайлов:
      - замеряем ТОЛЬКО время img2col_int8
      - для корректности один раз сравниваем с F.unfold (но не меряем его время)
    """
    x_q = torch.randint(-128, 127, (N, Cin, H, W),
                        device=device, dtype=torch.int8)

    # ---- ref через unfold (только для проверки, без тайминга) ----
    x_f = x_q.float()
    unf = F.unfold(
        x_f,
        kernel_size=(Kh, Kw),
        dilation=(Dh, Dw),
        padding=(Ph, Pw),
        stride=(Sh, Sw),
    )  # [N, KC, L]

    N_, KC, L = unf.shape
    assert N_ == N
    assert KC == Cin * Kh * Kw

    M = N * L          # ВАЖНО: M = N * L, а не L
    K = KC             # == Cin * Kh * Kw

    unf_N_L_KC = unf.permute(0, 2, 1).contiguous()   # [N, L, KC]
    cols_ref = unf_N_L_KC.view(M, K)                 # [M, K]
    cols_ref_q = cols_ref.to(torch.int8)

    # ---- наш img2col_int8 ----
    def _call_triton():
        cols_q, (Ho, Wo) = img2col_int8(
            x_q,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return cols_q

    # warmup
    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        cols_q = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # корректность (один раз, после тайминга)
    max_abs_err = (cols_q.to(torch.int16) - cols_ref_q.to(torch.int16)).abs().max().item()

    # грубая оценка памяти: 2*M*K байт (load + store int8)
    bytes_moved = 2.0 * M * K
    bw_triton = bytes_moved / t_triton / 1e9  # GB/s

    return {
        "N": N,
        "Cin": Cin,
        "H": H, "W": W,
        "Kh": Kh, "Kw": Kw,
        "BLOCK_M": BLOCK_M,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "bw_triton_GBs": bw_triton,
        "max_abs_err": max_abs_err,
    }


In [10]:
@torch.no_grad()
def bench_once_img2col_vs_unfold(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=50,
    device="cuda",
):
    """
    Бенч для ОДНОГО shape и ОДНОГО набора тайлов:
      - меряем t_triton_ms и t_unfold_ms
      - считаем speed_vs_unfold
      - проверяем max_abs_err
    """
    x_q = torch.randint(-128, 127, (N, Cin, H, W),
                        device=device, dtype=torch.int8)

    # ref
    x_f = x_q.float()
    unf = F.unfold(
        x_f,
        kernel_size=(Kh, Kw),
        dilation=(Dh, Dw),
        padding=(Ph, Pw),
        stride=(Sh, Sw),
    )  # [N, KC, L]

    N_, KC, L = unf.shape
    assert N_ == N
    assert KC == Cin * Kh * Kw

    M = N * L          # ВАЖНО: M = N * L
    K = KC

    unf_N_L_KC = unf.permute(0, 2, 1).contiguous()   # [N,L,KC]
    cols_ref = unf_N_L_KC.view(M, K)                 # [M,K]
    cols_ref_q = cols_ref.to(torch.int8)

    # наш
    def _call_triton():
        cols_q, (Ho, Wo) = img2col_int8(
            x_q,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return cols_q

    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        cols_q = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # unfold
    def _call_unfold():
        unf = F.unfold(
            x_f,
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )
        unf_N_L_KC = unf.permute(0, 2, 1).contiguous()  # [N,L,KC]
        cols = unf_N_L_KC.view(M, K)                    # [M,K]
        return cols

    for _ in range(5):
        _ = _call_unfold()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        cols_ref2 = _call_unfold()
    torch.cuda.synchronize()
    t_unfold = (time.perf_counter() - t0) / iters

    # ошибка
    max_abs_err = (cols_q.to(torch.int16) - cols_ref_q.to(torch.int16)).abs().max().item()

    bytes_moved = 2.0 * M * K
    bw_triton = bytes_moved / t_triton / 1e9
    bw_unfold = bytes_moved / t_unfold / 1e9

    return {
        "N": N,
        "Cin": Cin,
        "H": H, "W": W,
        "Kh": Kh, "Kw": Kw,
        "BLOCK_M": BLOCK_M,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_unfold_ms": t_unfold * 1e3,
        "speed_vs_unfold": t_unfold / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_unfold_GBs": bw_unfold,
        "max_abs_err": max_abs_err,
    }


In [11]:
best_tiles, df_tiles = tune_tiles_for_one_shape(
    N=4, Cin=64, H=224, W=224,
    Kh=3, Kw=3,
    Sh=1, Sw=1,
    Ph=1, Pw=1,
    Dh=1, Dw=1,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=500,
    device="cuda",
)
print("BEST TILES:", best_tiles)


NameError: name 'tune_tiles_for_one_shape' is not defined

In [31]:
INT8_I2C_BEST_BLOCK_M = 128
INT8_I2C_BEST_BLOCK_K = 32
INT8_I2C_BEST_WARPS   = 2
INT8_I2C_BEST_STAGES  = 3

# BENCH

In [None]:
@torch.no_grad()
def bench_once_img2col_vs_unfold(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=50,
    device="cuda",
):
    """
    Бенч для ОДНОГО shape и ОДНОГО набора тайлов:
      - меряем t_triton_ms и t_unfold_ms
      - считаем speed_vs_unfold
      - проверяем max_abs_err
    """
    x_q = torch.randint(-128, 127, (N, Cin, H, W),
                        device=device, dtype=torch.int8)

    # ref
    x_f = x_q.float()
    unf = F.unfold(
        x_f,
        kernel_size=(Kh, Kw),
        dilation=(Dh, Dw),
        padding=(Ph, Pw),
        stride=(Sh, Sw),
    )  # [N, KC, L]

    N_, KC, L = unf.shape
    assert N_ == N
    assert KC == Cin * Kh * Kw

    M = N * L          # ВАЖНО: M = N * L
    K = KC

    unf_N_L_KC = unf.permute(0, 2, 1).contiguous()   # [N,L,KC]
    cols_ref = unf_N_L_KC.view(M, K)                 # [M,K]
    cols_ref_q = cols_ref.to(torch.int8)

    # наш
    def _call_triton():
        cols_q, (Ho, Wo) = img2col_int8(
            x_q,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return cols_q

    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        cols_q = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # unfold
    def _call_unfold():
        unf = F.unfold(
            x_f,
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )
        unf_N_L_KC = unf.permute(0, 2, 1).contiguous()  # [N,L,KC]
        cols = unf_N_L_KC.view(M, K)                    # [M,K]
        return cols

    for _ in range(5):
        _ = _call_unfold()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        cols_ref2 = _call_unfold()
    torch.cuda.synchronize()
    t_unfold = (time.perf_counter() - t0) / iters

    # ошибка
    max_abs_err = (cols_q.to(torch.int16) - cols_ref_q.to(torch.int16)).abs().max().item()

    bytes_moved = 2.0 * M * K
    bw_triton = bytes_moved / t_triton / 1e9
    bw_unfold = bytes_moved / t_unfold / 1e9

    return {
        "N": N,
        "Cin": Cin,
        "H": H, "W": W,
        "Kh": Kh, "Kw": Kw,
        "BLOCK_M": BLOCK_M,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_unfold_ms": t_unfold * 1e3,
        "speed_vs_unfold": t_unfold / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_unfold_GBs": bw_unfold,
        "max_abs_err": max_abs_err,
    }


In [25]:
import pandas as pd
import torch

@torch.no_grad()
def benchmark_img2col_best_tiles(
    tiles_cfg,
    image_sizes=(112, 1024),
    batch_sizes=(1, 2, 4),
    channel_pairs=((1, 3), (3, 3), (3, 8), (8, 16), (32, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters_per_shape=20,
    device="cuda",
):
    """
    Сравнение img2col_int8 vs F.unfold на сетке параметров,
    с фиксированными лучшими тайлами tiles_cfg.

    tiles_cfg: dict с ключами:
      - BLOCK_M
      - BLOCK_K
      - num_warps
      - num_stages
    """
    BM = 128
    BK = 32
    NW = 2
    NS = 3

    records = []

    for H in image_sizes:
        W = H
        for N in batch_sizes:
            for (Cin, Cout) in channel_pairs:
                for k in kernels:
                    Ph = Pw = k // 2
                    print(f"=== SHAPE: N={N}, Cin={Cin}, Cout={Cout}, H=W={H}, k={k}x{k} ===")

                    try:
                        rec = bench_once_img2col_vs_unfold(
                            N, Cin, H, W,
                            k, k,
                            1, 1,       # Sh, Sw
                            Ph, Pw,     # padding
                            1, 1,       # Dh, Dw
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=NW,
                            num_stages=NS,
                            iters=iters_per_shape,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] N={N}, Cin={Cin}, Cout={Cout}, H={H}, k={k}: {e}")
                        continue

                    # проверка корректности
                    if rec["max_abs_err"] != 0:
                        print(f"[WRONG] N={N}, Cin={Cin}, Cout={Cout}, H={H}, k={k}, err={rec['max_abs_err']}")
                        continue

                    # добавляем Cout в запись (bench_once его не знает)
                    rec["Cout"] = Cout
                    records.append(rec)

    if not records:
        print("[WARN] no successful records")
        return pd.DataFrame()

    df = pd.DataFrame(records)


    return df


In [29]:
df_global = benchmark_img2col_best_tiles(
    best_tiles,
    image_sizes=(112, 256, 512, 1024),
    batch_sizes=(1, 2, 4),
    channel_pairs=((1, 3), (3, 3), (3, 8), (8, 16), (32, 32)),
    kernels=(1, 3, 5, 7, 9, 11),
    iters_per_shape=500,
    device="cuda",
)


=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=1, Cout=3, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=3, Cout=3, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=3x3 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=5x5 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=7x7 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=9x9 ===
=== SHAPE: N=1, Cin=3, Cout=8, H=W=112, k=11x11 ===
=== SHAPE: N=1, Cin=8, Cout=16, H=W=112, k=1x1 ===
=== SHAPE: N=1, Cin=8, Cout=16, H=W=112, k=

In [30]:
df_global.sort_values("speed_vs_unfold", ascending=False).head(5)

Unnamed: 0,N,Cin,H,W,Kh,Kw,BLOCK_M,BLOCK_K,num_warps,num_stages,t_triton_ms,t_unfold_ms,speed_vs_unfold,bw_triton_GBs,bw_unfold_GBs,max_abs_err,Cout
191,1,32,512,512,3,3,128,32,2,3,0.448237,7.392587,16.492597,336.864315,20.425183,0,32
262,4,32,1024,1024,1,1,128,32,2,3,0.513139,8.459648,16.486076,523.124304,31.73128,0,32
226,4,32,512,512,1,1,128,32,2,3,0.136437,2.006526,14.706562,491.8653,33.445296,0,32
253,2,32,1024,1024,1,1,128,32,2,3,0.263149,3.779029,14.360815,510.045315,35.516461,0,32
209,2,32,512,512,1,1,128,32,2,3,0.07324,0.987524,13.483335,458.141601,33.978359,0,32
