In [1]:
import numpy as np
import torch

import npgrad as npg

In [2]:
BATCH_SIZE = 16
N_CLASSES = 10
H = W = 64

In [3]:
class NpgNet(npg.nn.Module):
    def __init__(self) -> None:
        from npgrad import nn

        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 8, 5, 2, 2, 1, bias=False)
        self.pool1 = nn.AvgPool2d(2, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 5, 2, 2, 1, bias=False)
        self.pool2 = nn.MaxPool2d(4, padding=2)
        self.linear1 = nn.Linear(144, 64, bias=False)
        self.linear2 = nn.Linear(64, N_CLASSES, bias=False)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.reshape((len(x), -1))
        # print(x.shape)
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x


class TorchNet(torch.nn.Module):
    def __init__(self) -> None:
        from torch import nn

        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(3, 8, 5, 2, 2, 1, bias=False)
        self.pool1 = nn.AvgPool2d(2, padding=1)
        self.conv2 = nn.Conv2d(8, 16, 5, 2, 2, 1, bias=False)
        self.pool2 = nn.MaxPool2d(4, padding=2)
        self.linear1 = nn.Linear(144, 64, bias=False)
        self.linear2 = nn.Linear(64, N_CLASSES, bias=False)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.reshape((len(x), -1))
        # print(x.shape)
        x = self.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [4]:
def copy_weights(weights: tuple[tuple[npg.Array, torch.Tensor], ...]) -> None:
    for a, t in weights:
        data = np.array(t.detach().numpy())
        a.data = data if len(a) == len(data) else data.T


def compute_npg(data: np.ndarray, lbls: np.ndarray, model: NpgNet) -> npg.Array:
    out = model(data)
    loss = npg.nn.functional.cross_entropy(out, lbls).mean()
    loss.backward()
    return out


def compute_torch(data: np.ndarray, lbls: np.ndarray, model: TorchNet) -> torch.Tensor:
    out = model(torch.from_numpy(data))
    loss = torch.nn.functional.cross_entropy(out, torch.from_numpy(lbls))
    loss.backward()
    return out


def max_diff(a: np.ndarray, b: np.ndarray) -> float:
    a = a if a.shape == b.shape else a.T
    assert a.shape == b.shape
    return np.max(np.abs(a - b))

In [5]:
data = np.random.rand(BATCH_SIZE, 3, H, W).astype(np.float32)
lbls = np.random.randint(0, N_CLASSES, BATCH_SIZE)

npg_model = NpgNet()
torch_model = TorchNet()
models = npg_model, torch_model
weights = (
    tuple(m.conv1.weight for m in models),
    tuple(m.conv2.weight for m in models),
    tuple(m.linear1.weight for m in models),
    tuple(m.linear2.weight for m in models),
)
copy_weights(weights)  # type: ignore

lr = 1e-2
npg_optim = npg.optim.SGD(npg_model.parameters(), lr)
torch_optim = torch.optim.SGD(torch_model.parameters(), lr)

for i in range(100):
    npg_optim.zero_grad()
    torch_optim.zero_grad()
    npg_out = compute_npg(data, lbls, npg_model)
    torch_out = compute_torch(data, lbls, torch_model)
    npg_optim.step()
    torch_optim.step()

    if (i + 1) % 20 == 0:
        # fmt: off
        print(f"\nStep {i + 1}:")
        print(f"Max out diff: {max_diff(npg_out.data, torch_out.detach().numpy())}")
        print(f"Max weights data diff: {max(max_diff(a.data, t.detach().numpy()) for a, t in weights)}")
        print(f"Max weights grad diff: {max(max_diff(a.grad, t.grad.numpy()) for a, t in weights)}")
        # fmt: on


Step 20:
Max out diff: 6.05359673500061e-09
Max weights data diff: 1.4901161193847656e-08
Max weights grad diff: 1.3969838619232178e-08

Step 40:
Max out diff: 8.381903171539307e-09
Max weights data diff: 1.4901161193847656e-08
Max weights grad diff: 2.1886080503463745e-08

Step 60:
Max out diff: 6.984919309616089e-09
Max weights data diff: 1.4901161193847656e-08
Max weights grad diff: 1.862645149230957e-08

Step 80:
Max out diff: 5.587935447692871e-09
Max weights data diff: 2.2351741790771484e-08
Max weights grad diff: 2.3283064365386963e-08

Step 100:
Max out diff: 7.916241884231567e-09
Max weights data diff: 2.2351741790771484e-08
Max weights grad diff: 4.0978193283081055e-08
