In [1]:
import triton
import triton.language as tl
import torch
import time
from gemm_int8_kernel import gemm_int8_tc_kernel

In [2]:
def gemm_int8_tc(
    A_q: torch.Tensor,   # [M,K] int8
    B_q: torch.Tensor,   # [K,N] int8
    *,
    BLOCK_M: int = 64,
    BLOCK_N: int = 64,
    BLOCK_K: int = 32,
    num_warps: int = 4,
    num_stages: int = 2,
):
    """
    Совместимая с твоей gemm_int8() версия.
    Возвращает C_i32: [M,N] int32.
    """
    assert A_q.is_cuda and B_q.is_cuda, "A_q и B_q должны лежать на CUDA"
    assert A_q.dtype == torch.int8 and B_q.dtype == torch.int8, "Оба тензора должны быть int8"

    if not A_q.is_contiguous():
        A_q = A_q.contiguous()
    if not B_q.is_contiguous():
        B_q = B_q.contiguous()

    M, K1 = A_q.shape
    K2, N = B_q.shape
    assert K1 == K2, f"K mismatch: {K1} vs {K2}"

    assert K1 % 4 == 0, f"K={K1} must be divisible by 4 for INT8 dot"
    assert BLOCK_K % 4 == 0, f"BLOCK_K={BLOCK_K} must be divisible by 4"

    C_i32 = torch.empty((M, N), dtype=torch.int32, device=A_q.device)

    a_m, a_k = A_q.stride()
    b_k, b_n = B_q.stride()
    c_m, c_n = C_i32.stride()

    grid = (
        triton.cdiv(M, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )

    gemm_int8_tc_kernel[grid](
        A_q, B_q, C_i32,
        M, N, K1,
        a_m, a_k,
        b_k, b_n,
        c_m, c_n,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return C_i32

# title search

In [3]:
@torch.no_grad()
def bench_once_gemm_int8_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters=100,
    device="cuda",
):
    """
    Бенч ОДНОГО размера GEMM:
      - C = A @ B
      - A [M,K] int8, B [K,N] int8
      - наша реализация vs torch (fp32 ref)
    """
    # генерим int8
    A_q = torch.randint(-128, 127, (M, K), device=device, dtype=torch.int8)
    B_q = torch.randint(-128, 127, (K, N), device=device, dtype=torch.int8)

    # fp32 ref
    A_f = A_q.float()
    B_f = B_q.float()

    with torch.no_grad():
        C_ref = A_f @ B_f  # [M,N] fp32

    # --- warmup & timing: Triton ---
    def _call_triton():
        C_i32 = gemm_int8_tc(
            A_q, B_q,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return C_i32

    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        C_i32 = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # --- torch matmul (fp32) ---
    def _call_torch():
        return A_f @ B_f

    for _ in range(5):
        _ = _call_torch()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref2 = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # ошибка (переводим наш int32 -> fp32, сравниваем с fp32 ref)
    C_triton_f = C_i32.float()
    max_abs_err = (C_triton_f - C_ref).abs().max().item()

    # простая оценка bandwidth: читаем A,B, пишем C (байты/сек)
    bytes_moved = A_q.numel() + B_q.numel()  # int8
    bytes_moved += C_i32.numel() * 4        # int32
    bytes_moved = float(bytes_moved)

    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch  = bytes_moved / t_torch  / 1e9

    return {
        "M": M, "K": K, "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
        "max_abs_err": max_abs_err,
    }

In [4]:
import pandas as pd
import torch
@torch.no_grad()
def tune_gemm_int8_tiles_for_shape(
    M, K, N,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=100,
    device="cuda",
):
    """
    Перебираем набор тайлов и возвращаем:
      - df: DataFrame со всеми результатами (по одной строке на конфиг тайлов)
    Никакого выбора "лучшего" внутри — это ты делаешь снаружи.
    """
    records = []

    for BM in blocks_M:
        for BN in blocks_N:
            for BK in blocks_K:
                # K и BLOCK_K должны быть кратны 4
                if (K % 4 != 0) or (BK % 4 != 0):
                    print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}: K/BK not multiple of 4")
                    continue

                for W in warps:
                    for S in stages:
                        try:
                            rec = bench_once_gemm_int8_vs_torch(
                                M, K, N,
                                BLOCK_M=BM,
                                BLOCK_N=BN,
                                BLOCK_K=BK,
                                num_warps=W,
                                num_stages=S,
                                iters=iters,
                                device=device,
                            )
                        except RuntimeError as e:
                            print(f"[SKIP] BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: {e}")
                            continue

                        print(
                            f"BM={BM}, BN={BN}, BK={BK}, W={W}, S={S}: "
                            f"t_triton={rec['t_triton_ms']:.3f} ms, "
                            f"speed_vs_torch={rec['speed_vs_torch']:.3f}x, "
                            f"err={rec['max_abs_err']}"
                        )
                        records.append(rec)

    if not records:
        raise RuntimeError("No valid tile configs found for this GEMM shape")

    df = pd.DataFrame(records)
    return df

In [5]:
df_gemm_tiles = tune_gemm_int8_tiles_for_shape(
    M=4096, K=1024, N=1024,
    blocks_M=(32, 64, 128),
    blocks_N=(32, 64, 128),
    blocks_K=(32, 64, 128),
    warps=(2, 4, 8),
    stages=(2, 3),
    iters=200,
    device="cuda",
)

BM=32, BN=32, BK=32, W=2, S=2: t_triton=0.236 ms, speed_vs_torch=2.716x, err=0.0
BM=32, BN=32, BK=32, W=2, S=3: t_triton=0.245 ms, speed_vs_torch=2.578x, err=0.0
BM=32, BN=32, BK=32, W=4, S=2: t_triton=0.391 ms, speed_vs_torch=1.613x, err=0.0
BM=32, BN=32, BK=32, W=4, S=3: t_triton=0.408 ms, speed_vs_torch=1.511x, err=0.0
BM=32, BN=32, BK=32, W=8, S=2: t_triton=0.504 ms, speed_vs_torch=1.267x, err=0.0
BM=32, BN=32, BK=32, W=8, S=3: t_triton=0.470 ms, speed_vs_torch=1.340x, err=0.0
BM=32, BN=32, BK=64, W=2, S=2: t_triton=0.214 ms, speed_vs_torch=2.972x, err=0.0
BM=32, BN=32, BK=64, W=2, S=3: t_triton=0.220 ms, speed_vs_torch=2.920x, err=0.0
BM=32, BN=32, BK=64, W=4, S=2: t_triton=0.312 ms, speed_vs_torch=2.012x, err=0.0
BM=32, BN=32, BK=64, W=4, S=3: t_triton=0.323 ms, speed_vs_torch=1.925x, err=0.0
BM=32, BN=32, BK=64, W=8, S=2: t_triton=0.373 ms, speed_vs_torch=1.708x, err=0.0
BM=32, BN=32, BK=64, W=8, S=3: t_triton=0.376 ms, speed_vs_torch=1.658x, err=0.0
BM=32, BN=32, BK=128, W=2, S

In [6]:
df_gemm_tiles.sort_values("t_triton_ms").head(10)

Unnamed: 0,M,K,N,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs,max_abs_err
99,4096,1024,1024,64,128,64,4,3,0.104641,0.633273,6.051837,210.433932,34.77191,0.0
153,4096,1024,1024,128,128,64,4,3,0.106228,0.635428,5.981746,207.291235,34.65397,0.0
98,4096,1024,1024,64,128,64,4,2,0.106953,0.638751,5.972279,205.886332,34.47366,0.0
133,4096,1024,1024,128,64,64,2,3,0.108423,0.649938,5.994492,203.095065,33.880281,0.0
96,4096,1024,1024,64,128,64,2,2,0.110898,0.630017,5.681029,198.560855,34.951565,0.0
147,4096,1024,1024,128,128,32,4,3,0.111004,0.63828,5.750068,198.372238,34.499112,0.0
134,4096,1024,1024,128,64,64,4,2,0.111353,0.640743,5.754159,197.750308,34.366501,0.0
91,4096,1024,1024,64,128,32,2,3,0.113844,0.634232,5.571065,193.423634,34.719327,0.0
104,4096,1024,1024,64,128,128,4,2,0.114519,0.631522,5.514569,192.28351,34.868277,0.0
152,4096,1024,1024,128,128,64,4,2,0.117167,0.635612,5.424834,187.937391,34.6439,0.0


In [11]:
INT8_GEMM_BEST_BLOCK_M = 64
INT8_GEMM_BEST_BLOCK_N = 128
INT8_GEMM_BEST_BLOCK_K = 64
INT8_GEMM_BEST_WARPS   = 4
INT8_GEMM_BEST_STAGES  = 3

# BENCH

In [11]:
import time
import pandas as pd
import torch

# Ожидаем, что где-то выше уже есть:
# from your_module import gemm_int8_tc
# (тот, что берёт A_q[int8][M,K], B_q[int8][K,N] и возвращает C_i32[int32][M,N])


@torch.no_grad()
def bench_once_gemm_int8_vs_torch(
    M, K, N,
    BLOCK_M,
    BLOCK_N,
    BLOCK_K,
    num_warps,
    num_stages,
    iters: int = 50,
    device: str = "cuda",
):
    """
    Бенч для ОДНОГО GEMM-шейпа и ОДНОГО набора тайлов:
      - A_q: [M, K] int8
      - B_q: [K, N] int8
      - Triton: gemm_int8_tc -> C_i32 [M,N] int32
      - Torch: (A_q.float() @ B_q.float()).to(int32)
      - считаем:
          t_triton_ms, t_torch_ms,
          speed_vs_torch,
          max_abs_err
    """
    # --- проверка кратности для INT8 dot ---
    if (K % 4) != 0 or (BLOCK_K % 4) != 0:
        raise RuntimeError(f"K={K} или BLOCK_K={BLOCK_K} не кратны 4 — INT8 dot невалиден")

    # --- данные ---
    A_q = torch.randint(-128, 127, (M, K), device=device, dtype=torch.int8)
    B_q = torch.randint(-128, 127, (K, N), device=device, dtype=torch.int8)

    # референс: fp32 GEMM + cast к int32
    A_f = A_q.float()
    B_f = B_q.float()

    def _call_triton():
        C_i32 = gemm_int8_tc(
            A_q, B_q,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            BLOCK_K=BLOCK_K,
            num_warps=num_warps,
            num_stages=num_stages,
        )
        return C_i32

    def _call_torch():
        C_ref_f = A_f @ B_f          # [M,N] fp32, cuBLAS
        C_ref_i32 = C_ref_f.to(torch.int32)
        return C_ref_i32

    # --- прогрев Triton ---
    for _ in range(5):
        _ = _call_triton()
    torch.cuda.synchronize()

    # измеряем Triton
    t0 = time.perf_counter()
    for _ in range(iters):
        C_i32 = _call_triton()
    torch.cuda.synchronize()
    t_triton = (time.perf_counter() - t0) / iters

    # --- прогрев Torch ---
    for _ in range(5):
        _ = _call_torch()
    torch.cuda.synchronize()

    # измеряем Torch
    t0 = time.perf_counter()
    for _ in range(iters):
        C_ref_i32 = _call_torch()
    torch.cuda.synchronize()
    t_torch = (time.perf_counter() - t0) / iters

    # --- ошибка ---
    max_abs_err = (C_i32 - C_ref_i32).abs().max().item()

    # простая оценка "байт" (очень грубо): читаем A и B, пишем C
    bytes_moved = A_q.numel() + B_q.numel() + C_i32.numel()  # в "элементах"
    # переводим в байты: int8=1 байт, int32=4 байта
    bytes_moved = A_q.numel() * 1 + B_q.numel() * 1 + C_i32.numel() * 4
    bw_triton = bytes_moved / t_triton / 1e9
    bw_torch = bytes_moved / t_torch / 1e9

    return {
        "M": M,
        "K": K,
        "N": N,
        "BLOCK_M": BLOCK_M,
        "BLOCK_N": BLOCK_N,
        "BLOCK_K": BLOCK_K,
        "num_warps": num_warps,
        "num_stages": num_stages,
        "t_triton_ms": t_triton * 1e3,
        "t_torch_ms": t_torch * 1e3,
        "speed_vs_torch": t_torch / t_triton,
        "bw_triton_GBs": bw_triton,
        "bw_torch_GBs": bw_torch,
        "max_abs_err": max_abs_err,
    }


@torch.no_grad()
def benchmark_gemm_int8_best_tiles(
    tiles_cfg: dict,
    Ms=(1024, 2048, 4096),
    Ks=(256, 512, 1024),
    Ns=(256, 512, 1024),
    iters_per_shape: int = 50,
    device: str = "cuda",
):
    """
    Сравнение gemm_int8_tc vs torch матрицы на сетке (M, K, N)
    при фиксированных лучших тайлах tiles_cfg.

    tiles_cfg: dict с ключами:
      - "BLOCK_M"
      - "BLOCK_N"
      - "BLOCK_K"
      - "num_warps"
      - "num_stages"
    """
    BM = tiles_cfg["BLOCK_M"]
    BN = tiles_cfg["BLOCK_N"]
    BK = tiles_cfg["BLOCK_K"]
    NW = tiles_cfg["num_warps"]
    NS = tiles_cfg["num_stages"]

    records = []

    for M in Ms:
        for K in Ks:
            for N in Ns:
                print(f"=== GEMM SHAPE: M={M}, K={K}, N={N} ===")

                # Проверяем кратность для INT8 dot
                if (K % 4) != 0 or (BK % 4) != 0:
                    print(f"[SKIP] M={M}, K={K}, N={N}: K/BLOCK_K not multiple of 4")
                    continue

                try:
                    rec = bench_once_gemm_int8_vs_torch(
                        M, K, N,
                        BLOCK_M=BM,
                        BLOCK_N=BN,
                        BLOCK_K=BK,
                        num_warps=NW,
                        num_stages=NS,
                        iters=iters_per_shape,
                        device=device,
                    )
                except RuntimeError as e:
                    print(f"[SKIP] M={M}, K={K}, N={N}: {e}")
                    continue

                if rec["max_abs_err"] != 0:
                    print(f"[WRONG] M={M}, K={K}, N={N}, err={rec['max_abs_err']}")
                    continue

                records.append(rec)

    if not records:
        print("[WARN] no successful GEMM records")
        return pd.DataFrame()

    df = pd.DataFrame(records)
    return df


In [13]:
best_gemm_tiles = {
    "BLOCK_M": 64,
    "BLOCK_N": 128,
    "BLOCK_K": 64,
    "num_warps": 4,
    "num_stages": 3,
}



df_gemm = benchmark_gemm_int8_best_tiles(
    best_gemm_tiles,
    Ms=(512, 1024, 2048, 4096),
    Ks=(256, 512, 1024),
    Ns=(256, 512, 1024),
    iters_per_shape=50,
    device="cuda",
)

df_gemm.sort_values("speed_vs_torch", ascending=False).head(20)

=== GEMM SHAPE: M=512, K=256, N=256 ===
=== GEMM SHAPE: M=512, K=256, N=512 ===
=== GEMM SHAPE: M=512, K=256, N=1024 ===
=== GEMM SHAPE: M=512, K=512, N=256 ===
=== GEMM SHAPE: M=512, K=512, N=512 ===
=== GEMM SHAPE: M=512, K=512, N=1024 ===
=== GEMM SHAPE: M=512, K=1024, N=256 ===
=== GEMM SHAPE: M=512, K=1024, N=512 ===
=== GEMM SHAPE: M=512, K=1024, N=1024 ===
=== GEMM SHAPE: M=1024, K=256, N=256 ===
=== GEMM SHAPE: M=1024, K=256, N=512 ===
=== GEMM SHAPE: M=1024, K=256, N=1024 ===
=== GEMM SHAPE: M=1024, K=512, N=256 ===
=== GEMM SHAPE: M=1024, K=512, N=512 ===
=== GEMM SHAPE: M=1024, K=512, N=1024 ===
=== GEMM SHAPE: M=1024, K=1024, N=256 ===
=== GEMM SHAPE: M=1024, K=1024, N=512 ===
=== GEMM SHAPE: M=1024, K=1024, N=1024 ===
=== GEMM SHAPE: M=2048, K=256, N=256 ===
=== GEMM SHAPE: M=2048, K=256, N=512 ===
=== GEMM SHAPE: M=2048, K=256, N=1024 ===
=== GEMM SHAPE: M=2048, K=512, N=256 ===
=== GEMM SHAPE: M=2048, K=512, N=512 ===
=== GEMM SHAPE: M=2048, K=512, N=1024 ===
=== GEMM SH

Unnamed: 0,M,K,N,BLOCK_M,BLOCK_N,BLOCK_K,num_warps,num_stages,t_triton_ms,t_torch_ms,speed_vs_torch,bw_triton_GBs,bw_torch_GBs,max_abs_err
26,2048,1024,1024,64,128,64,4,3,0.060889,0.364229,5.981854,189.432118,31.667794,0
35,4096,1024,1024,64,128,64,4,3,0.113535,0.664686,5.854477,193.950395,33.12856,0
29,4096,256,1024,64,128,64,4,3,0.040263,0.232254,5.768443,449.24551,77.879862,0
34,4096,1024,512,64,128,64,4,3,0.059566,0.34203,5.742029,220.044549,38.321744,0
23,2048,512,1024,64,128,64,4,3,0.036624,0.204391,5.580816,271.99405,48.737327,0
32,4096,512,1024,64,128,64,4,3,0.071848,0.372878,5.189799,269.994717,52.024117,0
17,1024,1024,1024,64,128,64,4,3,0.046445,0.230206,4.956555,135.460993,27.329667,0
33,4096,1024,256,64,128,64,4,3,0.04583,0.224259,4.893245,188.755678,38.574744,0
31,4096,512,512,64,128,64,4,3,0.045775,0.20373,4.450675,234.797943,52.755586,0
25,2048,1024,512,64,128,64,4,3,0.058133,0.248542,4.275406,117.243935,27.422875,0
