# Train Model

#### References
* https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
* https://pytorch.org/tutorials/beginner/saving_loading_models.html
* https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import models
import losses
import utils_train
import change_dataset_np
import matplotlib.pyplot as plt
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

import helper_augmentations

from IPython.display import clear_output, display
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# Hyperparameters
num_epochs = 50
num_classes = 2
batch_size = 20
img_size = 224
base_lr = 1e-4

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
num_gpu = torch.cuda.device_count()
batch_size *= num_gpu
base_lr *= num_gpu
print('Number of GPUs Available:', num_gpu)

train_pickle_file = './proc_dataset/change_dataset_train.pkl'
val_pickle_file = './proc_dataset/change_dataset_train.pkl'

PyTorch Version:  1.1.0
Torchvision Version:  0.3.0
Device: cuda:0
Number of GPUs Available: 8


#### Define Transformation

In [2]:
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = {
    'train': transforms.Compose([
        #transforms.RandomResizedCrop(img_size),
        #transforms.RandomHorizontalFlip(),
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        #helper_augmentations.SwapReferenceTest(),
        transforms.RandomGrayscale(p=0.2),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0, hue=0),
        #helper_augmentations.JitterGamma(),
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

#### Load Dataset

In [3]:
# Create training and validation datasets
train_dataset = change_dataset_np.ChangeDatasetNumpy(train_pickle_file, data_transforms['train'])
val_dataset = change_dataset_np.ChangeDatasetNumpy(val_pickle_file, data_transforms['val'])
image_datasets = {'train': train_dataset, 'val': val_dataset}
# Create training and validation dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=16)
#dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=8) for x in ['train', 'val']}
dataloaders_dict = {'train': train_loader, 'val': val_loader}

#### Start Tensorboard Interface

In [4]:
# Default directory "runs"
writer = SummaryWriter()

#### Initialize Model

In [5]:
img_reference_dummy = torch.randn(1,3,img_size,img_size)
img_test_dummy = torch.randn(1,3,img_size,img_size)
change_net = models.ChangeNet(num_classes=num_classes)

# Add on Tensorboard the Model Graph
writer.add_graph(change_net, [img_reference_dummy, img_test_dummy])



#### Send Model to GPUs (If Available)

In [6]:
if num_gpu > 1:
    change_net = nn.DataParallel(change_net)
change_net = change_net.to(device)

#### Load Weights

In [7]:
#checkpoint = torch.load('./best_model.pkl')
#change_net.load_state_dict(checkpoint);

#### Initialize Loss Functions and Optimizers

In [8]:
#criterion = nn.CrossEntropyLoss()
# If there are more than 2 classes the alpha need to be a list
criterion = losses.FocalLoss(gamma=2.0, alpha=0.25)
optimizer = optim.Adam(change_net.parameters(), lr=base_lr)    
sc_plt = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)    

#### Train Model

In [None]:
best_model, _ = utils_train.train_model(change_net, dataloaders_dict, criterion, optimizer, sc_plt, writer, device, num_epochs=num_epochs)
torch.save(best_model.state_dict(), './best_model.pkl')

Epoch 0/49
----------
train Loss: 0.0043
val Loss: 0.0023
Epoch 1/49
----------
train Loss: 0.0019
val Loss: 0.0019
Epoch 2/49
----------
train Loss: 0.0016
val Loss: 0.0013
Epoch 3/49
----------
train Loss: 0.0013
val Loss: 0.0021
Epoch 4/49
----------
train Loss: 0.0013
val Loss: 0.0011
Epoch 5/49
----------
train Loss: 0.0011
train Loss: 0.0019
val Loss: 0.0015
Epoch 7/49
----------
train Loss: 0.0015
val Loss: 0.0012
Epoch 8/49
----------
train Loss: 0.0011
val Loss: 0.0010
Epoch 9/49
----------
train Loss: 0.0010
val Loss: 0.0009
Epoch 10/49
----------
train Loss: 0.0009
val Loss: 0.0009
Epoch 11/49
----------
train Loss: 0.0009
val Loss: 0.0009
Epoch 12/49
----------
train Loss: 0.0009
val Loss: 0.0021
Epoch 13/49
----------
train Loss: 0.0012
val Loss: 0.0009
Epoch 14/49
----------
train Loss: 0.0009
val Loss: 0.0008
Epoch 15/49
----------
train Loss: 0.0008
val Loss: 0.0008
Epoch 16/49
----------
train Loss: 0.0008
val Loss: 0.0007
Epoch 17/49
----------
val Loss: 0.0007
Epoch 

In [None]:
@interact(idx=widgets.IntSlider(min=0,max=len(val_dataset)-1))
def explore_validation_dataset(idx):
    best_model.eval()
    sample = val_dataset[idx]
    reference = sample['reference'].unsqueeze(0)
    reference_img = sample['reference'].permute(1, 2, 0).cpu().numpy()
    test_img = sample['test'].permute(1, 2, 0).cpu().numpy()
    test = sample['test'].unsqueeze(0)
    #label = sample['label'].type(torch.LongTensor).squeeze(0).cpu().numpy()
    label = (sample['label']>0).type(torch.LongTensor).squeeze(0).cpu().numpy()
    pred = best_model([reference, test])
    #print(pred.shape)
    _, output = torch.max(pred, 1)
    output = output.squeeze(0).cpu().numpy()
    fig=plt.figure(figsize=(8, 8))
    fig.add_subplot(2, 2, 1)
    plt.imshow(reference_img)
    plt.title('Reference')
    fig.add_subplot(2, 2, 2)
    plt.imshow(test_img)
    plt.title('Test')
    fig.add_subplot(2, 2, 3)
    plt.imshow(label)
    plt.title('Label')
    fig.add_subplot(2, 2, 4)
    plt.imshow(output)
    plt.title('ChangeNet Output')
    plt.show()

In [None]:
1+1