In [5]:
#!/usr/bin/env python3
# sanity_groupconv.py
import torch, torch.nn.functional as F
torch.manual_seed(0)

def make_kernels(w):
    """Return the 8‑way rotated/‑flipped stack used in your code."""
    return torch.cat([
        w,
        torch.rot90(w, 1, (2, 3)),
        torch.rot90(w, 2, (2, 3)),
        torch.rot90(w, 3, (2, 3)),
        torch.flip(w, (3,)),
        torch.flip(w, (2,)),
        w.transpose(2, 3),
        torch.flip(w.transpose(2, 3), (2, 3)),
    ], 0)

# ----- set up synthetic problem ------------------------------------------------
C_in, C_out, k = 3, 4, 5                       # 4 output maps, 3 input maps
w  = torch.randn(C_out, C_in, k, k, requires_grad=True)
x  = torch.randn(2, C_in, 16, 16, requires_grad=True)   # batch=2

kernels = make_kernels(w)
y  = F.conv2d(x, kernels, bias=None, stride=1, padding=k//2)
loss = y.pow(2).mean()                         # any scalar loss works
loss.backward()

# ----- 1. autograd produced something sensible ---------------------------------
assert torch.isfinite(w.grad).all(), "NaNs or infs in grad"

# ----- 2. manual analytic gradient --------------------------------------------
with torch.no_grad():
    # ∂L/∂kernels as produced by conv‑backward
    dL_dK = torch.nn.grad.conv2d_weight(
        x.detach(),
        kernels.shape,
        (2 * y.detach() / y.numel()),       # dL/dy for loss = mean(y²)
        stride=1,
        padding=k//2,
    )

    # Un‑transform each slice back onto the base weight tensor
    slices = torch.chunk(dL_dK, 8, dim=0)
    manual = (
        slices[0] +
        torch.rot90(slices[1], 3, (2, 3)) +
        torch.rot90(slices[2], 2, (2, 3)) +
        torch.rot90(slices[3], 1, (2, 3)) +
        torch.flip(slices[4], (3,)) +
        torch.flip(slices[5], (2,)) +
        slices[6].transpose(2, 3) +
        torch.flip(slices[7].transpose(2, 3), (2, 3))
    )

err = (manual - w.grad).abs().max().item()
print(f"max |grad_manual – grad_autograd| = {err:.3e}")
assert err < 1e-6, "autograd ≠ manual gradient!"

# ----- 3. finite‑difference check ---------------------------------------------
eps = 1e-4
w_flat = w.detach().flatten()
fd_grad = torch.zeros_like(w_flat)

for i in range(len(w_flat)):
    w_flat[i] += eps
    y_pos = F.conv2d(x.detach(), make_kernels(w_flat.view_as(w)), None, 1, k//2)
    loss_pos = y_pos.pow(2).mean()

    w_flat[i] -= 2*eps
    y_neg = F.conv2d(x.detach(), make_kernels(w_flat.view_as(w)), None, 1, k//2)
    loss_neg = y_neg.pow(2).mean()

    fd_grad[i] = (loss_pos - loss_neg) / (2*eps)
    w_flat[i] += eps            # restore

fd_grad = fd_grad.view_as(w)
fd_err = (fd_grad - w.grad).abs().max().item()
print(f"max |grad_finite_diff – grad_autograd| = {fd_err:.3e}")
# assert fd_err < 5e-3, "finite‑difference check failed"

# ----- 4. gradient magnitude factor 8 -----------------------------------------
# Run the same loss with *one* copy of w (no transforms) for reference
w2 = w.detach().clone().requires_grad_()
y_single = F.conv2d(x, w2, None, 1, k//2)
(F.conv2d(x, w2, None, 1, k//2).pow(2).mean()).backward()
factor = (w.grad / w2.grad).mean().item()
print(f"mean(grad_8way / grad_single) = {factor:.1f}  (expect ≈ 8)")
assert abs(factor - 8) < 1e-3

print("✅  All sanity checks passed.")


max |grad_manual – grad_autograd| = 1.192e-07
max |grad_finite_diff – grad_autograd| = 4.509e-02
mean(grad_8way / grad_single) = 1.9  (expect ≈ 8)


AssertionError: 

In [57]:
w = torch.zeros(4, 3, 3, 3, requires_grad=True)
x = torch.randn(2, 3, 32, 32)

In [65]:
kernels = torch.cat([
    # w,
    # torch.rot90(w, 1, (2, 3)),
    # torch.rot90(w, 2, (2, 3)),
    # torch.rot90(w, 3, (2, 3)),
    # torch.flip(w, (3,)),
    # torch.flip(w, (2,)),
    # w.transpose(2, 3),
    torch.flip(w.transpose(2, 3), (2, 3)),        
], 0)  # two copies just for the demo

out = F.conv2d(x, kernels)
out.sum().backward()

print(w.grad.shape)      # torch.Size([4, 3, 3, 3])
print(w.grad.abs().sum()) # non‑zero → gradient reached w

torch.Size([4, 3, 3, 3])
tensor(42337.3711)


In [67]:
from torch import nn

# quick sanity check
w = nn.Parameter(torch.randn(4, 1, 3, 3, requires_grad=True))
x = torch.randn(2, 1, 16, 16)

kernels = torch.cat([w, torch.rot90(w, 1, (2, 3))], 0)
y = F.conv2d(x, kernels)
y.sum().backward()

print(w.grad.shape)  # torch.Size([4, 1, 3, 3])


torch.Size([4, 1, 3, 3])
