In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from torch.optim import Optimizer, SGD
import mass
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import numpy as np
np.random.seed(seed)

In [None]:
# CIFAR 10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 64

trainset = CIFAR10(root='.', train=True, download=True, transform=transform)
testset = CIFAR10(root='.', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, (5,5), padding=0)
        self.conv2 = nn.Conv2d(64, 64, (5,5), padding=0)
        self.conv3 = nn.Conv2d(64, 128, (5,5), padding=0)
        self.fc1 = nn.Linear(128 * 4**2, 128)
        self.fc2 = nn.Linear(128, 10)
            
    def forward(self, x):
        out = self.conv1(x)
        #out = F.max_pool2d(out, (2,2), stride=2)

        out = self.conv2(out)
        out = F.max_pool2d(out, (2,2), stride=2)
        
        out = self.conv3(out)
        out = F.max_pool2d(out, (2,2), stride=2)
        
        out = out.view(out.shape[0], -1)
        out = self.fc1(out)
        out = F.relu(out)

        out = F.dropout(out, 0.5)

        out = self.fc2(out)

        if not self.training:
            out = F.softmax(out, dim=1)
        return out

In [None]:
def fit(model_instance, loss_fn, optim, data_loader, n_iter = 50):
    train_loss = torch.zeros(n_iter)

    for epoch in range(n_iter):
        model_instance.train()
        running_loss = 0.0
        for data in data_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optim.zero_grad()
            outputs = model_instance(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optim.step()
            
            running_loss += loss.item()
        print("Epoch %d, loss %4.2f" % (epoch, running_loss))
        train_loss[epoch] = running_loss

    print('**** Finished Training ****')
    return train_loss

In [None]:
# SGD
torch.cuda.empty_cache()
device = "cuda:0"

model = CNN().to(device)

loss_function = nn.CrossEntropyLoss()
sgd = SGD(model.parameters(), lr = 0.001)
train_loss_sgd = fit(model_instance = model, loss_fn = loss_function, optim = sgd, data_loader=trainloader)
torch.save(train_loss_sgd, "./train_loss_sgd.cifarcnn")
torch.save(model.state_dict(), "./model_sgd.cifarcnn")

In [None]:
# Nesterov
torch.cuda.empty_cache()
device = "cuda:0"

model = CNN().to(device)

loss_function = nn.CrossEntropyLoss()
sgd_nesterov = SGD(model.parameters(), lr = 0.001,momentum=0.9, nesterov=True, weight_decay=3)
train_loss_nes = fit(model_instance = model, loss_fn = loss_function, optim = sgd_nesterov, data_loader=trainloader)
torch.save(train_loss_sgd, "./train_loss_nes.cifarcnn")
torch.save(model.state_dict(), "./model_nes.cifarcnn")

In [None]:
# Mass
torch.cuda.empty_cache()
device = "cuda:0"

model = CNN().to(device)

loss_function = nn.CrossEntropyLoss()
mass = mass.Mass(model.parameters(), lr = 0.001, alpha = 0.05, kappa_t = 2)
train_loss_mass = fit(model, loss_function, mass, trainloader)
torch.save(train_loss_sgd, "./train_loss_mas.cifarcnn")
torch.save(model.state_dict(), "./model_mas.cifarcnn")

In [None]:
plt.plot(torch.log10(train_loss_sgd), c = 'red', label = 'sgd')
plt.plot(torch.log10(train_loss_nes), c = 'blue', label = 'nesterov')
plt.plot(torch.log10(train_loss_mass), c = 'green', label = 'mass')
plt.legend()

In [None]:
parameters_mas = {
    'lr' : np.arange(0.01, 0.3, 0.05),
    'optimizer__alpha' : [0.05],
    'optimizer__kappa_t' : range(2, 24, 5)
}

cv_split = ShuffleSplit(n_splits = 10, test_size = .3, train_size = .7, random_state = 0 )

net = NeuralNetClassifier(
    CNN,
    max_epochs=10,
    optimizer = mass.Mass,
    criterion = nn.CrossEntropyLoss,
    device = 'cuda:0'
)

In [None]:
gs = GridSearchCV(net, parameters_mas, cv=cv_split, scoring='accuracy')

gs.fit(torch.from_numpy(trainset.data).permute(0,3,1,2).float(), torch.from_numpy(trainset.targets))
print(gs.best_score_, gs.best_params_)