# Training our UNet model

In [None]:
%matplotlib inline
import torch
import os, sys
import matplotlib.pyplot as plt
import matplotlib.image as mpimg 
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import random_split, DataLoader

## Set up the environment 

In [None]:
torch.manual_seed(202042)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
if "google.colab" in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True, use_metadata_server=False)

In [6]:
from src import training
from src.models.unet import UNet
from src.metrics import DiceLoss
from src.image_mask_dataset import ImageMaskDataset, FullSubmissionImageDataset

## Load the data

In [None]:
# Load data
root_dir = "/content/drive/Shareddrives/road-segmentation/data/"
image_dir = root_dir + "training/images/"
gt_dir = root_dir + "training/groundtruth/"
test_dir = root_dir + "test_set_images/"

dataset = ImageMaskDataset(image_dir, gt_dir)

# Perform data augmentation by rotation and shearing
angles = [15, -10, 45, -60, 78]
for angle in angles:
    rotation = lambda img: TF.rotate(img, angle)
    dataset += ImageMaskDataset(image_dir, gt_dir, rotation)

shears = [[15, 20], [10, 30], [30, -17], [-3, 20], [-5, -10]]
for shear in shears:
    transformation = lambda img: TF.affine(img, angle=0, scale=1.0, translate=[0, 0], shear=shear)
    dataset += ImageMaskDataset(image_dir, gt_dir, transformation)

print(len(dataset))

In [None]:
batch_size = 5

data_len = len(dataset)
train_len = int(data_len * 0.8)
test_len = int(data_len * 0.2)

# Split the data in 80/20 for training and validation
dataset_train, dataset_test = random_split(dataset, [train_len, test_len])
print(len(dataset_train), len(dataset_test))

# Load the data using a dataloader
dataloader_train = DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True
)

dataloader_test = DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=True
)

## Learning Rate Finder

In [None]:
!pip install torch-lr-finder

In [None]:
from torch_lr_finder import LRFinder

NUM_CHANNELS = 3
NUM_FILTERS = 64

model = UNet(NUM_CHANNELS, NUM_FILTERS).to(device)

criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

# Use the LR-finder to find the optimal learning rate
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(dataloader_train, end_lr=1, num_iter=100)
lr_finder.plot() # Plot resuling graph: loss as a function of epochs 
lr_finder.reset()

## Training the model

In [None]:
#@title Setup
# Name of the run:
run_name = "Unet_paper_Adam_Dice"   #@param {type:"string"}
# Path to the drive:
drive_path = "/content/drive/Shareddrives/road-segmentation/"   #@param {type:"string"}
# Stating epoch (if not 0 load model):
starting_epoch = 0   #@param {type:"integer",  min:0}
# Epoch step (number of epoch between each save):
epoch_step = 10   #@param {type:"integer",  min:1}
# Number of training Epoch
total_iterations =     20#@param {type:"integer", min:1}
# Learning rate (please run above cell and use best found):
learning_rate = 5e-4 #@param {type:"number", min:1e-6}

In [None]:
NUM_CHANNELS = 3
NUM_FILTERS = 64

# decay_rate = 0.95

# Initialize our model with the right optimizer
model = UNet(NUM_CHANNELS, NUM_FILTERS).to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decay_rate, verbose=True)

# Load the model if we're continuing training
if starting_epoch > 0:
    loading_model_file = run_name + f"_{starting_epoch}.pkt"

    if loading_model_file in os.listdir(drive_path):
        print("Loading model from " + loading_model_file)
        state_dicts = torch.load(drive_path + loading_model_file)
        model.load_state_dict(state_dicts['model_state_dict'])
        optimizer.load_state_dict(state_dicts['optimizer_state_dict'])
#         scheduler.load_state_dict(state_dicts['scheduler_state_dict"])
    else:
        print("Unable to load model from " + loading_model_file)

In [None]:
for i in range(starting_epoch, total_iterations, epoch_step):
    # Train the model for a step of epochs
    accuracies, f1_scores, iou_scores = training.train(model, criterion, dataloader_train, dataloader_test, optimizer, num_epochs=epoch_step)

    # Save the model's state
    torch.save({'model_state_dict': model.state_dict(),
#               'scheduler_state_dict': scheduler.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()},
        drive_path + run_name + f"_{i + epoch_step}.pkt")

    score_file_name = "scores_" + run_name + ".csv"

    # Save intermediate scores
    if score_file_name not in os.listdir(drive_path):
        with open(drive_path + score_file_name, "w") as f:
            f.write("accuracy, f1_score, iou_score\n")

    with open(drive_path + score_file_name, "a") as f:
        for i in range(epoch_step):
            f.write(f"{accuracies[i]}, {f1_scores[i]}, {iou_scores[i]}\n")

### Show predicted output

In [None]:
i = 0

# Show predicted outputs on the training images
files = os.listdir(image_dir)
img = mpimg.imread(image_dir + files[i])
gt = mpimg.imread(gt_dir + files[i])
output = model(torch.tensor(img).to(device).permute(2, 0, 1).view(1, 3, 400, 400))
prediction = output[0][0].cpu().detach().numpy()

from src.scripts.helpers import concatenate_images

fig1 = plt.figure(figsize=(14, 10))
plt.imshow(concatenate_images(concatenate_images(img, gt), prediction))

## Predict output for testing images

In [None]:
# Load the testing data
submission_dataloader = DataLoader(
    FullSubmissionImageDataset(test_dir),
    batch_size=1
)

In [None]:
# Set the model in eval state
model.eval()
toPIL = transforms.ToPILImage()

output_dir = "outputs"

if output_dir not in os.listdir():
    os.makedirs(output_dir)

# Run predictions and save outputs
for indexes, images in submission_dataloader:
    out = model(images.to(device)).view(2, 608, 608).cpu()
    toPIL(out[0]).save(output_dir + "/file_{:03d}.png".format(indexes.view(-1).item()))

In [None]:
# Create the submission.csv file
masks_to_submission("submission.csv", *[output_dir + "/" + f for f in os.listdir(output_dir)])