# Train Model

#### References
* https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import models
import utils_train
import change_dataset_np
import matplotlib.pyplot as plt
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

# Hyperparameters
num_classes = 3
batch_size = 100
img_size = 224
base_lr = 1e-4

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
num_gpu = torch.cuda.device_count()
batch_size *= num_gpu
base_lr *= num_gpu
print('Number of GPUs Available:', num_gpu)

train_pickle_file = 'change_dataset_train.pkl'
val_pickle_file = 'change_dataset_train.pkl'

#### Define Transformation

In [None]:
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()        
    ]),
    'val': transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor()        
    ]),
}

#### Load Dataset

In [None]:
# Create training and validation datasets
train_loader = change_dataset_np.ChangeDatasetNumpy(train_pickle_file, data_transforms['train'])
val_loader = change_dataset_np.ChangeDatasetNumpy(val_pickle_file, data_transforms['val'])
image_datasets = {'train': train_loader, 'val': val_loader}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=8) for x in ['train', 'val']}

#### Start Tensorboard Interface

In [None]:
# Default directory "runs"
writer = SummaryWriter()

#### Initialize Model

In [None]:
img_reference_dummy = torch.randn(1,3,img_size,img_size)
img_test_dummy = torch.randn(1,3,img_size,img_size)
change_net = models.ChangeNet(num_classes=num_classes)

# Add on Tensorboard the Model Graph
writer.add_graph(change_net, [img_reference_dummy, img_test_dummy])

#### Send Model to GPUs (If Available)

In [None]:
if num_gpu > 1:
    change_net = nn.DataParallel(change_net)
change_net = change_net.to(device)

#### Initialize Loss Functions and Optimizers

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(change_net.parameters(), lr=base_lr)    
sc_plt = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, verbose=True)    

#### Train Model

In [None]:
best_model = utils_train.train_model(change_net, dataloaders_dict, criterion, optimizer, writer, device, num_epochs=25)
torch.save(best_model.state_dict(), './best_model.pkl')