In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose

In [None]:
from scripts.architecture import MLP, MLPManual
from scripts.train import *
from scripts.plot_utils import plot_loss_accuracy, plotValAccuracy

In [None]:
print(torch.__version__)
print(np.__version__)

## Create Parity Data Iterator

In [None]:
transforms = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,))
])

In [None]:
# doesn't perform and transformation until we call the loader
trainset = torchvision.datasets.MNIST(root='data', train=True, download=True, transform=transforms)
testset = torchvision.datasets.MNIST(root='data', train=False, download=True, transform=transforms)

In [None]:
learn_rate = 0.05
num_epochs = 20
batch_size = 128
loss_type = "Binary Cross Entropy"
loss_fn = torch.nn.BCELoss()

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

### For k = 1

In [None]:
k=1
model = MLP(k, "ReLU", loss_type)
optimizer = torch.optim.Adadelta(model.parameters(), lr=learn_rate, weight_decay = 0.001)

trainLostList_Ada1, trainAccList_Ada1, valLossList_Ada1, valAccList_Ada1  = train_model(model, k, trainset, testset, loss_type, loss_fn, optimizer, num_epochs, batch_size, validate_model = True, performance=accuracy, device=device, lr_scheduler=None, updateWManually=False)


In [None]:
plot_loss_accuracy(trainLostList_Ada1,valLossList_Ada1,trainAccList_Ada1,valAccList_Ada1,num_epochs)

In [None]:
k=1
model2 = MLP(k, "ReLU", loss_type)
optimizer = torch.optim.SGD(model2.parameters(), lr=learn_rate, weight_decay=0.001)

trainLostList_sgd1, trainAccList_sgd1, valLossList_sgd1, valAccList_sgd1  = train_model(model2, k, trainset, testset, loss_type, loss_fn, optimizer, num_epochs, batch_size, validate_model = True, performance=accuracy, device=device,lr = learn_rate, lr_scheduler=None, updateWManually=False)


In [None]:
plot_loss_accuracy(trainLostList_sgd1,valLossList_sgd1,trainAccList_sgd1,valAccList_sgd1,num_epochs)

In [None]:
k=1
modelManual = MLPManual(k, learn_rate, loss_type, False)

trainLostList_sgd1_scratch, trainAccList_sgd1_scratch, \
valLossList_sgd1_scratch, valAccList_sgd1_scratch  = train_model_manually(modelManual, k, trainset, testset,                                                                                                                                        loss_type, loss_fn, num_epochs, batch_size, validate_model = True,
                                                                          device=device)

In [None]:
plt.figure(figsize=(15,8))
plt.ylim(0.5,1)
plt.plot(valAccList_sgd1, label="SGD")
plt.plot(valAccList_Ada1, label="Adadelta")
plt.plot(valAccList_sgd1_scratch, label= "SGD Scratch")
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.legend();


### For k = 3

In [None]:
k = 3
model3 = MLP(k,"ReLU", loss_type)
optimizer = torch.optim.Adadelta(model3.parameters(), lr=learn_rate, weight_decay=0.001)

trainLostList_Ada3, trainAccList_Ada3, \
valLossList_Ada3, valAccList_Ada3  = train_model(model3, k, trainset, testset, loss_type, loss_fn, optimizer, num_epochs, batch_size, validate_model = True, performance=accuracy, device=device, lr_scheduler=None)


In [None]:
plot_loss_accuracy(trainLostList_Ada3,valLossList_Ada3,trainAccList_Ada3,valAccList_Ada3,num_epochs)

In [None]:
model4 = MLP(k, "ReLU", loss_type)
optimizer = torch.optim.SGD(model4.parameters(), lr=learn_rate, weight_decay=0.001)

trainLostList_sgd3, trainAccList_sgd3, valLossList_sgd3, valAccList_sgd3  = train_model(model4, k, trainset, testset, loss_type, loss_fn, optimizer, num_epochs, batch_size, validate_model = True, performance=accuracy, device=device, lr_scheduler=None)

In [None]:
plot_loss_accuracy(trainLostList_sgd3, valLossList_sgd3, trainAccList_sgd3, valAccList_sgd3, num_epochs)

In [None]:
k=3
modelManual3 = MLPManual(k, learn_rate, loss_type, False)
trainLostList_sgd3_scratch, trainAccList_sgd3_scratch, \
valLossList_sgd3_scratch, valAccList_sgd3_scratch  = train_model_manually(modelManual3, k, trainset, testset, loss_type, loss_fn, num_epochs,
                                                                          batch_size, validate_model = True, device=device)

