In [None]:
!ssh-keygen -t rsa -b 4096
!ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts
!cat /root/.ssh/id_rsa.pub

In [None]:
!ssh -T git@github.com
!git config --global user.email "justin.deschenauxy@epfl.com"
!git config --global user.name "Justin-Collab"
!git clone git@github.com:deschena/colab_unet_train.git
!mv colab_unet_train/* .
from google.colab import drive
drive.mount('/content/gdrive')
!nvidia-smi -L

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
import os, sys, io, random
from PIL import Image
from collections import OrderedDict

from datasets.AugmDataset import AugmDataset
from models.Unet import Unet
from models.DenseUnet import DenseUnet
from utils import *
%matplotlib inline

# Model Selection

In [None]:
device = "cuda"
root_path = "datasets/augmented_dataset/"
train_name = "msel_train/"
valid_name = "msel_valid/"

In [None]:
def train_net(net, train_name, valid_name, seed=999, max_epoch=50, net_name="DEFAULT", patience=5, verbose=True, batch_size=4):
    torch.random.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    root_path = "datasets/augmented_dataset/"
    
    # Since we had the best results with only the binary cross entropy, we combine the final sigmoïd 
    # activation with the loss, since that way we have a numerically more stable result, as the 
    # log-sum-exp trick is used.
    criterion = nn.BCEWithLogitsLoss()
    
    train_set = AugmDataset(root_dir=root_path, name=train_name)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=batch_size)    

    validation_set = AugmDataset(root_dir=root_path,name=valid_name)
    validation_loader = DataLoader(validation_set, batch_size=2*batch_size, shuffle=False, num_workers=2*batch_size)
    
    # Send to GPU, prepare optimizer and learning rate scheduler
    net.to(device)
    optimizer = optim.Adam(net.parameters())
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=patience, verbose=verbose)
    
    validation_loss = []
    training_loss = []
    loss = -1
    best_current_loss = -1
    
    for epoch in range(max_epoch):
        net.train()
        for batch_train, batch_gt in train_loader:
            
            # Send data to gpu
            batch_train = batch_train.to(device)
            batch_gt = batch_gt.to(device)
            
            # Clear accumulated gradients & compute prediction
            optimizer.zero_grad()
            output = net(batch_train)
            # Compute loss, gradient & update parameters
            loss = criterion(output, batch_gt)
            loss.backward()
            optimizer.step()
        # After each epoch, compute & save loss on training and validation sets
        v_perf = validation_perf(net, validation_loader)
        validation_loss.append(v_perf)
        training_loss.append(loss)
        # Check if scheduler must decrease learning rate
        scheduler.step(v_perf)
        if v_perf > best_current_loss:
            # Save best net
            torch.save(net.state_dict(), f"/content/gdrive/My Drive/ML files/model_selection/{net_name}.pth")
            v_perf = best_current_loss
        if verbose and epoch % 10 == 0:
            print(f"{epoch} epochs elapsed")
            
    return training_loss, validation_loss

## Train the models
**Models considered**:
1. Standard Unet
2. Attention Unet (channel attention)
3. Attention Unet (pixel attention)
4. Dense Unet
5. Dense Attention Unet (channel attention)
6. Dense Attention Unet (pixel attention)

**Important note**: The last layer of the sigmoid is deactivated during training because it is included in the loss, indeed, it yields a more stable function by leveraging the "log-sum-exp" trick. When the model is in eval mode, or activation_output is True, the last layer is there.

In [None]:
%%time
net1 = Unet(activation_output=False)
net1_tr, net1_val = train_net(net1, train_name, valid_name, net_name="unet", seed=123123)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net1_tr_loss", net1_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net1_val_loss", net1_val)

In [None]:
%%time
net2 = Unet(attention="channel", activation_output=False)
net2_tr, net2_val = train_net(net2, train_name, valid_name, net_name="channel_unet", seed=4325443)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net2_tr_loss", net2_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net2_val_loss", net2_val)

In [None]:
%%time
net3 = Unet(attention="grid", activation_output=False)
net3_tr, net3_val = train_net(net3, train_name, valid_name, net_name="grid_unet", seed=989873)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net3_tr_loss", net3_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net3_val_loss", net3_val)

In [None]:
%%time
net4 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=False)
net4_tr, net4_val = train_net(net4, train_name, valid_name, net_name="dense_unet", seed=776834)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net4_tr_loss", net4_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net4_val_loss", net4_val)

In [None]:
%%time
net5 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=False, attention="channel")
net5_tr, net5_val = train_net(net5, train_name, valid_name, net_name="dense_channel_unet", seed=445366)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net5_tr_loss", net5_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net5_val_loss", net5_val)

In [None]:
%%time
net6 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=False, attention="grid")
net6_tr, net6_val = train_net(net6, train_name, valid_name, net_name="dense_grid_unet", seed=445366)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net6_tr_loss", net6_tr)
np.save(f"/content/gdrive/My Drive/ML files/model_selection/net6_val_loss", net6_val)

### Evaluating the performance of each architecture
After training those 6 U-nets, we wanted to evaluate their performance on the test set (20% of the original data).

In [None]:
def load_net_params(net, name):
    path = "experiments/model_selection/" + name + ".pth"
    params = torch.load(path)
    net.load_state_dict(params)
    net.eval()
    net.to("cuda")
    return net

In [None]:
def perf_on_test_set(net, dataset_path, dataset_name):
    dataset = AugmDataset(root_dir=dataset_path,name=dataset_name)
    loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=8)
    res = validation_perf(net, loader)
    return res

In [None]:
net1 = Unet(activation_output=True) # This time we don't combine it with the loss, so we want to have the last activation
net1 = load_net_params(net1, "unet")
print(perf_on_test_set(net1, "datasets/augmented_dataset/", "msel_test/"))
del net1 # To avoid filling the GPU

In [None]:
net2 = Unet(attention="channel", activation_output=True)
net2 = load_net_params(net2, "channel_unet")
print(perf_on_test_set(net2, "datasets/augmented_dataset/", "msel_test/"))
del net2

In [None]:
net3 = Unet(attention="grid", activation_output=True)
net3 = load_net_params(net3, "grid_unet")
print(perf_on_test_set(net3, "datasets/augmented_dataset/", "msel_test/"))
del net3

In [None]:
net4 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=True)
net4 = load_net_params(net4, "dense_unet")
print(perf_on_test_set(net4, "datasets/augmented_dataset/", "msel_test/"))
del net4

In [None]:
net5 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=True, attention="channel")
net5 = load_net_params(net5, "dense_channel_unet")
print(perf_on_test_set(net5, "datasets/augmented_dataset/", "msel_test/"))
del net5

In [None]:
net6 = DenseUnet(down_config=(4, 8, 16, 32), bottom=64, up_channels=(256, 128, 64, 32), activation_output=True, attention="grid")
net6 = load_net_params(net6, "dense_grid_unet")
print(perf_on_test_set(net6, "datasets/augmented_dataset/", "msel_test/"))
del net5