In [None]:
%%capture
!pip install torchdata==0.4.1

In [None]:
import torch
from torch.optim import Adam, SGD
import torch.nn.functional as F

from train_utils import train_and_checkpoint
from models import init_model, dataset_hyperparams
from datasets import get_data

import matplotlib.pyplot as plt
from os import path, makedirs

In [None]:
NUM_INIT_POINTS = 5

# The Training Pipeline

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def calc_accuracy(y_pred, y):
  return torch.sum(torch.max(y_pred, 1)[1] == y) / y.nelement()

## Create a list of optimizer classes

In [None]:
class SGDMomentum(SGD):
    def __init__(self, params, lr, momentum=0.9):
        super().__init__(params, lr=lr, momentum=momentum)

In [None]:
Optimizers = {'SGD': SGD, 'SGDMomentum': SGDMomentum, 'Adam': Adam}

## Config

In [None]:
root = '.'
checkpoints_dir = path.join(root, 'checkpoints')
figures_dir = path.join(root, 'figures')

In [None]:
makedirs(checkpoints_dir, exist_ok=True)
makedirs(figures_dir, exist_ok=True)

In [None]:
seeds = [i + 42 for i in range(NUM_INIT_POINTS)]
dataset_names = dataset_hyperparams.keys()

In [None]:
dataset_names

## Train and Plot

In [None]:
for dataset_name in dataset_names:
    hyperparams = dataset_hyperparams[dataset_name]
    for seed in seeds:
        for opt_name, Opt in Optimizers:
            torch.manual_seed(seed)
            train_dataloader, test_dataloader = get_data(dataset_name)
            
            model = init_model(dataset_name)
            model = model.to(device)


            optimizer = Opt(model.parameters(), lr=hyperparams['lr'][opt_name])

            optimizer_checkpoints_dir = path.join(checkpoints_dir, dataset_name, f'seed_{seed}', opt_name)
            makedirs(optimizer_checkpoints_dir, exist_ok=True)
            

            info = train_and_checkpoint(model, train_dataloader, test_dataloader, optimizer, F.cross_entropy,
                                        calc_accuracy, device=device, num_epochs=hyperparams['epochs'], path_to_save=optimizer_checkpoints_dir)

            # Add current optimizer results to the plot
            plt.figure(1)
            plt.plot(info['train_losses'], label=f'{opt_name}')
            
            plt.figure(2)
            time_per_epoch = info['time'] / info['epochs']
            time_stamps = [i * time_per_epoch for i in range(info['epochs'] + 1)]
            plt.plot(time_stamps, info['train_losses'], label=f'{opt_name}')

        # Plot epoch vs. loss, time vs. loss. Figures must contain info on seed and optimizer
        plt.figure(1)
        plt.xlabel('Epoch')
        plt.ylabel('Train Loss')
        plt.legend()
        figure_filename = path.join(figures_dir, f'{dataset_name}_{seed}_epochloss.pdf')
        plt.savefig(figure_filename)

        plt.figure(2)
        plt.xlabel('Time (s)')
        plt.ylabel('Loss')
        plt.legend()
        
        figure_filename = path.join(figures_dir, f'{dataset_name}_{seed}_timeloss.pdf')
        plt.savefig(figure_filename)

        plt.show()