In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.optim import lr_scheduler

In [None]:
from scripts.mnistParity import MNISTParity
from scripts.architecture import MLP
from scripts.train_utils import accuracy
from scripts.train import train_epoch, train_model, test_model
from scripts.plot_utils import plot_loss_accuracy, plotValAccuracy

In [None]:
print(torch.__version__)
print(np.__version__)

## Create Parity Data Iterator

### Vertical

In [None]:
transforms = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

In [None]:
# doesn't perform and transformation until we call the loader
trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms)

In [None]:
learn_rate = 0.05
num_epochs = 20
batch_size = 128
loss_fn = torch.nn.CrossEntropyLoss()

### For k = 1

In [None]:
k=1
model = MLP(k, "ReLU")
optimizer = torch.optim.Adadelta(model.parameters(), lr=learn_rate, weight_decay = 0.001)

trainLostList, trainAccList, valLossList, valAccList  = train_model(model, k, trainset, testset, loss_fn, optimizer, num_epochs, batch_size, validate_model = True,
                                                                     performance=accuracy, device="cuda:0", lr_scheduler=None)


In [None]:
# add bias terms --> a bit confusing now with batch_size and it's not clear how to update them at the moment
# try to have same structure with other training --> I had (2 output) with cross entropy loss, it didn't perform well (50%)
# try to have the same performance for k = 3 --> doesn't perform well at all
# check the equations one more time (DFA paper + Sebastian's paper + internet BP blog)

class MLP_Manual(torch.nn.Module):
    def __init__(self, k, device):
        super().__init__()

        self.input_dim = 28 * 28 * k
        self.hidden_dim = 512
        self.output_dim = 1 # I tried with 2 + cross entropy, it didn't perform well
        self.learning_rate = 0.001
        self.flat = torch.nn.Flatten()
        # weights
        # e.g. 784 x 512
        self.w1 = torch.randn(self.input_dim, self.hidden_dim).to(device)
        #  e.g. 512 x 1
        self.w2 = torch.randn(self.hidden_dim, self.output_dim).to(device)
        
    def sigmoid(self, s):
        return 1 / (1 + torch.exp(-s))
     
    def reLU(self, s):
        s[s < 0] = 0
        return s.float()
    
    def reLUPrime(self, s):
        s[s < 0] = 0
        s[s > 1] = 1
        return s.float()

    # Forward propagation
    def forward(self, X):
        X = self.flat(X)
        # First linear layer
        self.a1 = torch.matmul(X, self.w1)
        # First non-linearity
        self.h1 = self.reLU(self.a1)
        # Second linear layer
        self.ay = torch.matmul(self.h1, self.w2)
        # Second non-linearity
        y_hat = self.sigmoid(self.ay)
        return y_hat

    # Backward propagation
    def backward(self, X, y, y_hat):
        X = self.flat(X)
        self.e = yhat - y.reshape(len(y),1)

        self.delta_a1 = torch.matmul(self.e, self.w2.t()) * self.reLUPrime(self.a1)
            
        # Gradient descent on the weights from our 2 linear layers
        self.change_w1 = -self.learning_rate * torch.matmul(X.t(), self.delta_a1)
        self.change_w2 = -self.learning_rate * torch.matmul(self.h1.t(), self.e)
        
        self.w1 += self.change_w1
        self.w2 += self.change_w2

    def train(self, X, y):
        # Forward propagation
        y_hat = self.forward(X)
        # Backward propagation and gradient descent
        self.backward(X, y, y_hat)
        
        
def predict(nn_output: torch.Tensor):
    nn_output[nn_output > 0.5] = 1
    nn_output[nn_output < 0.5] = 0
    return nn_output.reshape(len(nn_output)).int()
    
def accuracy(nn_output: torch.Tensor, ground_truth: torch.Tensor):
    nn_out_classes = predict(nn_output)
    # produce tensor of booleans - at which position of the nn output is the correct class located?
    correct_items = (nn_out_classes == ground_truth)
    # now getting the accuracy is easy, we just operate the sum of the tensor and divide it by the number of examples
    acc = correct_items.sum().item() / nn_output.shape[0]
    return acc

In [None]:
# training loop for MLP_Manual

k=1
device = "cuda:0"
model_a = MLP_Manual(k, device)
loss_fn = torch.nn.BCELoss()
from scripts.train_utils import AverageMeter

