In [1]:
import time
import torch
import triton
from torch import nn

# ==== ИМПОРТЫ ОБЁРТОК INT8 ==== 
# правь пути под свой проект, но идея такая:
from img2col_int8_kernel import img2col_int8
from gemm_int8_kernel    import gemm_int8_tc
from col2img_int8_kernel import col2img_int32

In [2]:
def _force_strict_fp32():
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass


_force_strict_fp32()
torch.backends.cudnn.benchmark = True

In [3]:
device = "cuda"

In [4]:
_force_strict_fp32()
torch.backends.cudnn.benchmark = True

In [5]:
def sync():
    torch.cuda.synchronize()


def bench(fn, warmup=50, iters=300):
    """Среднее время вызова fn() в секундах."""
    for _ in range(warmup):
        fn()
    sync()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    sync()
    return (time.perf_counter() - t0) / iters


def out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw):
    Ho = (H + 2 * Ph - Dh * (Kh - 1) - 1) // Sh + 1
    Wo = (W + 2 * Pw - Dw * (Kw - 1) - 1) // Sw + 1
    return Ho, Wo


def row(name, items):
    print("| " + name + " | " + " | ".join(items) + " |")

In [6]:
def quantize_per_tensor_symmetric(x: torch.Tensor, num_bits: int = 8):
    """
    Простой симметричный per-tensor квантайзер:
      x ≈ scale * q, q ∈ [-128, 127].
    Возвращает (q:int8, scale:float)
    """
    assert x.is_floating_point()
    qmax = 2 ** (num_bits - 1) - 1  # 127
    max_abs = x.abs().max().item()
    if max_abs == 0.0:
        scale = 1.0
    else:
        scale = max_abs / qmax
    q = torch.clamp((x / scale).round(), -128, 127).to(torch.int8)
    return q, float(scale)


def quantize_to_int32(x: torch.Tensor):
    """
    Используем для col2img_int32: x ≈ scale * q_int32 (где q = round(x/scale)).
    По сути то же, что выше, только тип int32.
    """
    assert x.is_floating_point()
    max_abs = x.abs().max().item()
    if max_abs == 0.0:
        scale = 1.0
    else:
        # можно взять чуть грубее, чтобы не плодить огромные значения
        scale = max_abs / (2 ** 20)  # условно 20 бит под полезный сигнал
    q = torch.round(x / scale).to(torch.int32)
    return q, float(scale)

In [7]:
def int8_stage_img2col(
    x_fp32: torch.Tensor,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int,
    BLOCK_K: int,
    NUM_WARPS: int,
    NUM_STAGES: int,
):
    """
    FWD: x_fp32 -> (Unfold-эталон, Triton-int8 img2col-dequant).
    Возвращает (cols_ref_fp32, cols_tri_fp32).
    """
    B, Cin, H, W = x_fp32.shape
    Ho, Wo = out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw)
    K = Cin * Kh * Kw

    # Эталон: Unfold -> [B, K, L] -> [M, K]
    unfold = nn.Unfold(
        kernel_size=(Kh, Kw),
        dilation=(Dh, Dw),
        padding=(Ph, Pw),
        stride=(Sh, Sw),
    ).to(x_fp32.device)
    with torch.no_grad():
        cols_ref = (
            unfold(x_fp32.float())
            .transpose(1, 2)
            .contiguous()
            .view(-1, K)
            .float()
        )  # [M,K]

    # INT8-путь: квант x -> img2col_int8 -> dequant
    x_q, s_x = quantize_per_tensor_symmetric(x_fp32)
    cols_q, _ = img2col_int8(
        x_q,
        Kh, Kw,
        Sh, Sw,
        Ph, Pw,
        Dh, Dw,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=NUM_WARPS,
        num_stages=NUM_STAGES,
    )  # [M,K] int8
    cols_tri = cols_q.float() * s_x  # примитивный dequant

    return cols_ref, cols_tri

