In [1]:
%load_ext autoreload
%autoreload 2
from dl import Variable, Module
import numpy as np
from dl.functions import sum, cross_entropy_loss
from dl.modules import Convolution, Linear, ReLU, Flatten, MaxPool

import torch
import torch.nn as nn

In [2]:
def transfer_weights_to_torch(model, torch_model):
    torch_params = [p for p in torch_model.parameters()]
    params = model.parameters()
    
    assert len(torch_params) == len(params), "Parameter count mismatch"

    for torch_p, p in zip(torch_params, params):
        p = torch.tensor(p.data, dtype=torch.float32)
        if p.shape != torch_p.data.shape:
            p = p.T

        torch_p.data.copy_(p)

def compare_grads(model, torch_model):
    torch_params = [p for p in torch_model.parameters()]
    params = model.parameters()
    
    assert len(torch_params) == len(params), "Parameter count mismatch"

    for torch_p, p in zip(torch_params, params):

        # My framework stores W in Linear layer as the transpose so that the output is Y = X @ W + b.
        # This workaround fails for square weights, so come up with a better solution in the future.
        mygrad = p.grad
        if p.grad.shape != torch_p.grad.shape:
            mygrad = mygrad.T

        print(np.allclose(torch_p.grad.numpy(), mygrad, atol=1e-4, rtol=1e-3))

## Convolution Implementation 

In [3]:
# TODO

# def compare_grads(C_in, C_out, K, stride, padding, N_batch, H, W, verbose=False):
#     # Input
#     X = Variable(np.random.randn(N_batch, C_in, H, W), keep_grad=True)
#     X_torch = torch.tensor(X.data, dtype=torch.float32, requires_grad=True)

#     # Your Conv
#     conv = Convolution(C_in, C_out, K, stride=stride, padding=padding)
#     Y = conv(X)
#     z = sum(Y)
#     z.backward()

#     # PyTorch Conv
#     conv_torch = nn.Conv2d(C_in, C_out, K, stride=stride, padding=padding, bias=False)
#     W_tensor = torch.tensor(conv.W.data, dtype=torch.float32)
#     with torch.no_grad():
#         conv_torch.weight.copy_(W_tensor)
#     Y_torch = conv_torch(X_torch)
#     z_torch = torch.sum(Y_torch)
#     z_torch.backward()

#     grad_torch_W = conv_torch.weight.grad.numpy()
#     grad_mine_W = conv.W.grad()

#     grad_torch_X = X_torch.grad.numpy()
#     grad_mine_X = X.grad()

#     w_close = np.allclose(grad_mine_W, grad_torch_W, atol=1e-4, rtol=1e-3)
#     x_close = np.allclose(grad_mine_X, grad_torch_X, atol=1e-4, rtol=1e-3)

#     if not (w_close and x_close) and verbose:
#         print(f"FAILED: Cin={C_in}, Cout={C_out}, K={K}, stride={stride}, pad={padding}, size=({N_batch},{H},{W})")
#         if not w_close:
#             print("Weight grad mismatch — max diff:", np.max(np.abs(grad_mine_W - grad_torch_W)))
#         if not x_close:
#             print("Input grad mismatch — max diff:", np.max(np.abs(grad_mine_X - grad_torch_X)))

#     return w_close and x_close

# def run_all_tests():
#     paddings = [0, 1, 2, 60]
#     strides = [1, 2, 37, 90]
#     kernels = [1, 3, 5]
#     batch_sizes = [1, 4]
#     image_sizes = [(5, 5), (7, 7), (16, 16), (8, 32)]
#     channels = [(1, 1), (3, 8), (8, 3), (4, 4)]

#     total = 0
#     passed = 0

#     for pad in paddings:
#         for stride in strides:
#             for K in kernels:
#                 for (C_in, C_out) in channels:
#                     for N in batch_sizes:
#                         for (H, W) in image_sizes:
#                             if H + 2 * pad < K or W + 2 * pad < K:
#                                 continue  # Skip invalid kernel size
#                             total += 1
#                             try:
#                                 result = compare_grads(C_in, C_out, K, stride, pad, N, H, W)
#                             except Exception as e:
#                                 print(f"\n❌ Exception for: Cin={C_in}, Cout={C_out}, K={K}, stride={stride}, pad={pad}, size=({N},{H},{W})")
#                                 raise e
#                             if result:
#                                 passed += 1
#                             else:
#                                 print(f"\n❌ Failed for: Cin={C_in}, Cout={C_out}, K={K}, stride={stride}, pad={pad}, size=({N},{H},{W})")
#                                 return
#     print(f"✅ Passed {passed} / {total} tests.")

