In [27]:
import sys, pathlib
sys.path.insert(0, str(pathlib.Path().resolve().parent.parent))

In [28]:
import time
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import pandas as pd
from conv_gemm.triton_kernels.fp16.col2img_kernel import col2img_kernel

In [29]:
def col2img_fp32(
    cols_f32: torch.Tensor,
    N: int, Cin: int,
    H: int, W: int,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int,
    BLOCK_K: int,
    num_warps: int = 4,
    num_stages: int = 2,
):
    assert cols_f32.is_cuda
    assert cols_f32.dtype == torch.float32
    cols_f32 = cols_f32.contiguous()

    Ho = (H + 2*Ph - Dh*(Kh-1) - 1)//Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw-1) - 1)//Sw + 1
    M  = N * Ho * Wo
    K  = Cin * Kh * Kw

    assert cols_f32.shape == (M, K), f"cols shape {cols_f32.shape}, expected {(M, K)}"

    x_f32 = torch.zeros((N, Cin, H, W), device=cols_f32.device, dtype=torch.float32)
    sN, sC, sH, sW = x_f32.stride()

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(K, BLOCK_K))

    col2img_kernel[grid](
        cols_f32, x_f32,
        N, Cin, H, W,
        Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
        Ho, Wo,
        sN, sC, sH, sW,
        K,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return x_f32

In [34]:
@torch.no_grad()
def bench_once_col2img_fp32_vs_torch(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    BLOCK_M,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):
    """
    FP32 Triton col2img vs torch F.fold (FP32 и FP16)
    """
    Ho = (H + 2*Ph - Dh*(Kh-1) - 1)//Sh + 1
    Wo = (W + 2*Pw - Dw*(Kw-1) - 1)//Sw + 1
    M  = N * Ho * Wo
    K  = Cin * Kh * Kw

    cols_f32 = torch.randn((M, K), device=device, dtype=torch.float32)

    # Torch FP32 (baseline)
    cols_fold_f32 = cols_f32.view(N, Ho*Wo, K).permute(0, 2, 1).contiguous()

    def _call_torch_f32():
        return F.fold(
            cols_fold_f32,
            output_size=(H, W),
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )

    for _ in range(5):
        _ = _call_torch_f32()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_ref_f32 = _call_torch_f32()
    torch.cuda.synchronize()
    t_torch32 = (time.perf_counter() - t0) / iters

    # Torch FP16 (baseline)
    cols_f16 = cols_f32.half()
    cols_fold_f16 = cols_f16.view(N, Ho*Wo, K).permute(0, 2, 1).contiguous()

    def _call_torch_f16():
        return F.fold(
            cols_fold_f16,
            output_size=(H, W),
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        )

    for _ in range(5):
        _ = _call_torch_f16()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_ref_f16 = _call_torch_f16()
    torch.cuda.synchronize()
    t_torch16 = (time.perf_counter() - t0) / iters

    # Triton col2img (FP32)
    def _call_triton():
        return col2img_fp32(
            cols_f32,
            N, Cin, H, W,
            Kh, Kw,
            Sh, Sw,
            Ph, Pw,
            Dh, Dw,
            BLOCK_M=BLOCK_M,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )

    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(iters):
        x_triton = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # Errors
    # Triton FP32 vs Torch FP32
    diff32 = (x_triton - x_ref_f32).abs()
    max_abs_err32 = diff32.max().item()
    mean_abs_err32 = diff32.mean().item()

    # Triton FP32 vs Torch FP16
    diff16 = (x_triton - x_ref_f16.float()).abs()
    max_abs_err16 = diff16.max().item()
    mean_abs_err16 = diff16.mean().item()

    # Bandwidth
    bytes_triton = (cols_f32.numel() + x_triton.numel()) * 4.0  # FP32
    bw_triton = bytes_triton / t_triton / 1e9

    bytes_torch32 = (cols_fold_f32.numel() + x_ref_f32.numel()) * 4.0
    bw_torch32 = bytes_torch32 / t_torch32 / 1e9

    bytes_torch16 = (cols_fold_f16.numel() + x_ref_f16.numel()) * 2.0
    bw_torch16 = bytes_torch16 / t_torch16 / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": 0,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        # Triton
        "t_triton_ms": t_triton * 1e3,
        "bw_triton_GBs": bw_triton,
        # Torch FP32 с явной пометкой
        "t_torch32_ms": t_torch32 * 1e3,
        "speed_vs_torch32": t_torch32 / t_triton,
        "bw_torch32_GBs": bw_torch32,
        "max_abs_err32": max_abs_err32,
        "mean_abs_err32": mean_abs_err32,
        # Torch FP16
        "t_torch16_ms": t_torch16 * 1e3,
        "speed_vs_torch16": t_torch16 / t_triton,
        "bw_torch16_GBs": bw_torch16,
        "max_abs_err16": max_abs_err16,
        "mean_abs_err16": mean_abs_err16,
    }