In [8]:
def int8_stage_col2img(
    dcols_fp32: torch.Tensor,
    B: int, Cin: int, H: int, W: int,
    Kh: int, Kw: int,
    Sh: int, Sw: int,
    Ph: int, Pw: int,
    Dh: int, Dw: int,
    BLOCK_M: int,
    BLOCK_K: int,
    NUM_WARPS: int,
    NUM_STAGES: int,
):
    """
    BWD-эквивалент: dcols -> dx.
    Torch: Fold(dcols) vs Triton: col2img_int32(q(dcols))*scale.
    Возвращает (dx_ref_fp32, dx_tri_fp32).
    """
    device = dcols_fp32.device
    Ho, Wo = out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw)
    K = Cin * Kh * Kw
    M = B * Ho * Wo
    assert dcols_fp32.shape == (M, K)

    # Torch Fold: dcols [M,K] -> [B,K,L] -> Fold -> [B,Cin,H,W]
    fold = nn.Fold(
        output_size=(H, W),
        kernel_size=(Kh, Kw),
        dilation=(Dh, Dw),
        padding=(Ph, Pw),
        stride=(Sh, Sw),
    ).to(device)
    with torch.no_grad():
        cols_3d = (
            dcols_fp32.view(B, Ho * Wo, K)
            .transpose(1, 2)
            .contiguous()
        )  # [B,K,L]
        dx_ref = fold(cols_3d).float()

    # Triton col2img_int32: сначала квант в int32, потом col2img_int32, потом dequant
    dcols_q, s_d = quantize_to_int32(dcols_fp32)
    dx_tri = col2img_int32(
        dcols_q,
        B, Cin, H, W,
        Kh, Kw,
        Sh, Sw,
        Ph, Pw,
        Dh, Dw,
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        num_warps=NUM_WARPS,
        num_stages=NUM_STAGES,
    )  # [B,Cin,H,W] fp32, но это сумма int32, значит надо умножить на scale
    dx_tri = dx_tri * s_d

    return dx_ref, dx_tri


In [9]:
def chapter_img2col_col2im_int8(cfg, tile_cfgs):
    """
    Аналог твоей главы 1, но:
      - FWD: Unfold vs img2col_int8
      - BWD: Fold  vs col2img_int32
    Всё с квант/деквант, считаем MAE/max и время.
    """
    print("\n# Глава 1 (INT8): ImageToColumn + ColumnToImage")
    device = "cuda"
    torch.cuda.empty_cache()          # сброс кэша
    torch.cuda.reset_peak_memory_stats(device=device)

    B, Cin, Cout, H, W, ks = cfg["B"], cfg["Cin"], cfg["Cout"], cfg["H"], cfg["W"], cfg["ks"]
    stride, padding, dilation = cfg["stride"], cfg["padding"], cfg["dilation"]
    Sh, Sw = (stride, stride) if isinstance(stride, int) else stride
    Ph, Pw = (padding, padding) if isinstance(padding, int) else padding
    Dh, Dw = (dilation, dilation) if isinstance(dilation, int) else dilation
    Kh = Kw = ks

    Ho, Wo = out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw)
    K = Cin * Kh * Kw
    M = B * Ho * Wo

    # один x/градиент на все конфиги, чтобы сравнение было честным
    torch.manual_seed(0)
    x = torch.randn(B, Cin, H, W, device=device, dtype=torch.float32)
    dcols = torch.randn(M, K, device=device, dtype=torch.float32)

    for tile in tile_cfgs:
        BLOCK_M = tile["BLOCK_M"]
        BLOCK_K = tile["BLOCK_K"]
        NUM_WARPS = tile["NUM_WARPS"]
        NUM_STAGES = tile["NUM_STAGES"]

        print(f"\n[int8 img2col+col2im] cfg={tile}")

        # -------- FWD точность --------
        with torch.no_grad():
            cols_ref, cols_tri = int8_stage_img2col(
                x, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
                BLOCK_M=BLOCK_M,
                BLOCK_K=BLOCK_K,
                NUM_WARPS=NUM_WARPS,
                NUM_STAGES=NUM_STAGES,
            )
            diff_f = (cols_ref - cols_tri).abs()
            f_mae = diff_f.mean().item()
            f_mx = diff_f.max().item()

        # -------- BWD точность (dx) --------
        with torch.no_grad():
            dx_ref, dx_tri = int8_stage_col2img(
                dcols, B, Cin, H, W,
                Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
                BLOCK_M=BLOCK_M,
                BLOCK_K=BLOCK_K,
                NUM_WARPS=NUM_WARPS,
                NUM_STAGES=NUM_STAGES,
            )
            diff_b = (dx_ref - dx_tri).abs()
            b_mae = diff_b.mean().item()
            b_mx = diff_b.max().item()

        # -------- Тайминги --------
        unfold = nn.Unfold(
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        ).to(device)

        def f_unfold():
            unfold(x.float())

        def f_i2c_int8():
            cols_ref_, cols_tri_ = int8_stage_img2col(
                x, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
                BLOCK_M=BLOCK_M,
                BLOCK_K=BLOCK_K,
                NUM_WARPS=NUM_WARPS,
                NUM_STAGES=NUM_STAGES,
            )
            return cols_tri_

        fold = nn.Fold(
            output_size=(H, W),
            kernel_size=(Kh, Kw),
            dilation=(Dh, Dw),
            padding=(Ph, Pw),
            stride=(Sh, Sw),
        ).to(device)

        def f_fold():
            cols_3d = (
                dcols.view(B, Ho * Wo, K)
                .transpose(1, 2)
                .contiguous()
            )
            fold(cols_3d)

        def f_c2i_int8():
            dx_ref_, dx_tri_ = int8_stage_col2img(
                dcols, B, Cin, H, W,
                Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw,
                BLOCK_M=BLOCK_M,
                BLOCK_K=BLOCK_K,
                NUM_WARPS=NUM_WARPS,
                NUM_STAGES=NUM_STAGES,
            )
            return dx_tri_

        t_unfold = bench(f_unfold, warmup=30, iters=200)
        t_i2c = bench(f_i2c_int8, warmup=30, iters=200)
        t_fold = bench(f_fold, warmup=30, iters=200)
        t_c2i = bench(f_c2i_int8, warmup=30, iters=200)

        row(
            "Этап",
            [
                "FWD MAE(cols)",
                "FWD max(cols)",
                "BWD MAE(dx)",
                "BWD max(dx)",
                "Torch Unfold ms",
                "Triton i2c ms",
                "Torch Fold ms",
                "Triton c2i ms",
            ],
        )
        row(
            "Знач",
            [
                f"{f_mae:.3e}",
                f"{f_mx:.3e}",
                f"{b_mae:.3e}",
                f"{b_mx:.3e}",
                f"{t_unfold * 1e3:.3f}",
                f"{t_i2c * 1e3:.3f}",
                f"{t_fold * 1e3:.3f}",
                f"{t_c2i * 1e3:.3f}",
            ],
        )


