In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

In [2]:
batch_size = 128
lr = 0.01

### torch version

In [3]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(784, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = nn.functional.relu(self.fc1(x))
    x = self.fc2(x).log_softmax(dim=1)
    return x

In [4]:
train_data = datasets.MNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transforms.ToTensor())

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [5]:
model = Net()
optim = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.NLLLoss()

In [6]:
model.train()
for e in range(5):
  with tqdm(train_dataloader) as pbar:
    for i, (x, y) in enumerate(pbar):
      x = x.view(-1, 784)
      y_hat = model(x)
      loss = loss_fn(y_hat, y)

      optim.zero_grad(set_to_none=True)
      loss.backward()
      optim.step()
  
      if i % 50 == 0:
        pbar.set_description(f"Epoch {e+1} | Loss: {loss.item():.4f}") 

Epoch 1 | Loss: 1.1189: 100%|██████████| 469/469 [00:01<00:00, 268.06it/s]
Epoch 2 | Loss: 0.5625: 100%|██████████| 469/469 [00:01<00:00, 241.73it/s]
Epoch 3 | Loss: 0.4950: 100%|██████████| 469/469 [00:01<00:00, 251.62it/s]
Epoch 4 | Loss: 0.5094: 100%|██████████| 469/469 [00:01<00:00, 265.46it/s]
Epoch 5 | Loss: 0.4668: 100%|██████████| 469/469 [00:01<00:00, 276.82it/s]


In [7]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
  for x, y in test_dataloader:
    x = x.view(-1, 784)
    preds = model(x).argmax(dim=1)
    correct += (preds == y).sum()
    total += len(y)
print(f"Accuracy: {correct/total:.4f}")

Accuracy: 0.9003


### numpy version

#### forward pass

In [8]:
class ScratchNet():
  def __init__(self, input_size, hidden_size, output_size):
    self.fc1 = np.random.randn(input_size, hidden_size) / np.sqrt(input_size)
    self.b1 = np.zeros((1, hidden_size))
    self.fc2 = np.random.randn(hidden_size, output_size) / np.sqrt(hidden_size)
    self.b2 = np.zeros((1, output_size))
    self.grads = {
      'fc1': np.zeros_like(self.fc1),
      'b1': np.zeros_like(self.b1),
      'fc2': np.zeros_like(self.fc2),
      'b2': np.zeros_like(self.b2)
    }
  
  def forward(self, x):
    self.h1 = np.maximum(0, np.dot(x, self.fc1) + self.b1)
    self.h2 = np.dot(self.h1, self.fc2) + self.b2
    self.h2 -= np.max(self.h2, axis=1, keepdims=True)
    self.softmaxed = np.exp(self.h2) / np.sum(np.exp(self.h2), axis=1, keepdims=True)
    return np.log(self.softmaxed)

  def backward(self, x, y):
    dy_hat = self.softmaxed
    dy_hat[range(y.shape[0]), y] -= 1
    dy_hat /= y.shape[0]
    
    d_h2 = dy_hat
    self.grads["fc2"] = self.h1.T @ d_h2
    self.grads["b2"] = d_h2.sum(axis=0, keepdims=True)
    d_h1 = d_h2 @ self.fc2.T
    d_relu = d_h1 * (self.h1 > 0).astype(float)
    self.grads["fc1"] = x.T @ d_relu
    self.grads["b1"] = d_relu.sum(axis=0, keepdims=True)

  def update_params(self):
    self.fc1 -= lr * self.grads["fc1"]
    self.b1 -= lr * self.grads["b1"]
    self.fc2 -= lr * self.grads["fc2"]
    self.b2 -= lr * self.grads["b2"]
  
  def zero_grads(self):
    for k in self.grads.keys():
      self.grads[k] *= 0

In [9]:
scratch_model = ScratchNet(28*28, 128, 10)

scratch_model.fc1 = model.fc1.weight.detach().numpy().T
scratch_model.b1 = model.fc1.bias.detach().numpy().reshape(1, -1)
scratch_model.fc2 = model.fc2.weight.detach().numpy().T
scratch_model.b2 = model.fc2.bias.detach().numpy().reshape(1, -1)

In [10]:
correct = 0
total = 0
for x, y in test_dataloader:
  x = x.view(-1, 784).numpy()
  y = y.numpy()
  preds = np.argmax(scratch_model.forward(x), axis=1)
  correct += (preds == y).sum()
  total += len(y)
print(f"Accuracy: {correct/total:.4f}")

Accuracy: 0.9003


#### backward pass

In [11]:
scratch_model = ScratchNet(28*28, 128, 10)

In [12]:
def nll_loss(y_hat, y):
  return -y_hat[range(y.shape[0]), y].mean()

In [13]:
for e in range(5):
  with tqdm(train_dataloader) as pbar:
    for i, (x, y) in enumerate(pbar):
      x = x.view(-1, 784).numpy()
      y = y.numpy()
      y_hat = scratch_model.forward(x)
      loss = nll_loss(y_hat, y)
      scratch_model.backward(x, y)
      scratch_model.update_params()
      scratch_model.zero_grads()

      if i % 50 == 0:
        pbar.set_description(f"Epoch {e+1} | Loss: {loss.item():.4f}") 

Epoch 1 | Loss: 0.7908: 100%|██████████| 469/469 [00:06<00:00, 77.45it/s] 
Epoch 2 | Loss: 0.5459: 100%|██████████| 469/469 [00:04<00:00, 104.73it/s]
Epoch 3 | Loss: 0.4781: 100%|██████████| 469/469 [00:05<00:00, 80.97it/s] 
Epoch 4 | Loss: 0.5928: 100%|██████████| 469/469 [00:04<00:00, 94.90it/s] 
Epoch 5 | Loss: 0.2400: 100%|██████████| 469/469 [00:05<00:00, 83.13it/s] 


In [14]:
correct = 0
total = 0
for x, y in test_dataloader:
  x = x.view(-1, 784).numpy()
  y = y.numpy()
  preds = np.argmax(scratch_model.forward(x), axis=1)
  correct += (preds == y).sum()
  total += len(y)
print(f"Accuracy: {correct/total:.4f}")

Accuracy: 0.9055
