## TODO:
- Add bias terms
    * Here need to check how to update these terms
- Have plots Dogan SGD vs Pytorch SGD on k=1 and k=3
- Try to switch to DFA from here and check the performance results.

In [None]:
import numpy as np
import math
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose

In [None]:
from scripts.train import *

In [None]:
print(torch.__version__)
print(np.__version__)

## Create Parity Data Iterator

In [None]:
transforms = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

In [None]:
# doesn't perform and transformation until we call the loader
trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms)

## MLP Scratch

In [None]:
learn_rate = 0.05
num_epochs = 20
batch_size = 128
loss_fn = torch.nn.BCELoss() # or BCELoss with sigmoid activation in last layer

In [None]:
class MLP_Manual(torch.nn.Module):
    def __init__(self, k, device, loss_type = "Cross Entropy"):
        super().__init__()

        self.input_dim = 28 * 28 * k
        self.hidden_dim = 512
        if loss_type == "Cross Entropy":
            self.output_dim = 2
        else:                           # BCE case
            self.output_dim = 1
        self.learning_rate = 0.001
        self.flat = torch.nn.Flatten() # when input comes as 28x28, this'll convert to 784
        # WEIGHTS
        # initialize the weights as pytorch does by default --> IT DIVERGES and perform worse (90%) for k=1
        # e.g. 784 x 512
        self.w1 = torch.zeros(self.input_dim, self.hidden_dim).to(device)
        stdv1 = 1. / math.sqrt(self.w1.size(1))
        self.w1.uniform_(-stdv1, +stdv1)
        #  e.g. 512 x 1
        self.w2 = torch.zeros(self.hidden_dim, self.output_dim).to(device)
        stdv2 = 1. / math.sqrt(self.w2.size(1))
        self.w2.uniform_(-stdv2, +stdv2)

    @staticmethod
    def softmax(x):
        maxes = torch.max(x, 1, keepdim=True)[0]
        x_exp = torch.exp(x-maxes)
        x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
        return x_exp/x_exp_sum

    @staticmethod
    def sigmoid(s):
        return 1 / (1 + torch.exp(-s))

    @staticmethod
    def reLU(s):
        s[s < 0] = 0
        return s.float()

    @staticmethod
    def reLUPrime(s):
        s[s < 0] = 0
        s[s > 0] = 1
        return s.float()

    # Forward propagation
    def forward(self, X):

        X = self.flat(X)
        # a_k = W_k @ h_{k-1} + b_k, h_k = f(a_k) where h_0 = X and f is the non linearity, a_2 = y^
        self.a1 = torch.matmul(X, self.w1) # e.g. k=1 --> 128x784 @ 784x512
        self.h1 = self.reLU(self.a1)       # f is the reLU
        self.a2 = torch.matmul(self.h1, self.w2) #

        if loss_type == "Cross Entropy":
            y_hat = torch.nn.functional.softmax(self.a2, dim=1)
        else:
            y_hat = self.sigmoid(self.a2)

        return y_hat # some loss functions handle output layer non-linearity

    # Backward propagation
    def backward(self, X, y, y_hat):
        X = self.flat(X)
        # gradients of W2 --> dBCE/dW2 = dE/dy^.dy^/da2. da2/dW2 = (y^ - y) h1
        if loss_type == "Cross Entropy":
            self.e = y_hat - torch.nn.functional.one_hot(y) # e - 128x2, h1.t - 512,128 for k=1

        else:
            self.e = y_hat - y.reshape(len(y),1) # e - 128x1, h1.t - 512,128 for k=1

        self.w2_grads = torch.matmul(self.h1.t(), self.e)
        # gradients of W1 --> dBCE/dW1 = dE/dh1 . dh1/da1 . da1/dW1
        # where dE/dh1 = dE/dy^ . dy^/da2 . da2/dh1
        self.dBCE_da1 = torch.matmul(self.e, self.w2.t()) * self.reLUPrime(self.a1) # e - 128x1, w2.t - 1,512 , a1 - 128,512
        self.w1_grads = torch.matmul(X.t(), self.dBCE_da1) # x.t - 784,128, dBCE_da1 128,512

        # Implement SGD here
        self.w1 -= self.learning_rate * self.w1_grads
        self.w2 -= self.learning_rate * self.w2_grads

    def train(self, X, y):
        # Forward propagation
        y_hat = self.forward(X)
        # Backward propagation and gradient descent
        self.backward(X, y, y_hat)

In [None]:
k=3
device="cpu"
loss_type = "Cross Entropy"
model = MLP_Manual(k, device, loss_type)

trainLostList, trainAccList, valLossList, valAccList  = train_model_manually(model, k, trainset, testset, loss_type, loss_fn, num_epochs,
                                                                             batch_size, validate_model = True, device=device)