In [10]:
def chapter_gemm_int8(cfg, tile_cfgs):
    """
    Аналог главы GEMM:
      Torch: fp32 @ fp32
      Triton: int8×int8 -> fp32, с квант/деквант.
    """
    print("\n# Глава 2 (INT8): GEMM int8×int8 vs Torch mm")

    B, Cin, Cout, H, W, ks = cfg["B"], cfg["Cin"], cfg["Cout"], cfg["H"], cfg["W"], cfg["ks"]
    stride, padding, dilation = cfg["stride"], cfg["padding"], cfg["dilation"]
    Sh, Sw = (stride, stride) if isinstance(stride, int) else stride
    Ph, Pw = (padding, padding) if isinstance(padding, int) else padding
    Dh, Dw = (dilation, dilation) if isinstance(dilation, int) else dilation
    Kh = Kw = ks
    Ho, Wo = out_hw(H, W, Kh, Kw, Sh, Sw, Ph, Pw, Dh, Dw)

    M = B * Ho * Wo
    K = Cin * Kh * Kw
    N = Cout

    torch.manual_seed(1)
    A_fp32 = torch.randn(M, K, device=device, dtype=torch.float32)
    B_fp32 = torch.randn(K, N, device=device, dtype=torch.float32)

    # эталон
    with torch.no_grad():
        C_ref = (A_fp32 @ B_fp32).float()

    for tile in tile_cfgs:
        BLOCK_M = tile["BLOCK_M"]
        BLOCK_N = tile["BLOCK_N"]
        BLOCK_K = tile["BLOCK_K"]
        NUM_WARPS = tile["NUM_WARPS"]
        NUM_STAGES = tile["NUM_STAGES"]

        print(f"\n[int8 GEMM] cfg={tile}")

        # квант
        A_q, s_A = quantize_per_tensor_symmetric(A_fp32)
        B_q, s_B = quantize_per_tensor_symmetric(B_fp32)

        # FWD
        with torch.no_grad():
            C_tri = gemm_int8_tc(
                A_q,
                B_q,
                BLOCK_M=BLOCK_M,
                BLOCK_N=BLOCK_N,
                BLOCK_K=BLOCK_K,
                num_warps=NUM_WARPS,
                num_stages=NUM_STAGES,
            )  # fp32 acc
            C_tri = C_tri * (s_A * s_B)

            diff = (C_ref - C_tri).abs()
            f_mae = diff.mean().item()
            f_mx = diff.max().item()

        # Тайминги
        def f_torch():
            A_fp32 @ B_fp32

        def f_tri():
            C_ = gemm_int8_tc(
                A_q,
                B_q,
                BLOCK_M=BLOCK_M,
                BLOCK_N=BLOCK_N,
                BLOCK_K=BLOCK_K,
                num_warps=NUM_WARPS,
                num_stages=NUM_STAGES,
            )
            _ = C_ * (s_A * s_B)

        t_torch = bench(f_torch, warmup=30, iters=200)
        t_tri = bench(f_tri, warmup=30, iters=200)

        row(
            "Этап",
            ["FWD MAE", "FWD max", "Torch mm ms", "Triton int8 GEMM ms", "Speedup"],
        )
        row(
            "Знач",
            [
                f"{f_mae:.3e}",
                f"{f_mx:.3e}",
                f"{t_torch * 1e3:.3f}",
                f"{t_tri * 1e3:.3f}",
                f"{t_torch / max(t_tri, 1e-12):.2f}x",
            ],
        )