# run_all_tests()

# def test_edge_cases():
#     assert compare_grads(1, 1, 1, 1, 0, 1, 1, 1)  # Minimal case
#     assert compare_grads(3, 4, 3, 2, 4, 16, 7, 7)  # From your original test
#     assert compare_grads(8, 16, 5, 1, 2, 2, 32, 32)  # Large channels and image
#     assert compare_grads(2, 2, 3, 1, 1, 1, 5, 5)  # Small batch

# test_edge_cases()

## Small CNN

In [4]:
class MiniCNN(Module):

    def __init__(self):
        super().__init__()

        self.conv1 = Convolution(3, 4, 4, 1)
        self.relu1 = ReLU()
        self.maxpool1 = MaxPool(2, 2)
        self.flat1 = Flatten()
        self.lin1 = Linear(4*14*14, 783)
        self.relu2 = ReLU()
        self.lin2 = Linear(783, 100)

    def forward(self, X):

        X = self.conv1(X)
        X = self.relu1(X)
        X = self.maxpool1(X)
        X = self.flat1(X)
        X = self.lin1(X)
        X = self.relu2(X)
        X = self.lin2(X)

        return X
    
class MiniCNNTorch(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 4, 4, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flat1 = nn.Flatten()
        self.lin1 = nn.Linear(4*14*14, 783)
        self.relu2 = nn.ReLU()
        self.lin2 = nn.Linear(783, 100)

    def forward(self, X):
        X = self.conv1(X)
        X = self.relu1(X)
        X = self.maxpool1(X)
        X = self.flat1(X)
        X = self.lin1(X)
        X = self.relu2(X)
        X = self.lin2(X)

        return X
    
N = 16
C_in = 3
H = 32
W = 32

X = Variable(np.random.rand(N, C_in, H, W))
y = Variable(np.random.randint(0,100,size=(N,)))
torch_X = torch.tensor(X.data, dtype=torch.float32)
torch_y = torch.tensor(y.data, dtype=torch.int64)

model = MiniCNN()
torch_model = MiniCNNTorch()

transfer_weights_to_torch(model, torch_model)

Y = model(X)
z = cross_entropy_loss(Y, y)
z.backward()

torch_Y = torch_model(torch_X) 
criterion = nn.CrossEntropyLoss()
torch_z = criterion(torch_Y, torch_y)
torch_z.backward()

compare_grads(model, torch_model)

True
True
True
True
True


In [5]:
[p for p in torch_model.named_parameters()]

[('conv1.weight',
  Parameter containing:
  tensor([[[[ 0.1854,  0.0110, -0.3825,  0.1056],
            [-0.0541,  0.2143,  0.0157,  0.1009],
            [-0.0495,  0.3085,  0.1125,  0.0960],
            [ 0.1524,  0.0748, -0.3754, -0.0203]],
  
           [[ 0.0902, -0.2312,  0.0707, -0.1286],
            [ 0.3914, -0.1080, -0.2606,  0.3007],
            [ 0.1315, -0.1007,  0.0221, -0.1149],
            [ 0.0565,  0.0212,  0.3711,  0.1981]],
  
           [[ 0.3084, -0.0152, -0.1546,  0.1089],
            [-0.1067,  0.4666,  0.0616, -0.1474],
            [-0.0210,  0.1515,  0.1685,  0.0263],
            [ 0.2246,  0.2202, -0.0547, -0.0236]]],
  
  
          [[[-0.0235, -0.0415, -0.2790,  0.1413],
            [-0.0218, -0.0983, -0.3634, -0.2653],
            [ 0.2107,  0.3419,  0.1121, -0.1584],
            [ 0.1190,  0.0131, -0.2051,  0.2933]],
  
           [[ 0.0123,  0.3786,  0.0046, -0.1087],
            [ 0.0562,  0.1304,  0.1123,  0.1854],
            [-0.0229, -0.1895, -0.2647