## TODO:
- Might need to use BCE in default pytorch implementation
- Have plots Dogan SGD vs Pytorch SGD on k=1 and k=3
- Try to switch to DFA from here and check the performance results.

In [1]:
import numpy as np
import math
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose

In [2]:
from scripts.train import *

In [3]:
print(torch.__version__)
print(np.__version__)

1.9.0
1.20.3


## Create Parity Data Iterator

In [4]:
transforms = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

In [5]:
# doesn't perform and transformation until we call the loader
trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## MLP Scratch

In [6]:
learn_rate = 0.05
num_epochs = 20
batch_size = 128
loss_fn = torch.nn.BCELoss() # or BCELoss with sigmoid activation in last layer

In [7]:
class MLP_Manual(torch.nn.Module):
    def __init__(self, k, device, batch_size, loss_type = "Cross Entropy"):
        super().__init__()

        self.batch_size = batch_size
        self.input_dim = 28 * 28 * k
        self.hidden_dim = 512
        if loss_type == "Cross Entropy":
            self.output_dim = 2
        else:                           # BCE case
            self.output_dim = 1
        self.learning_rate = 0.001
        self.flat = torch.nn.Flatten() # when input comes as 28x28, this'll convert to 784
        # WEIGHTS
        # initialize the weights as pytorch does by default --> IT DIVERGES and perform worse (90%) for k=1
        # e.g. 784 x 512
        self.w1 = torch.empty(self.input_dim, self.hidden_dim).to(device)
        stdv1 = 1. / math.sqrt(self.w1.size(1))
        self.w1.uniform_(-stdv1, +stdv1)
        #  e.g. 512 x 1
        self.w2 = torch.empty(self.hidden_dim, self.output_dim).to(device)
        stdv2 = 1. / math.sqrt(self.w2.size(1))
        self.w2.uniform_(-stdv2, +stdv2)

        # BIASES
        self.b1 = torch.empty(batch_size, self.hidden_dim).to(device)
        self.b1.uniform_(-stdv1, stdv1)
        self.b2 = torch.empty(batch_size, self.output_dim).to(device)
        self.b2.uniform_(-stdv1, stdv1)


    @staticmethod
    def softmax(x):
        maxes = torch.max(x, 1, keepdim=True)[0]
        x_exp = torch.exp(x-maxes)
        x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
        return x_exp/x_exp_sum

    @staticmethod
    def sigmoid(s):
        return 1 / (1 + torch.exp(-s))

    @staticmethod
    def reLU(s):
        s[s < 0] = 0
        return s.float()

    @staticmethod
    def reLUPrime(s):
        s[s < 0] = 0
        s[s > 0] = 1
        return s.float()

    # Forward propagation
    def forward(self, X):

        X = self.flat(X)
        # batch_size changes at the end of the apoch from 128 to 96, this spawned a problem in calculations 
        self.dynamic_batch_size =  X.shape[0]  
        # a_k = W_k @ h_{k-1} + b_k, h_k = f(a_k) where h_0 = X and f is the non linearity, a_2 = y^
        self.a1 = torch.matmul(X, self.w1) + self.b1[:self.dynamic_batch_size, :] # e.g. k=1 --> 128x784 @ 784x512 + 128x512 where 128 is batch_size (X.shape[1])
        self.h1 = self.reLU(self.a1)       # f is the reLU
        self.a2 = torch.matmul(self.h1, self.w2) + self.b2[:self.dynamic_batch_size, :]

        if loss_type == "Cross Entropy":
            y_hat = torch.nn.functional.softmax(self.a2, dim=1)
        else:
            y_hat = self.sigmoid(self.a2)

        return y_hat # some loss functions handle output layer non-linearity

    # Backward propagation
    def backward(self, X, y, y_hat):
        X = self.flat(X)
        # gradients of W2 --> dBCE/dW2 = dE/dy^.dy^/da2. da2/dW2 = (y^ - y) h1
        if loss_type == "Cross Entropy":
            self.e = y_hat - torch.nn.functional.one_hot(y) # e - 128x2, h1.t - 512,128 for k=1
        else:
            self.e = y_hat - y.reshape(len(y),1) # e - 128x1, h1.t - 512,128 for k=1

        self.w2_grads = torch.matmul(self.h1.t(), self.e)
        # gradients of W1 --> dBCE/dW1 = dE/dh1 . dh1/da1 . da1/dW1
        # where dE/dh1 = dE/dy^ . dy^/da2 . da2/dh1
        self.dBCE_da1 = torch.matmul(self.e, self.w2.t()) * self.reLUPrime(self.a1) # e - 128x1, w2.t - 1,512 , a1 - 128,512
        self.w1_grads = torch.matmul(X.t(), self.dBCE_da1) # x.t - 784,128, dBCE_da1 128,512
        # gradients of b2 --> dBCE/db2 = dBCE/dy^. dy^/da2. da2/db2 = (y^-y)*1
        self.b2_grads = self.e[:self.dynamic_batch_size, :]
        # gradients of b1 --> dBCE/db1 = dBCE/dh1. dh1/da1. da1/db1
        # where dBCE/dh1 = dBCE/dy^ . dy^/da2 . da2/dh1
        self.b1_grads = self.dBCE_da1[:self.dynamic_batch_size, :]

        # Implement SGD here
        self.w1 -= self.learning_rate * self.w1_grads
        self.w2 -= self.learning_rate * self.w2_grads
        self.b1[:self.dynamic_batch_size, :] -= self.learning_rate * self.b1_grads
        self.b2[:self.dynamic_batch_size, :] -= self.learning_rate * self.b2_grads

    def train(self, X, y):
        # Forward propagation
        y_hat = self.forward(X)
        # Backward propagation and gradient descent
        self.backward(X, y, y_hat)

In [11]:
k=3
device="cuda:0"
loss_type = "Cross Entropy"
model = MLP_Manual(k, device,batch_size, loss_type)

trainLostList, trainAccList, valLossList, valAccList  = train_model_manually(model, k, trainset, testset, loss_type, loss_fn, num_epochs,
                                                                             batch_size, validate_model = True, device=device)


Epoch 1 completed. Loss - total: 41628.5430 - average: 0.6938; Performance: 0.5039
TESTING - loss 6898.647859573364 - performance 0.5327
Epoch 2 completed. Loss - total: 40316.6055 - average: 0.6719; Performance: 0.5833
TESTING - loss 6494.511209487915 - performance 0.6229
Epoch 3 completed. Loss - total: 37967.9609 - average: 0.6328; Performance: 0.6488
TESTING - loss 6239.996013641357 - performance 0.6605
Epoch 4 completed. Loss - total: 36971.1406 - average: 0.6162; Performance: 0.6716
TESTING - loss 6077.1282777786255 - performance 0.6840
Epoch 5 completed. Loss - total: 36556.5547 - average: 0.6093; Performance: 0.6803
TESTING - loss 6079.064363479614 - performance 0.6852
Epoch 6 completed. Loss - total: 36224.9570 - average: 0.6037; Performance: 0.6892
TESTING - loss 5963.65124797821 - performance 0.7008
Epoch 7 completed. Loss - total: 35817.3750 - average: 0.5970; Performance: 0.6979
TESTING - loss 5941.411116600037 - performance 0.7083
Epoch 8 completed. Loss - total: 35683.08