In [20]:
cfg = dict(
    B=2, Cin=64, Cout=128, H=320, W=320, ks=3,
    stride=1, padding=1, dilation=1,
    # если у тебя там ещё dtype/use_bias и т.п. — оставь
)
    # Набор тайлов (как в твоих тестах) — сюда же потом подставишь "лучшие"
TILE_CONFIGS = [
        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=64, BLOCK_N=32, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=32, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),

        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=1),

        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=2, NUM_STAGES=2),
    ]

print("\n=== INT8 img2col/col2im & GEMM benchmarks ===")
chapter_img2col_col2im_int8(cfg, TILE_CONFIGS)
chapter_gemm_int8(cfg, TILE_CONFIGS)



=== INT8 img2col/col2im & GEMM benchmarks ===


NameError: name 'chapter_img2col_col2im_int8' is not defined

In [21]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- пробуем зацепить gemm_int8, если есть ---

_USE_GEMM_INT8 = False
try:
    import gemm_int8
    _USE_GEMM_INT8 = True
    print("[INFO] gemm_int8 найден, буду использовать его для GEMM")
except Exception as e:
    print("[WARN] gemm_int8 не найден или не загрузился, fallback на torch (медленно)")
    print("[WARN] причина:", repr(e))


def quantize_int8_per_tensor(x: torch.Tensor):
    """
    Очень простой per-tensor квант: x -> int8, возвращает (x_q, scale).
    """
    with torch.no_grad():
        max_abs = x.abs().max()
        # чтобы не делить на 0
        scale = (max_abs / 127.0).clamp(min=1e-8)
        x_q = (x / scale).round().clamp(-128, 127).to(torch.int8)
    return x_q, scale


def int8_gemm(x_q: torch.Tensor, w_q: torch.Tensor, s_x: torch.Tensor, s_w: torch.Tensor):
    """
    x_q: [M, K] int8
    w_q: [N, K] int8   (как в gemm_int8: второй операнд будет транспонироваться внутри)
    возвращает: [M, N] (bf16 или float32) уже со скейлингом.
    """
    alpha = float(s_x * s_w)
    if _USE_GEMM_INT8:
        # gemm_int8.matmul(X, Y, alpha) считает (X @ Y.T) * alpha
        return gemm_int8.matmul(x_q, w_q, alpha=alpha)
    else:
        # медленный, но корректный fallback: int8 -> float32
        x_f = x_q.float()
        w_f = w_q.float()
        return (x_f @ w_f.t()) * alpha