In [35]:
@torch.no_grad()
def tune_col2img_fp16_tiles_for_shape(
    N, Cin, H, W,
    Kh, Kw,
    Sh, Sw,
    Ph, Pw,
    Dh, Dw,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(1, 2, 4, 8),
    stages=(2, 3, 4),
    iters=200,
    device="cuda",
):
    records = []
    for BM in blocks_M:
        for BK in blocks_K:
            for Wp in warps:
                for S in stages:
                    try:
                        rec = bench_once_col2img_fp32_vs_torch(
                            N, Cin, H, W,
                            Kh, Kw,
                            Sh, Sw,
                            Ph, Pw,
                            Dh, Dw,
                            BLOCK_M=BM,
                            BLOCK_K=BK,
                            num_warps=Wp,
                            num_stages=S,
                            iters=iters,
                            device=device,
                        )
                    except RuntimeError as e:
                        print(f"[SKIP] BM={BM}, BK={BK}, W={Wp}, S={S}: {e}")
                        continue

                    print(
                        f"BM={BM}, BK={BK}, W={Wp}, S={S}: "
                        f"t_triton={rec['t_triton_ms']:.3f} ms, "
                        f"speed_vs_torch={rec['speed_vs_torch16']:.3f}x, "
                    )
                    records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this COL2IMG shape")

    df = pd.DataFrame(records)
    return df


In [36]:
df_c2i = tune_col2img_fp16_tiles_for_shape(
    N=16, Cin=1, H=256, W=256,
    Kh=11, Kw=11,
    Sh=1, Sw=1,
    Ph=5, Pw=5,
    Dh=1, Dw=1,
    blocks_M=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)

BM=32, BK=32, W=2, S=2: t_triton=2.062 ms, speed_vs_torch=1.095x, 
BM=32, BK=32, W=2, S=3: t_triton=2.052 ms, speed_vs_torch=1.083x, 
BM=32, BK=32, W=4, S=2: t_triton=1.684 ms, speed_vs_torch=1.389x, 
BM=32, BK=32, W=4, S=3: t_triton=1.585 ms, speed_vs_torch=1.434x, 
BM=32, BK=32, W=8, S=2: t_triton=1.776 ms, speed_vs_torch=1.346x, 
BM=32, BK=32, W=8, S=3: t_triton=1.762 ms, speed_vs_torch=1.286x, 
BM=32, BK=64, W=2, S=2: t_triton=4.582 ms, speed_vs_torch=0.486x, 
BM=32, BK=64, W=2, S=3: t_triton=4.621 ms, speed_vs_torch=0.495x, 
BM=32, BK=64, W=4, S=2: t_triton=2.021 ms, speed_vs_torch=1.124x, 
BM=32, BK=64, W=4, S=3: t_triton=2.028 ms, speed_vs_torch=1.107x, 
BM=32, BK=64, W=8, S=2: t_triton=1.678 ms, speed_vs_torch=1.359x, 
BM=32, BK=64, W=8, S=3: t_triton=1.755 ms, speed_vs_torch=1.303x, 
BM=32, BK=128, W=2, S=2: t_triton=6.161 ms, speed_vs_torch=0.371x, 
BM=32, BK=128, W=2, S=3: t_triton=6.184 ms, speed_vs_torch=0.362x, 
BM=32, BK=128, W=4, S=2: t_triton=5.162 ms, speed_vs_torch=0

In [37]:
cols = [
    "BLOCK_M", "BLOCK_K",
    "num_warps", "num_stages",
    "t_triton_ms", 
    "t_torch16_ms", 
    "speed_vs_torch16",
    "mean_abs_err16",
]

df_c2i_filtered = df_c2i[cols].sort_values("speed_vs_torch16", ascending=False).head(5).reset_index(drop=True)
df_c2i_filtered

Unnamed: 0,BLOCK_M,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch16_ms,speed_vs_torch16,mean_abs_err16
0,32,32,4,3,1.584996,2.272959,1.434047,0.002471
1,64,32,8,3,1.762368,2.491867,1.413931,0.002469
2,32,32,4,2,1.683654,2.338107,1.38871,0.002469
3,32,64,8,2,1.677512,2.279804,1.359039,0.002469
4,32,32,8,2,1.776147,2.390902,1.346118,0.00247


In [8]:
FP32_COL2IMG_BEST_BLOCK_M = 32
FP32_COL2IMG_BEST_BLOCK_N = 0
FP32_COL2IMG_BEST_BLOCK_K = 32
FP32_COL2IMG_BEST_WARPS = 4
FP32_COL2IMG_BEST_STAGES = 2