In [1]:
import torch
import triton
import triton.language as tl
import torch.nn.functional as F

In [2]:
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64,  "BLOCK_K": 64}, num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128, "BLOCK_K": 64}, num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3),
    ],
    key=["M", "N", "K"],
)
@triton.jit
def _gemm_kernel(
    a_ptr, b_ptr, bias_ptr, c_ptr,
    M, N, K,
    lda, ldb, ldc,                              # leading dimensions (row-major: lda = K, ldb = N, ldc = N)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    ADD_BIAS: tl.constexpr,
    OUT_FP16: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)      # [BM]
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)      # [BN]
    mask_m = offs_m < M
    mask_n = offs_n < N

    # fp32 аккумулятор для численной стабильности
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # петля по K
    k0 = 0
    while k0 < K:
        offs_k = k0 + tl.arange(0, BLOCK_K)               # [BK]
        mask_k = offs_k < K

        a_ptrs = a_ptr + (offs_m[:, None] * lda + offs_k[None, :])   # (BM, BK)
        b_ptrs = b_ptr + (offs_k[:, None] * ldb + offs_n[None, :])   # (BK, BN)

        a = tl.load(a_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0)
        b = tl.load(b_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0)

        # fp32 accumulate
        acc += tl.dot(a, b, out_dtype=tl.float32)

        k0 += BLOCK_K

    # эпилог: добавляем bias по столбцам (OUT-ось)
    if ADD_BIAS:
        bias = tl.load(bias_ptr + offs_n, mask=mask_n, other=0).to(tl.float32)  # [BN]
        acc = acc + bias[None, :]

    # запись
    c_ptrs = c_ptr + (offs_m[:, None] * ldc + offs_n[None, :])      # (BM, BN)
    if OUT_FP16:
        tl.store(c_ptrs, acc.to(tl.float16), mask=mask_m[:, None] & mask_n[None, :])
    else:
        tl.store(c_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :])


In [3]:
_gemm_kernel

In [4]:
def conv2d_forward_im2col_gemm_retongue(
    x: torch.Tensor,              # (N,C,H,W), CUDA
    weight: torch.Tensor,         # (OUT,C,KH,KW), CUDA
    bias: torch.Tensor | None,    # (OUT,) or None
    *,
    stride=1, padding=0, dilation=1,
    encodeRetongue=None,
    out_dtype=torch.float16,
):
    """
    Прямой проход Conv2d через им2кол + GEMM на Triton, с вызовом encodeRetongue.
    Возвращает Y формы (N, OUT, H_OUT, W_OUT).
    Требует твою функцию im2col_triton(x, kernel_size, stride, padding, dilation) -> (cols, (H_OUT, W_OUT)).
    """
    assert x.is_cuda and weight.is_cuda
    if isinstance(stride, int):   stride   = (stride, stride)
    if isinstance(padding, int):  padding  = (padding, padding)
    if isinstance(dilation, int): dilation = (dilation, dilation)

    N, C, H, W = x.shape
    OUT, Cw, KH, KW = weight.shape
    assert Cw == C

    # 1) im2col (используем твой уже реализованный im2col_triton)
    cols, (H_OUT, W_OUT) = im2col_triton(x, (KH, KW), stride=stride, padding=padding, dilation=dilation)
    # cols: (N, K_TOTAL, L)
    L = H_OUT * W_OUT
    K_TOTAL = C * KH * KW

    # 2) превратим в (M, K) и (K, N)
    A = cols.transpose(1, 2).reshape(N * L, K_TOTAL).contiguous()               # (M=N*L, K_TOTAL)
    B = weight.reshape(OUT, K_TOTAL).transpose(0, 1).contiguous()               # (K_TOTAL, OUT)

    # 3) GEMM (с encodeRetongue и bias в эпилоге)
    y2d = matmul_triton_retongue(A, B, bias, encodeRetongue=encodeRetongue, out_dtype=out_dtype)  # (N*L, OUT)

    # 4) вернуть форму Conv2d выхода
    y = y2d.view(N, L, OUT).transpose(1, 2).reshape(N, OUT, H_OUT, W_OUT)
    return y


# =========================
# 3) Пример быстрой проверки корректности (не обязательно)
# =========================
def _quick_check():
    torch.manual_seed(0)
    device = "cuda"
    N, C, H, W = 2, 3, 16, 16
    OUT, KH, KW = 5, 3, 3
    stride, padding, dilation = (2, 1), (1, 1), (1, 1)

    x = torch.randn(N, C, H, W, device=device, dtype=torch.float16)
    w = torch.randn(OUT, C, KH, KW, device=device, dtype=torch.float16)
    b = torch.randn(OUT, device=device, dtype=torch.float32)

    def encodeRetongue_fn(t):
        # пример: тождественная функция (подставь свою)
        return t

    y_ref = F.conv2d(x.to(torch.float32), w.to(torch.float32), b.to(torch.float32),
                     stride=stride, padding=padding, dilation=dilation, groups=1).to(torch.float16)

    y = conv2d_forward_im2col_gemm_retongue(
        x, w, b, stride=stride, padding=padding, dilation=dilation,
        encodeRetongue=encodeRetongue_fn, out_dtype=torch.float16
    )

    torch.testing.assert_close(y, y_ref, rtol=1e-2, atol=1e-2)
    print("OK")