class Int8UnfoldConv2d(nn.Module):
    """
    Свёртка: Unfold (на float) -> quant -> GEMM_INT8 -> reshape, только forward.
    """
    def __init__(self, in_channels, out_channels,
                 kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super().__init__()

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        if isinstance(dilation, int):
            dilation = (dilation, dilation)

        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.kernel_size  = kernel_size
        self.stride       = stride
        self.padding      = padding
        self.dilation     = dilation

        self.weight = nn.Parameter(
            torch.empty(out_channels, in_channels, *kernel_size)
        )
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

        self.unfold = nn.Unfold(
            kernel_size=self.kernel_size,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
        )

        nn.init.kaiming_uniform_(self.weight, a=5**0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [N, Cin, H, W] (fp32/bf16)
        """
        N, Cin, H, W = x.shape
        Kh, Kw = self.kernel_size

        # ---- Unfold на float ----
        # cols_f: [N, Cin*Kh*Kw, L], L = Ho*Wo
        cols_f = self.unfold(x)          # float
        N_, K, L = cols_f.shape
        assert N_ == N

        # ---- квант патчей и весов в int8 ----
        cols_q, s_x = quantize_int8_per_tensor(cols_f)   # [N, K, L] int8
        w_q,   s_w  = quantize_int8_per_tensor(self.weight)

        # X_q: [M, K], где M = N*L
        X_q = cols_q.transpose(1, 2).contiguous().view(N * L, K)

        # W_q: [Cout, K]
        W_q = w_q.view(self.out_channels, K)

        # ---- GEMM int8 ----
        Y = int8_gemm(X_q, W_q, s_x, s_w)  # [M, Cout]

        # ---- reshape обратно в [N, Cout, Ho, Wo] ----
        Ho = (H + 2 * self.padding[0]
              - self.dilation[0] * (Kh - 1) - 1) // self.stride[0] + 1
        Wo = (W + 2 * self.padding[1]
              - self.dilation[1] * (Kw - 1) - 1) // self.stride[1] + 1

        Y = Y.view(N, L, self.out_channels).transpose(1, 2).contiguous()
        Y = Y.view(N, self.out_channels, Ho, Wo)

        if self.bias is not None:
            Y = Y + self.bias.view(1, -1, 1, 1)

        return Y



[WARN] gemm_int8 не найден или не загрузился, fallback на torch (медленно)
[WARN] причина: OSError('/home/manzhura/ITMO/EDLM/venv/lib/python3.10/site-packages/gemm_int8/gemm_int8_CUDA.so: undefined symbol: _ZN5torch7LibraryC1ENS0_4KindESsN3c108optionalINS2_11DispatchKeyEEEPKcj')


In [22]:
def bench(fn, warmup=10, iters=50):
    # простой бенч с синхронизацией
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.perf_counter() - t0) / iters


def run_benchmark():
    device = "cuda"

    # какие-нибудь реалистичные размеры
    N, Cin, Cout = 32, 64, 64
    H, W = 56, 56
    ks = 3
    stride = 1
    padding = 1
    dilation = 1

    x = torch.randn(N, Cin, H, W, device=device)

    conv_ref = nn.Conv2d(
        Cin, Cout, ks,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=True,
    ).to(device)

    conv_int8 = Int8UnfoldConv2d(
        Cin, Cout, ks,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=True,
    ).to(device)

    # чтобы сравнивать честно, копируем веса и биасы
    with torch.no_grad():
        conv_int8.weight.copy_(conv_ref.weight)
        if conv_ref.bias is not None and conv_int8.bias is not None:
            conv_int8.bias.copy_(conv_ref.bias)

    # прогрев
    with torch.no_grad():
        y_ref = conv_ref(x)
        y_int8 = conv_int8(x)
    torch.cuda.synchronize()

    # ошибка (из-за квантования будет заметная, но нам сейчас важна скорость)
    max_err = (y_ref - y_int8.to(y_ref.dtype)).abs().max().item()
    print(f"max_err = {max_err:.6f}")

    with torch.no_grad():
        t_ref = bench(lambda: conv_ref(x))
        t_int8 = bench(lambda: conv_int8(x))

    print(f"Conv2d (cuDNN, fp32):      {t_ref*1e3:.3f} ms")
    print(f"Unfold+GEMM (int8 path):   {t_int8*1e3:.3f} ms")
    print(f"speedup (int8 / conv2d) =  {t_ref / t_int8:.3f}x")



