# Prolegomena

Depending on whether you are running locally or on google colab uncomment the appropriate PATH cell.
By default the local cell is uncommented.

 ## Imports

In [None]:
import os
import torch
torch.manual_seed(42)

import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from tqdm import tqdm

from torchvision import datasets, transforms

import gc

## PATH (local run)

In [None]:
PATH_TO_ROOT = "."
PATH_TO_DATASETS = os.path.join(
    PATH_TO_ROOT, "datasets"
)

from modules.utils import Trainer
from modules.models import DFC_LeNet_5, Deep_RetiNet


## Mount drive, PATH (on google colab)

In [None]:
# from google.colab import drive
# import os, sys

# drive.mount('./mnt')
# !ls mnt/MyDrive/Code/ -l

# PATH_TO_ROOT = "mnt/MyDrive/Code/RetiNet"
# PATH_TO_DATASETS = os.path.join(
#     PATH_TO_ROOT, "datasets"
# )

# sys.path.append(PATH_TO_ROOT)
# from modules.utils import Trainer
# from modules.models import DFC_LeNet_5, Deep_RetiNet


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using: {device}')

# Training

## Input panel

Training parameters input section. When a limited set of options is available, a list of the possibilities is specified as a comment above the parameter definition line.
When training a new model you can start directly from here.

In [None]:
# MNIST - FashionMNIST - SVHN
dataset_name = "MNIST" #@param         

# LeNet - Deep_RetiNet
model_name = "Deep_RetiNet" #@param 

retinic_kernel_size =  7#@param
rks = retinic_kernel_size

depth =  3#@param


optimizer = Adam 
batch_size = 128 
start_lr = 1e-3 
epochs = 20

loss_fn = CrossEntropyLoss()

Datasets loading, dataloaders

In [None]:
IL = { 
    "MNIST" : datasets.MNIST,
    "FashionMNIST" : datasets.FashionMNIST,
    "SVHN" : datasets.SVHN,
}


if dataset_name == "SVHN":
    # pytorch does not use the same standards for all
    # datasets for some reasons I dont know why
    path = os.path.join(PATH_TO_DATASETS, "SVHN")

    trainset = IL[dataset_name](
        path,
        download = True,
        split = "train",
        transform = transforms.ToTensor()
    )

    testset = IL[dataset_name](
        path,
        download = True,
        split = "test",
        transform = transforms.ToTensor()
    )

else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Pad(2)
    ])

    trainset = IL[dataset_name](
        PATH_TO_DATASETS,
        download = True,
        train = True,
        transform = transform
    )

    testset = IL[dataset_name](
        PATH_TO_DATASETS,
        download = True,
        train = False,
        transform = transform
    )
 
in_channels = testset[0][0].shape[0]
train_size = trainset.__len__()
test_size = testset.__len__()


trainloader = DataLoader(trainset, 
                          batch_size = batch_size, 
                          shuffle = True, 
                          num_workers = 1)

testloader = DataLoader(testset, 
                        batch_size = 5000, 
                        shuffle = False, 
                        num_workers = 1)

Model and optimizer istantiation

In [None]:
if model_name == "LeNet":
  model = DFC_LeNet_5(in_channels).to(device)
  model_save_name = "DFC_LeNet_5"

elif model_name == "Deep_RetiNet":
  model = Deep_RetiNet(depth, rks, in_channels).to(device)
  model_save_name = f"Deep_RetiNet_d{depth}_rks{rks}"

optimizer = optimizer(model.parameters(), lr=start_lr)

print(model)

trainer = Trainer(model)

## Grid search: uncomment
In the case you need to do some grid search uncomment this cell and run it 

In [None]:
# lrs = [1e-3, 2e-3, 3e-3]
# batch_sizes = [128]
# betass = [(0.7, 0.999), (0.8, 0.999), (0.9, 0.999)]
# epss = [1e-7, 1e-8, 1e-9]
# gc.collect()

# for lr in lrs:
#   for batch_size in batch_sizes:
#     print(f"Learning rate: {lr} \t batch size: {batch_size}")

#     gs_model = LeNet_5().to(device)
#     gs_trainer = Trainer(gs_model)

#     gs_trainloader = DataLoader(trainset, 
#                             batch_size = batch_size, 
#                             shuffle = True, 
#                             num_workers = 2)

#     gs_optimizer = Adam(gs_model.parameters(),
#                        lr = lr)

#     gs_trainer.train(trainloader = gs_trainloader,
#                      validloader = testloader,
#                      optimizer = gs_optimizer,
#                      epochs = 1,
#                      loss_fn = loss_fn,
#                      retitrain = False,
#                      plotting = False)

## Training routine
3+1+1-scheduled trainig routine
we do somewhat of a learning rate scheduling and we avoid dropping the colab gpu memory

In [None]:
gc.collect()

trainer.train(trainloader = trainloader,
              validloader = testloader,
              optimizer = optimizer,
              loss_fn = loss_fn,
              epochs = 5,
              plotting = True)

In [None]:
gc.collect()

trainer.train(trainloader = trainloader,
              validloader = testloader,
              optimizer = optimizer,
              loss_fn = loss_fn,
              epochs = 5,
              plotting = True)

In [None]:
gc.collect()

trainer.train(trainloader = trainloader,
              validloader = testloader,
              optimizer = optimizer,
              loss_fn = loss_fn,
              epochs = 5,
              plotting = True)

In [None]:
gc.collect()

for g in optimizer.param_groups:
  g["lr"] = 1e-4

trainer.train(trainloader = trainloader,
              validloader = testloader,
              optimizer = optimizer,
              loss_fn = loss_fn,
              epochs = 5,
              plotting = True)

In [None]:
gc.collect()

for g in optimizer.param_groups:
  g["lr"] = 1e-5

trainer.train(trainloader = trainloader,
              validloader = testloader,
              optimizer = optimizer,
              loss_fn = loss_fn,
              epochs = 3,
              plotting = True)

Save the state dict

In [None]:
PATH_TO_SAVE = PATH_TO_ROOT + f"/trained_models/{dataset_name}"
torch.save(model.state_dict(),
            f"{PATH_TO_SAVE}/trained_{model_save_name}_state_dict.pt")