In [1]:
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data import random_split
import torchvision.models as models

import matplotlib.pyplot as plt
%matplotlib inline

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Implement a PyTorch model with VGG on CIFAR10

There are many optimizers that employ adaptive learning rates to account for the different learning rate needs at different phases of training. Your job is to first implement such an optimizer and see its performance. For now, we are using

### a. Download Data CIFAR10

In [3]:
%%capture
transform_method = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize( 
       (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 
    )
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_method)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_method)

In [4]:
batch_size = 1
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

### b. Implement a VGG Model and Fit the Data with SGD

In [None]:
class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch 
        images, labels = images.to(device), labels.to(device)
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        if not (epoch+1) % 1 == 0:
            return
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch+1, result['val_loss'], result['val_acc']))

class CIFAR10VGGModel(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.model = models.vgg16().to(device)
    
    def forward(self, xb):
        return self.model(xb)

def accuracy(outputs, labels):
    max_probabilities, predictions = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(predictions == labels).item() / len(predictions))
    
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
model1 = CIFAR10VGGModel().to(device)
history1 = fit(20, 0.0005, model1, train_loader, test_loader, torch.optim.Adam)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1)
fig.suptitle("Learning Trend")
val_loss = [element["val_loss"] for element in history1]
val_accuracy = [element["val_acc"] for element in history1]
ax1.plot(val_loss, label="Validation Loss")
ax2.plot(val_accuracy, label="Validation Accuracy")
ax1.legend()
ax1.set_ylabel("Loss")
ax2.legend()
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Accuracy")
plt.show()

## 2. Implement SGD with Hypergradient Descent, and comprare it with SGD without Hypergradient Descent

### This is already done in our code, Using VGG with SGD on CIFAR10 data set. Please refer to our code base for solution.

## 3. Compare MARTHE on SGD to SGD-HD.
Implement MARTHE. Use the learning rate scheduling to tune the learning rate of SGD, and compare the results to SGD with HD on VGG.

### Import MARTHE Code

In [None]:
from adatune.data_loader import *
from adatune.mu_adam import MuAdam
from adatune.mu_sgd import MuSGD
from adatune.network import *
from adatune.utils import *

### Define Training procedure

In [None]:
def train_rtho(network_name, dataset, num_epoch, batch_size, optim_name, lr, momentum, wd, hyper_lr, alpha,
               grad_clipping, first_order, seed, mu=1.0):
    torch.manual_seed(seed)
    return

    # We are using cuda for training - no point trying out on CPU for ResNet
    device = torch.device("cuda")

    net = network(network_name, dataset)
    net.to(device).apply(init_weights)

    # assign argparse parameters
    criterion = nn.CrossEntropyLoss().to(device)
    best_val_accuracy = 0.0
    cur_lr = lr
    timestep = 0

    train_data, test_data = data_loader(network, dataset, batch_size)

    if optim_name == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=wd, eps=1e-4)
        hyper_optim = MuAdam(optimizer, hyper_lr, grad_clipping, first_order, mu, alpha, device)
    else:
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
        hyper_optim = MuSGD(optimizer, hyper_lr, grad_clipping, first_order, mu, alpha, device)

    vg = ValidationGradient(test_data, nn.CrossEntropyLoss(), device)
    for epoch in range(num_epoch):
        train_correct = 0
        train_loss = 0

        for inputs, labels in train_data:
            net.train()
            timestep += 1

            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            train_loss += loss.item()

            train_pred = outputs.argmax(1)
            train_correct += train_pred.eq(labels).sum().item()

            first_grad = ag.grad(loss, net.parameters(), create_graph=True, retain_graph=True)

            hyper_optim.compute_hg(net, first_grad)

            for params, gradients in zip(net.parameters(), first_grad):
                params.grad = gradients

            optimizer.step()
            hyper_optim.hyper_step(vg.val_grad(net))
            clear_grad(net)

        train_acc = 100.0 * (train_correct / len(train_data.dataset))
        val_loss, val_acc = compute_loss_accuracy(net, test_data, criterion, device)

        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc

        print('train_accuracy at epoch :{} is : {}'.format(epoch, train_acc))
        print('val_accuracy at epoch :{} is : {}'.format(epoch, val_acc))
        print('best val_accuracy is : {}'.format(best_val_accuracy))

        cur_lr = 0.0
        for param_group in optimizer.param_groups:
            cur_lr = param_group['lr']
        print('learning_rate after epoch :{} is : {}'.format(epoch, cur_lr))

In [None]:
train_rtho("vgg", "cifar_10", 10, 16, "adam", 0.0001, 0.9, 0, 0.0001, 1e-6, 100.0, False, 42)

<img src='./figures/cifar10vgg_one_ACC-1.png' width="400" height="400">
<img src='./figures/cifar10vgg_one_LOSS-1.png' width="400" height="400">

#### This is what it should look like for the validation loss and validation accuracy of MARTHE and Hypergradient Descent with VGG and SGD on CIFAR10.

Use the graphs from the paper to compare HD with MARTHE. Dataset Cifar10, VGG, SGD. THis is what we are hoping from the students  
What is MARTHE using/why is marthe better than hD? (Make it graduate)


## 4. Task for Graduate Students
### a. Repeat task 1 and task 2, this time using ResNet and CIFAR100.
<img src='./figures/cifar100resnet_one_ACC-1.png' width="400" height="400">
<img src='./figures/cifar100resnet_one_LOSS-1.png' width="400" height="400">

#### This is what it should look like for the validation loss and validation accuracy of MARTHE and Hypergradient Descent with Resnet and SGD on CIFAR100.

### b. Students should also give an explanation about their findings: Read the MARTHE paper and provide an idea of why it is performing better than Hypergradient Descent.