## TODO:
- Add bias terms
    * Here need to check how to update these terms
- Update the print statement, just need to write another functionn like test_model
- Put functions into script folder
- Have plots Dogan SGD vs Pytorch SGD on k=1 and k=3
- Try to switch to DFA from here and check the performance results.

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import math
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.optim import lr_scheduler

In [2]:
from scripts.mnistParity import MNISTParity
from scripts.architecture import MLP
from scripts.train_utils import AverageMeter, accuracy
from scripts.train import train_epoch, train_model, test_model
from scripts.plot_utils import plot_loss_accuracy, plotValAccuracy

In [3]:
print(torch.__version__)
print(np.__version__)

1.9.0
1.20.3


## Create Parity Data Iterator

In [4]:
transforms = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

In [5]:
# doesn't perform and transformation until we call the loader
trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
learn_rate = 0.05
num_epochs = 20
batch_size = 128
loss_fn = torch.nn.BCELoss()

## MLP Scratch

In [7]:
class MLP_Manual(torch.nn.Module):
    def __init__(self, k, device, loss_type = "Cross Entropy"):
        super().__init__()

        self.input_dim = 28 * 28 * k
        self.hidden_dim = 512
        if loss_type == "Cross Entropy":
            self.output_dim = 2
        else:                           # BCE case
            self.output_dim = 1
        self.learning_rate = 0.001
        self.flat = torch.nn.Flatten() # when input comes as 28x28, this'll convert to 784
        # WEIGHTS
        # initialize the weights as pytorch does by default --> IT DIVERGES and perform worse (90%) for k=1
        # e.g. 784 x 512
        self.w1 = torch.zeros(self.input_dim, self.hidden_dim).to(device)
        stdv1 = 1. / math.sqrt(self.w1.size(1))
        self.w1.uniform_(-stdv1, +stdv1)
        #  e.g. 512 x 1
        self.w2 = torch.zeros(self.hidden_dim, self.output_dim).to(device)
        stdv2 = 1. / math.sqrt(self.w2.size(1))
        self.w2.uniform_(-stdv2, +stdv2)
        
    def softmax(self,x):    
        maxes = torch.max(x, 1, keepdim=True)[0]
        x_exp = torch.exp(x-maxes)
        x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
        return x_exp/x_exp_sum
        
    def sigmoid(self, s):
        return 1 / (1 + torch.exp(-s))
     
    def reLU(self, s):
        s[s < 0] = 0
        return s.float()
    
    def reLUPrime(self, s):
        s[s < 0] = 0
        s[s > 0] = 1
        return s.float()

    # Forward propagation
    def forward(self, X):
        
        X = self.flat(X)
        # a_k = W_k @ h_{k-1} + b_k, h_k = f(a_k) where h_0 = X and f is the non linearity, a_2 = y^
        self.a1 = torch.matmul(X, self.w1) # e.g. k=1 --> 128x784 @ 784x512
        self.h1 = self.reLU(self.a1)       # f is the reLU
        self.a2 = torch.matmul(self.h1, self.w2) #
        
        if loss_type == "Cross Entropy":
            y_hat = torch.nn.functional.softmax(self.a2, dim=1)
        else:
            y_hat = self.sigmoid(self.a2)
            
        return y_hat # some loss functions handle output layer non-linearity

    # Backward propagation
    def backward(self, X, y, y_hat):
        X = self.flat(X)
        # gradients of W2 --> dBCE/dW2 = dE/dy^.dy^/da2. da2/dW2 = (y^ - y) h1
        if loss_type == "Cross Entropy":
            self.e = y_hat - torch.nn.functional.one_hot(y) # e - 128x2, h1.t - 512,128 for k=1
        else:
            self.e = y_hat - y.reshape(len(y),1) # e - 128x1, h1.t - 512,128 for k=1
            
        self.w2_grads = torch.matmul(self.h1.t(), self.e)
        # gradients of W1 --> dBCE/dW1 = dE/dh1 . dh1/da1 . da1/dW1
        # where dE/dh1 = dE/dy^ . dy^/da2 . da2/dh1
        self.dBCE_da1 = torch.matmul(self.e, self.w2.t()) * self.reLUPrime(self.a1) # e - 128x1, w2.t - 1,512 , a1 - 128,512
        self.w1_grads = torch.matmul(X.t(), self.dBCE_da1) # x.t - 784,128, dBCE_da1 128,512
        
        # Implement SGD here
        self.w1 -= self.learning_rate * self.w1_grads
        self.w2 -= self.learning_rate * self.w2_grads

    def train(self, X, y_hat):
        # Forward propagation
        y_hat = self.forward(X)
        # Backward propagation and gradient descent
        self.backward(X, y, y_hat)
        
        