In [None]:
plt.figure(figsize=(15,8))
plt.ylim(0.4,1)
plt.plot(valAccList_sgd3, label="SGD")
plt.plot(valAccList_Ada3, label="Adadelta")
plt.plot(valAccList_sgd3_scratch, label= "SGD Scratch")
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.legend();

### Try with the same weights

In [None]:
k = 3
modelx = MLP(k, "ReLU", loss_type).to(device)

w1 = copy.deepcopy(modelx.state_dict()["layer1.weight"]).to(device)
w2 = copy.deepcopy(modelx.state_dict()["layer2.weight"]).to(device)

optimizer = torch.optim.SGD(modelx.parameters(), lr=learn_rate)

trainLostList_sgd3_w, trainAccList_sgd3_w, valLossList_sgd3_w, valAccList_sgd3_w  = train_model(modelx, k, trainset, testset, loss_type, loss_fn, optimizer, num_epochs, batch_size, validate_model = True, performance=accuracy, device=device, lr=learn_rate, lr_scheduler=None, updateWManually=True)

In [None]:
modelManualx = MLPManual(k, learn_rate, loss_type, (w1.t(),w2.t()))
trainLostList_sgd3_scratch_w, trainAccList_sgd3_scratch_w, \
valLossList_sgd3_scratch_w, valAccList_sgd3_scratch_w  = train_model_manually(modelManualx, k, trainset, testset, loss_type, loss_fn, num_epochs,
                                                                          batch_size, validate_model = True, device=device)

In [None]:
plt.figure(figsize=(15,8))
plt.plot(range(1,21),valAccList_sgd3_w, color = "blue", label = "BP SGD Pytorch")
plt.plot(range(1,21),valAccList_sgd3_scratch_w, color = "green", label = "BP SGD Dogan")

plt.ylim(0.4,1.05)
plt.title("Test Accuracy k=3")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.legend()
plt.grid(True)

plt.savefig("plots/doganVSPytorch.png")

plt.show();

# They are gonna be different, because I recreate the data every epoch
# Even without recreating, results are different?

In [None]:
x = torch.randn(2,3)
x

In [None]:
# Add Lazy methods
learn_rate = 0.05
K = 3
num_epochs = 20
loss_type = "Binary Cross Entropy"

fig = plt.figure(figsize=(15,9))
for activation in ["ReLU", "NTK", "Gaussian features", "ReLU features", "Linear features", "SGD", "SGD Dogan"]:
    if activation != "SGD_Scratch":
        model = MLP(K, activation, loss_type)
        if "features" in activation:
            # deactivate the first layer
            optimizer = torch.optim.Adadelta(model.layer2.parameters(), lr = learn_rate, weight_decay=0.001)
        elif "NTK" in activation:
            paramsToUpdate = list(model.layer1.parameters()) + list(model.layer2.parameters())
            optimizer = torch.optim.Adadelta(paramsToUpdate, lr = learn_rate, weight_decay=0.001)
        elif "SGD" in activation:
            optimizer = torch.optim.SGD(model.parameters(), lr = learn_rate, weight_decay=0.001)
        else:
            optimizer = torch.optim.Adadelta(model.parameters(), lr = learn_rate, weight_decay=0.001)

        print("Activation:",activation)

        trainLostList, trainAccList, valLossList, valAccList  = train_model(model, K, trainset, testset, loss_type, loss_fn, optimizer, num_epochs,
                                                                            batch_size, validate_model = True, performance=accuracy,
                                                                            device="cuda:0", lr_scheduler=None)
    else:
        print("Activation:",activation)
        modelManual3 = MLPManual(K, learn_rate, loss_type, False)

        trainLostList, trainAccList, valLossList, valAccList  = train_model_manually(modelManual3, K, trainset, testset, loss_type, loss_fn, num_epochs,
                                                                                  batch_size, validate_model = True, device=device)

    plotValAccuracy(valAccList, num_epochs, activation, K)

fig.savefig("plots/" + str(K) + "valAccuracy.png")
plt.show()
dataset = MNISTParity(trainset, K, 128)
dataset.plotRandomData()

# just need to find good lr and weight_decay values for lazy methods to have more similar plots to paper