In [1]:
import torch
import time
from torch import nn

# === импорт твоего слоя INT8 ===
from conv_gemm.baseline_layers.triton_conv2d import TritonConv2d

In [2]:
torch.manual_seed(0)
device = "cuda"

In [11]:
N = 1
Cin = 3
Cout = 8
H = W = 512
K = 11


In [12]:
x_fp16 = torch.randn(N, Cin, H, W, device=device, dtype=torch.float16)
def bench_ms(fn, iters=50):
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters

In [13]:
conv_ref = nn.Conv2d(
    Cin, Cout,
    kernel_size=K,
    stride=1,
    padding=K // 2,
    bias=True,
).to(device).half()

In [14]:
conv_triton = TritonConv2d(
    in_channels=Cin,
    out_channels=Cout,
    kernel_size=K,
    stride=1,
    padding=K // 2,
    dilation=1,
    bias=True,
    BLOCK_M=64, BLOCK_N=64, BLOCK_K=32,
    NUM_WARPS=4, NUM_STAGES=2,
).to(device)

In [15]:
with torch.no_grad():
    conv_triton.weight.data.copy_(conv_ref.weight.data.float())
    if conv_ref.bias is not None:
        conv_triton.bias.data.copy_(conv_ref.bias.data.float())


with torch.no_grad():
    y_ref = conv_ref(x_fp16)        # [N, Cout, H, W] fp16
    y_tri = conv_triton(x_fp16) 

In [16]:
print("y_ref.shape:", y_ref.shape)
print("y_tri.shape:", y_tri.shape)

err = (y_ref.float() - y_tri.float()).abs()
print("\n=== FP16 FORWARD ACCURACY ===")
print("max error:", err.max().item())
print("mean error:", err.mean().item())

# ----- speed -----
t_ref = bench_ms(lambda: conv_ref(x_fp16), iters=100)
t_tri = bench_ms(lambda: conv_triton(x_fp16), iters=100)

print("\n=== FP16 FORWARD SPEED (ms) ===")
print(f"PyTorch FP16 Conv2D:   {t_ref:.3f} ms")
print(f"Triton FP16 Conv2D:    {t_tri:.3f} ms")
print(f"Speedup: {t_ref / t_tri:.3f}x")

y_ref.shape: torch.Size([1, 8, 512, 512])
y_tri.shape: torch.Size([1, 8, 512, 512])

=== FP16 FORWARD ACCURACY ===
max error: 0.001953125
mean error: 8.337707549799234e-05

=== FP16 FORWARD SPEED (ms) ===
PyTorch FP16 Conv2D:   0.907 ms
Triton FP16 Conv2D:    1.916 ms
Speedup: 0.473x