def predict(nn_output: torch.Tensor):
    nn_output[nn_output > 0.5] = 1
    nn_output[nn_output < 0.5] = 0
    return nn_output.reshape(len(nn_output)).int()

def predict2(nn_output: torch.Tensor):
    return torch.argmax(nn_output, dim=1)
    
def accuracy(nn_output: torch.Tensor, ground_truth: torch.Tensor, loss_type = "Cross Entropy"):
    # nn_out_classes = torch.argmax(nn_output, dim=1)
    if loss_type == "Cross Entropy":
        nn_out_classes = predict2(nn_output)
    else:
        nn_out_classes = predict(nn_output)
    # nn_out_classes = predict2(nn_output)
    # produce tensor of booleans - at which position of the nn output is the correct class located?
    correct_items = (nn_out_classes == ground_truth)
    # now getting the accuracy is easy, we just operate the sum of the tensor and divide it by the number of examples
    acc = correct_items.sum().item() / nn_output.shape[0]
    return acc

In [11]:
# training loop for MLP_Manual

k=3
device = "cpu"
loss_type = "Binary Cross Entropy"
model_a = MLP_Manual(k, device, loss_type)
loss_fn = torch.nn.BCELoss() # or BCELoss with sigmoid activation in last layer

for epoch in range(20):
    trainData = MNISTParity(trainset, k, batch_size)
    loss_meter = AverageMeter()
    performance_meter = AverageMeter()

    for X,y in trainData.loader:
        X = X.to(device)
        y = y.to(device)
        
        y_hat = model_a(X)
        
        if loss_type == "Cross Entropy":
            loss = torch.nn.functional.cross_entropy(y_hat,y)
        else:
            loss = loss_fn(y_hat,y.reshape(len(y),1).float())
            
        acc = accuracy(y_hat, y, loss_type)
        loss_meter.update(val=loss, n=X.shape[0])
        performance_meter.update(val=acc, n=X.shape[0])
        model_a.train(X,y)
     
    print(f"Epoch {epoch+1} completed. Loss - total: {loss_meter.sum:.4f} - average: {loss_meter.avg:.4f}; Performance: {performance_meter.avg:.4f}")   

Epoch 1 completed. Loss - total: 53879.7969 - average: 0.8980; Performance: 0.5134
Epoch 2 completed. Loss - total: 40136.3555 - average: 0.6689; Performance: 0.5678
Epoch 3 completed. Loss - total: 37329.0625 - average: 0.6222; Performance: 0.6312
Epoch 4 completed. Loss - total: 35150.3516 - average: 0.5858; Performance: 0.6679
Epoch 5 completed. Loss - total: 34091.0234 - average: 0.5682; Performance: 0.6832
Epoch 6 completed. Loss - total: 33533.4023 - average: 0.5589; Performance: 0.6909
Epoch 7 completed. Loss - total: 33035.7617 - average: 0.5506; Performance: 0.6999
Epoch 8 completed. Loss - total: 32547.7207 - average: 0.5425; Performance: 0.7019
Epoch 9 completed. Loss - total: 32103.4668 - average: 0.5351; Performance: 0.7091
Epoch 10 completed. Loss - total: 31949.3457 - average: 0.5325; Performance: 0.7105
Epoch 11 completed. Loss - total: 31541.4707 - average: 0.5257; Performance: 0.7186
Epoch 12 completed. Loss - total: 31412.9668 - average: 0.5235; Performance: 0.7218
E