for epoch in range(20):
    trainData = MNISTParity(trainset, k, batch_size)
    loss_meter = AverageMeter()
    performance_meter = AverageMeter()

    for X,y in trainData.loader:
        X = X.to(device)
        y = y.to(device)
        yhat = model_a(X)
        loss = loss_fn(yhat,y.reshape(len(y),1).float())
        acc = accuracy(yhat, y)
        loss_meter.update(val=loss.item(), n=X.shape[0])
        performance_meter.update(val=acc, n=X.shape[0])
        model_a.train(X,y)
  
    #print(f"Epoch {epoch+1} completed. Loss - total: {loss_meter.sum:.4f} - average: {loss_meter.avg:.4f}")   
    print(f"Epoch {epoch+1} completed. Loss - total: {loss_meter.sum:.4f} - average: {loss_meter.avg:.4f}; Performance: {performance_meter.avg:.4f}")   

In [None]:
k=1
model2 = MLP(k, "ReLU")
optimizer = torch.optim.SGD(model2.parameters(), lr=learn_rate)

trainLostList, trainAccList, valLossList, valAccList  = train_model(model2, k, trainset, testset, loss_fn, optimizer, num_epochs, batch_size, validate_model = True,
                                                                     performance=accuracy, device="cuda:0",lr = learn_rate, lr_scheduler=None, updateWManually=True)


In [None]:
plot_loss_accuracy(trainLostList,valLossList,trainAccList,valAccList,num_epochs)

### For k = 3

In [None]:
k = 3

model3 = MLP(k,"ReLU")
optimizer = torch.optim.Adadelta(model3.parameters(), lr=learn_rate, weight_decay = 0.001)

trainLostList3, trainAccList3, valLossList3, valAccList3  = train_model(model3, k, trainset, testset, loss_fn, optimizer, num_epochs, batch_size, validate_model = True,
                                                                     performance=accuracy, device="cuda:0", lr_scheduler=None)


In [None]:
plot_loss_accuracy(trainLostList3,valLossList3,trainAccList3,valAccList3,num_epochs)

In [None]:
k = 3
model4 = MLP(k, "ReLU")
optimizer = torch.optim.SGD(model4.parameters(), lr=learn_rate, weight_decay = 0.001)

trainLostList4, trainAccList4, valLossList4, valAccList4  = train_model(model4, k, trainset, testset, loss_fn, optimizer, num_epochs, batch_size, validate_model = True,
                                                                     performance=accuracy, device="cuda:0", lr_scheduler=None)


In [None]:
plot_loss_accuracy(trainLostList4,valLossList4,trainAccList4,valAccList4,num_epochs)

In [None]:
# Add Lazy methods
learn_rate = 0.05
K = 3
num_epochs = 20

fig = plt.figure()
for activation in ["ReLU", "NTK", "Gaussian features", "ReLU features", "linear features", "SGD"]:
    model = MLP(K, activation)

    if "features" in activation:
        # deactivate the first layer
        optimizer = torch.optim.Adadelta(model.layer2.parameters(), lr = learn_rate, weight_decay=0.001)
    elif "NTK" in activation:
        paramsToUpdate = list(model.layer1.parameters()) + list(model.layer2.parameters())
        optimizer = torch.optim.Adadelta(paramsToUpdate, lr = learn_rate, weight_decay=0.001)
    elif "SGD" in activation:
        optimizer = torch.optim.SGD(model.parameters(), lr = learn_rate, weight_decay=0.001)
    else:
        optimizer = torch.optim.Adadelta(model.parameters(), lr = learn_rate, weight_decay=0.001)

    print("Activation:",activation)

    trainLostList, trainAccList, valLossList, valAccList  = train_model(model, K, trainset, testset, loss_fn, optimizer, num_epochs, 
                                                                        batch_size, validate_model = True, performance=accuracy, 
                                                                        device="cuda:0", lr_scheduler=None)

    plotValAccuracy(valAccList,num_epochs, activation, K)

fig.savefig(str(K) + "valAccuracy.png")
plt.show()
dataset = MNISTParity(trainset, K, 128)
dataset.plotRandomData()

# just need to find good lr and weight_decay values for lazy methods to have more similar plots to paper
