# Imports

In [16]:
import numpy as np

# pytorch imports
import torch
import torch.nn as torch_nn
import torch.nn.functional as TorchF
from torch import Tensor as TorchTensor

# mini_torch imports
from nn import Linear as MiniLinear, Module as MiniModule, MSELoss as MiniMSELoss
from optim import Adam as MiniAdam
from tensor import Tensor as MiniTensor
import nn as mini_nn
import nn.functional as MiniF

# helper functions
def gradients_are_equal(torch_tensor: TorchTensor, mini_tensor: MiniTensor):
    print(np.all(torch_tensor.grad.detach().numpy()==mini_tensor.grad))

## 1. Build a Simple MLP (with MiniTorch API)

In [17]:
class MLP(MiniModule):
    def __init__(self):
        super().__init__()
        self.linear_1 = MiniLinear(3, 3)
        self.linear_2 = MiniLinear(3, 6)
        self.linear_3 = MiniLinear(6, 1)

    def forward(self, x):
        x = MiniF.tanh(self.linear_1(x))
        x = MiniF.tanh(self.linear_2(x))
        x = MiniF.tanh(self.linear_3(x))
        return x
    
model = MLP()
optim = MiniAdam(model.parameters())
criterion = MiniMSELoss()


X = MiniTensor([[2.0, 3.0, -1.0], [3.0, -1.0, 0.5], [0.5, 1.0, 1.0], [1.0, 1.0, -1.0]])
Y = MiniTensor([[1.0], [-1.0], [-1.0], [1.0]])

epochs = 100

for epoch in range(epochs):
    y_hat = model(X)
    loss = criterion(y_hat, Y)
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(f"Epoch {epoch}, Loss: {loss.data}")
    
print('\nModel architecture: \n',model)
print('\n state dict:\n' ,model.state_dict())

Epoch 0, Loss: 1.9937796592712402
Epoch 1, Loss: 1.9923006296157837
Epoch 2, Loss: 1.9903831481933594
Epoch 3, Loss: 1.9878700971603394
Epoch 4, Loss: 1.984545350074768
Epoch 5, Loss: 1.9801126718521118
Epoch 6, Loss: 1.9741649627685547
Epoch 7, Loss: 1.9661449193954468
Epoch 8, Loss: 1.9552971124649048
Epoch 9, Loss: 1.9406157732009888
Epoch 10, Loss: 1.9207981824874878
Epoch 11, Loss: 1.8942272663116455
Epoch 12, Loss: 1.8590224981307983
Epoch 13, Loss: 1.8132188320159912
Epoch 14, Loss: 1.7551205158233643
Epoch 15, Loss: 1.6837944984436035
Epoch 16, Loss: 1.5995620489120483
Epoch 17, Loss: 1.5047229528427124
Epoch 18, Loss: 1.4042599201202393
Epoch 19, Loss: 1.3052639961242676
Epoch 20, Loss: 1.2151246070861816
Epoch 21, Loss: 1.139133095741272
Epoch 22, Loss: 1.079049825668335
Epoch 23, Loss: 1.0334875583648682
Epoch 24, Loss: 0.9993932247161865
Epoch 25, Loss: 0.9734480381011963
Epoch 26, Loss: 0.9528200030326843
Epoch 27, Loss: 0.935355544090271
Epoch 28, Loss: 0.9194931983947754

## 2. MiniTorch vs. PyTorch: Validation 

### Example 1: Simple Linear Case

In [18]:
mse_loss_torch = torch_nn.MSELoss(reduction='sum')
mse_loss_mini = mini_nn.MSELoss(reduction='sum')


a_torch = torch.tensor([[1, 2], [3, 4]], requires_grad=True, dtype=torch.float32)
b_torch = torch.tensor([[5, 6], [7, 8]], requires_grad=True, dtype=torch.float32)
y_torch = torch.tensor([[9, 2], [3, -1]], requires_grad=True, dtype=torch.float32)

y_hat_torch = TorchF.linear(a_torch, b_torch)
loss_torch = mse_loss_torch(y_hat_torch, y_torch)
loss_torch.backward()

a_mini = MiniTensor([[1, 2], [3, 4]])
b_mini = MiniTensor([[5, 6], [7, 8]])
y_mini = MiniTensor([[9, 2], [3, -1]])

y_hat_mini = MiniF.linear(a_mini, b_mini)
loss_mini = mse_loss_mini(y_hat_mini, y_mini)
loss_mini.backward()


gradients_are_equal(a_torch, a_mini) # compare d_loss/d_a
gradients_are_equal(b_torch, b_mini) # compare d_loss/d_b

True
True


### Example 2: More sophisticated case

In [19]:
val_a = [
    [1, 2, 3, 4],
    [4, 5, 6, 10],
    [9, -1, 1, 1]
]

val_b = [
    [1, -2, 3, 4],
    [4, 4, -6, 10],
    [1, -1, 0, 1]
]
val_q = [
    [2, 2, 2],
    [2, 2, 2],
    [2, 2, 2],
    [2, 2, 2],
    [2, 2, 2]
]

a_torch = torch.tensor(val_a, requires_grad=True, dtype=torch.float32)
b_torch = torch.tensor(val_b, requires_grad=True, dtype=torch.float32)
q_torch = torch.tensor(val_q, requires_grad=True, dtype=torch.float32)
c_torch = TorchF.linear(a_torch, b_torch); c_torch.retain_grad()
d_torch = TorchF.linear(c_torch, q_torch); d_torch.retain_grad()
e_torch = d_torch.sum(); e_torch.retain_grad()
e_torch.backward()


a_mini = MiniTensor(val_a)
b_mini = MiniTensor(val_b)
c_mini = MiniF.linear(a_mini, b_mini)
q_mini = MiniTensor(val_q)
d_mini = MiniF.linear(c_mini, q_mini)
e_mini = d_mini.sum()
e_mini.backward()



gradients_are_equal(a_torch, a_mini) # compare d_e/d_a
gradients_are_equal(b_torch, b_mini) # compare d_e/d_b
gradients_are_equal(c_torch, c_mini) # compare d_e/d_c
gradients_are_equal(d_torch, d_mini) # compare d_e/d_d
gradients_are_equal(e_torch, e_mini) # compare d_e/d_e
gradients_are_equal(q_torch, q_mini) # compare d_e/d_q


True
True
True
True
True
True


### Example 3: High Dimensional Case

In [24]:
val_a = [
    [
        [
            [1, 2, 3, 4]
        ]
    ],
    [
        [
            [1, 9, -1, 4]
        ]
    ],
    [
        [
            [1, 2, 3, -1]
        ]
    ]
]

val_b = [
    [1,   9,   3,   4],
    [0,   1,  -1, -11],
    [1,  21,  11,  -1]
]

a_torch = torch.tensor(val_a, requires_grad=True, dtype=torch.float32)
b_torch = torch.tensor(val_b, requires_grad=True, dtype=torch.float32)
c_torch = TorchF.linear(a_torch, b_torch)
d_torch = c_torch.sum(); d_torch.retain_grad()
d_torch.backward()

a_mini = MiniTensor(val_a)
b_mini = MiniTensor(val_b)
c_mini = MiniF.linear(a_mini, b_mini)
d_mini = c_mini.sum()
d_mini.backward()

gradients_are_equal(a_torch, a_mini) # compare d_d/d_a
gradients_are_equal(b_torch, b_mini) # compare d_d/d_b

True
True
