In [1]:
import torch
import time
from torch import nn

# === импорт твоего слоя INT8 ===
from conv_gemm.layers.triton_conv2d import TritonConv2d

In [2]:
torch.manual_seed(0)
device = "cuda"

In [3]:
N = 1
Cin = 3
Cout = 8
H = W = 32
K = 3



# test forvard fp32

In [23]:
x_base = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)

In [24]:
conv_ref = nn.Conv2d(
    Cin, Cout, kernel_size=K,
    stride=1,
    padding=K // 2,
    bias=True,
).to(device)

In [25]:
conv_triton = TritonConv2d(
    in_channels=Cin,
    out_channels=Cout,
    kernel_size=K,
    stride=1,
    padding=K // 2,
    dilation=1,
    bias=True,
    BLOCK_M=32,
    BLOCK_N=32,
    BLOCK_K=32,
    NUM_WARPS=4,
    NUM_STAGES=2,
    precision_mode="fp32",       # можно поменять на "fp16_infer" или "fp16_runtime"
    use_weight_shadow=True,
).to(device)

In [26]:
with torch.no_grad():
    conv_triton.weight.copy_(conv_ref.weight)
    if conv_ref.bias is not None:
        conv_triton.bias.copy_(conv_ref.bias)

In [27]:
with torch.no_grad():
    y_ref    = conv_ref(x_base)
    y_triton = conv_triton(x_base)

In [28]:
print("y_ref.shape:   ", y_ref.shape)
print("y_triton.shape:", y_triton.shape)

err = (y_ref - y_triton).abs()
print("\n=== FORWARD ACCURACY ===")
print("max error: ", err.max().item())
print("mean error:", err.mean().item())

y_ref.shape:    torch.Size([1, 8, 32, 32])
y_triton.shape: torch.Size([1, 8, 32, 32])

=== FORWARD ACCURACY ===
max error:  0.0
mean error: 0.0


# test backrvard fp32

In [16]:
x_ref = x_base.detach().clone().requires_grad_(True)
x_tri = x_base.detach().clone().requires_grad_(True)

conv_ref.zero_grad(set_to_none=True)
conv_triton.zero_grad(set_to_none=True)

y_ref = conv_ref(x_ref)
y_tri = conv_triton(x_tri)

loss_ref = y_ref.sum()
loss_tri = y_tri.sum()

loss_ref.backward()
loss_tri.backward()

# grad по входу
dx_err = (x_ref.grad - x_tri.grad).abs()
dx_err_max = dx_err.max().item()
dx_err_mean = dx_err.mean().item()


dx_err = (x_ref.grad - x_tri.grad).abs()
dw_err = (conv_ref.weight.grad - conv_triton.weight.grad).abs()

print("after dx/dw compute")
print("\n=== BACKWARD GRAD ACCURACY ===")
print(f"dX max err:  {dx_err.max().item():.6e}, mean err: {dx_err.mean().item():.6e}")
print(f"dW max err:  {dw_err.max().item():.6e}, mean err: {dw_err.mean().item():.6e}")


after dx/dw compute

=== BACKWARD GRAD ACCURACY ===
dX max err:  3.576279e-07, mean err: 1.042305e-07
dW max err:  3.051758e-05, mean err: 1.273332e-05


# test forvard fp16

In [34]:
x_fp16 = torch.randn(N, Cin, H, W, device=device, dtype=torch.float16)
def bench_ms(fn, iters=50):
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    return (time.time() - t0) * 1000.0 / iters

In [35]:
conv_ref = nn.Conv2d(
    Cin, Cout,
    kernel_size=K,
    stride=1,
    padding=K // 2,
    bias=True,
).to(device).half()

In [36]:
conv_triton = TritonConv2d(
    in_channels=Cin,
    out_channels=Cout,
    kernel_size=K,
    stride=1,
    padding=K // 2,
    dilation=1,
    bias=True,
    # блочки можешь подставить свои, если хочешь
    BLOCK_M=64, BLOCK_N=64, BLOCK_K=32,
    NUM_WARPS=4, NUM_STAGES=2,
    precision_mode="fp16_runtime",   # или "fp16_infer"
    use_weight_shadow=True,
).to(device)

In [37]:
with torch.no_grad():
    # conv_triton вес хранит в fp32, conv_ref — в fp16
    conv_triton.weight.data.copy_(conv_ref.weight.data.float())
    if conv_ref.bias is not None:
        conv_triton.bias.data.copy_(conv_ref.bias.data.float())


# ----- forward -----
with torch.no_grad():
    y_ref = conv_ref(x_fp16)        # [N, Cout, H, W] fp16
    y_tri = conv_triton(x_fp16) 

In [38]:
print("y_ref.shape:", y_ref.shape)
print("y_tri.shape:", y_tri.shape)

err = (y_ref.float() - y_tri.float()).abs()
print("\n=== FP16 FORWARD ACCURACY ===")
print("max error:", err.max().item())
print("mean error:", err.mean().item())

# ----- speed -----
t_ref = bench_ms(lambda: conv_ref(x_fp16), iters=100)
t_tri = bench_ms(lambda: conv_triton(x_fp16), iters=100)

