In [9]:
import torch
import time
from torch import nn


from conv_gemm.baseline_layers.triton_conv2d_int8 import TritonConv2dINT8

In [10]:
torch.manual_seed(0)
device = "cuda"

In [17]:
N = 1
Cin = 3
Cout = 8
H = W = 32
K = 3

x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)
conv_ref = nn.Conv2d(Cin, Cout, K, padding=K//2, bias=True).to(device)

In [18]:
conv_int8 = TritonConv2dINT8(
    Cin, Cout, K,
    padding=K//2,
).to(device)

In [19]:
with torch.no_grad():
    conv_int8.weight.copy_(conv_ref.weight)
    if conv_ref.bias is not None:
        conv_int8.bias.copy_(conv_ref.bias)

In [20]:
# FP32 reference
y_ref = conv_ref(x)

# INT8 Triton
y_int8 = conv_int8(x)

print("y_ref.shape:", y_ref.shape)
print("y_int8.shape:", y_int8.shape)

y_ref.shape: torch.Size([1, 8, 32, 32])
y_int8.shape: torch.Size([1, 8, 32, 32])


In [21]:
# ============================================================
#                   ОЦЕНКА ТОЧНОСТИ
# ============================================================

err = (y_ref - y_int8).abs()
print("\n=== ACCURACY CHECK ===")
print("max error:", err.max().item())
print("mean error:", err.mean().item())


=== ACCURACY CHECK ===
max error: 2.2881076335906982
mean error: 0.43634045124053955


In [22]:
def bench(fn, iters=200):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - start) * 1000 / iters  # ms

t_ref = bench(lambda: conv_ref(x))
t_int8 = bench(lambda: conv_int8(x))

print("\n=== SPEED (ms) ===")
print(f"PyTorch FP32 Conv2D:   {t_ref:.3f} ms")
print(f"Triton INT8 Conv2D:    {t_int8:.3f} ms")
print(f"Speedup: {t_ref / t_int8:.3f}x")


=== SPEED (ms) ===
PyTorch FP32 Conv2D:   0.038 ms
Triton INT8 Conv2D:    0.416 ms
Speedup: 0.091x