print("\n=== FP16 FORWARD SPEED (ms) ===")
print(f"PyTorch FP16 Conv2D:   {t_ref:.3f} ms")
print(f"Triton FP16 Conv2D:    {t_tri:.3f} ms")
print(f"Speedup: {t_ref / t_tri:.3f}x")


y_ref.shape: torch.Size([1, 8, 32, 32])
y_tri.shape: torch.Size([1, 8, 32, 32])

=== FP16 FORWARD ACCURACY ===
max error: 0.0009765625
mean error: 7.395536522381008e-05

=== FP16 FORWARD SPEED (ms) ===
PyTorch FP16 Conv2D:   0.062 ms
Triton FP16 Conv2D:    0.484 ms
Speedup: 0.128x


# BACKWARD TEST (FP16)

In [39]:
x_base = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)


In [46]:
conv_ref = nn.Conv2d(
    Cin, Cout, kernel_size=K, stride=1, padding=K//2, bias=True
).to(device).half()

In [47]:
conv_tri = TritonConv2d(
    in_channels=Cin,
    out_channels=Cout,
    kernel_size=K,
    stride=1,
    padding=K//2,
    dilation=1,
    bias=True,
    BLOCK_M=64, BLOCK_N=64, BLOCK_K=32,
    NUM_WARPS=4, NUM_STAGES=2,
    precision_mode="fp16_infer",   # ВАЖНО: без shadow
    use_weight_shadow=False,
).to(device)

In [48]:
with torch.no_grad():
    conv_tri.weight.data.copy_(conv_ref.weight.data.float())
    if conv_ref.bias is not None:
        conv_tri.bias.data.copy_(conv_ref.bias.data.float())

In [49]:
x_ref = x_base.detach().clone().half().requires_grad_(True)
x_tri = x_base.detach().clone().half().requires_grad_(True)

conv_ref.zero_grad(set_to_none=True)
conv_tri.zero_grad(set_to_none=True)

# forward
y_ref = conv_ref(x_ref)
y_tri = conv_tri(x_tri)

loss_ref = y_ref.sum()
loss_tri = y_tri.sum()

loss_ref.backward()
loss_tri.backward()

# ----- grad по входу -----
dx_err = (x_ref.grad.float() - x_tri.grad.float()).abs()
dx_err_max = dx_err.max().item()
dx_err_mean = dx_err.mean().item()

# ----- grad по весам -----
dw_err = (conv_ref.weight.grad.float() - conv_tri.weight.grad.float()).abs()
dw_err_max = dw_err.max().item()
dw_err_mean = dw_err.mean().item()

print("\n=== FP16 BACKWARD GRAD ACCURACY (no AMP) ===")
print(f"dX max err:  {dx_err_max:.6e}, mean err: {dx_err_mean:.6e}")
print(f"dW max err:  {dw_err_max:.6e}, mean err: {dw_err_mean:.6e}")


=== FP16 BACKWARD GRAD ACCURACY (no AMP) ===
dX max err:  0.000000e+00, mean err: 0.000000e+00
dW max err:  0.000000e+00, mean err: 0.000000e+00


# BACKWARD TEST (FP16) AMP

In [53]:
import torch
from torch import nn
from torch.amp import autocast, GradScaler

from conv_gemm.layers.triton_conv2d import TritonConv2d

device = "cuda"
torch.manual_seed(0)

N, Cin, Cout, H, W, K = 2, 8, 16, 32, 32, 3

x_base = torch.randn(N, Cin, H, W, device=device, dtype=torch.float32)

# === Triton Conv2d под AMP (fp16_runtime + shadow) ===
conv_tri_amp = TritonConv2d(
    in_channels=Cin,
    out_channels=Cout,
    kernel_size=K,
    stride=1,
    padding=K//2,
    dilation=1,
    bias=True,
    BLOCK_M=64, BLOCK_N=64, BLOCK_K=32,
    NUM_WARPS=4, NUM_STAGES=2,
    precision_mode="fp16_runtime",   # runtime: x->fp16, weight_fp16_shadow
    use_weight_shadow=True,
).to(device)

optimizer = torch.optim.SGD(conv_tri_amp.parameters(), lr=1e-2)
scaler = GradScaler("cuda")

for step in range(5):
    optimizer.zero_grad(set_to_none=True)
    x = x_base.detach().clone().requires_grad_(True)

    # ---- ВАЖНО: новая форма autocast ----
    with autocast("cuda", dtype=torch.float16):
        y = conv_tri_amp(x)
        loss = y.mean()

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(f"step {step}: loss={loss.item():.4f}")
    print("  x.grad dtype:", x.grad.dtype if x.grad is not None else None)
    print("  weight.grad is None?:", conv_tri_amp.weight.grad is None)

print("\nAMP/backward completed without crash.")


step 0: loss=-0.0075
  x.grad dtype: torch.float32
  weight.grad is None?: True
step 1: loss=-0.0081
  x.grad dtype: torch.float32
  weight.grad is None?: True
step 2: loss=-0.0087
  x.grad dtype: torch.float32
  weight.grad is None?: True
step 3: loss=-0.0094
  x.grad dtype: torch.float32
  weight.grad is None?: True
step 4: loss=-0.0100
  x.grad dtype: torch.float32
  weight.grad is None?: True

AMP/backward completed without crash.
