In [None]:
from torchsummary import summary
import torch
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os, sys
import torch.optim
from torch.optim import lr_scheduler

from dataloader import *
from models import *
from train import *

## Parameters ##

In [None]:
train_input_path = '/home/pete/melvin/nuclei_segmentation/data/images'
train_target_path = '/home/pete/melvin/nuclei_segmentation/data/labels_3_classes'

eval_input_path = '/home/pete/melvin/nuclei_segmentation/data/eval/images'
eval_target_path = '/home/pete/melvin/nuclei_segmentation/data/eval/labels_3_classes'

stats_path = "/home/pete/melvin/nuclei_segmentation/stats.txt"

transform_params = {"cropping_width":448, "cropping_height":448, "h_flip":0.5, "v_flip":0.5, "normalise":False}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
n_class = 3
lr = 1e-4

## Load dataset ##

In [None]:
# Provide the path to cross validation images and annotations.
# CV_image_paths='.\Red_vs_Yellow_Results\ANN_L1_ModRes\CV\Original_Images'
# CV_target_paths='.\Red_vs_Yellow_Results\ANN_L1_ModRes\CV\Annotated_Images'

# Create a dataset for training.
train_dataset = Nuclei_Dataset(train_input_path, train_target_path, transform_params)


# Create a dataset for evaluation.
eval_dataset = Nuclei_Dataset(eval_input_path, eval_target_path, transform_params, train=False)

# Create a dataset from the CV images and annotations.
# CV_dataset=MyDataset(CV_image_paths, CV_target_paths)

In [None]:
# Load the data.
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True)

eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=True)

## Create dict for parameters to pass to train function ##

In [None]:
parameters = {"stats_path": stats_path,
              "device": device, 
              "batch_size": 4, 
              "n_class": n_class, 
              "datasets": [train_dataset, eval_dataset],
              "loaders": [train_loader, eval_loader]}

## See what a batch looks like ##

In [None]:
# Examine a single batch from the dataset.
batch = next(iter(train_loader))
inputs, targets = batch


# Plot a batch of images.
fig, ax = plt.subplots(inputs.shape[0],2,figsize=(25, 25))
for i in range(inputs.shape[0]):
    ax[i, 0].imshow(inputs[i].permute(1,2,0))
    ax[i, 1].imshow(targets[i])
plt.tight_layout()
plt.show()

## Initialise the model and train it ##

In [None]:
import time

## Will try weight 0 --> 50
for weight_factor in range(3, 4):
    model = ResNetUNet(n_class)
    model = model.to(device)

    # freeze backbone layers
    # Comment out to finetune further
    for l in model.base_layers:
        for param in l.parameters():
            param.requires_grad = False


    ## Create Adam optimiser
    optimizer_ft = optim.Adam(model.parameters(), lr=lr)

    ## Create scheduler for learning rate decay
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

    weight_cpu = np.array([0.85, 1,  0.85])
    weights = torch.from_numpy(np.array([1, weight_factor/2, 1])).cuda().float()
    model = train_model(model, parameters, optimizer_ft, exp_lr_scheduler, weights=weights, progress_bars=False, perclass_stat=True, num_epochs=50)

In [None]:
num_of_batches_to_check = 5

for batch in range(num_of_batches_to_check):
    inputs, labels = next(iter(eval_loader))
    inputs = inputs.cuda()
    out = model(inputs)
    _, prediction = torch.max(out, 1)


    num_predictions = prediction.shape[0]
    fig, ax = plt.subplots(num_predictions, 3, figsize=(25, 25))
    fig.tight_layout()
    plt.subplots_adjust(wspace=0, hspace=0)
    for i in range(num_predictions):
        ax[i%num_predictions, 0].imshow(inputs[i, :, :, :].transpose(0, 1).transpose(1, 2).cpu().numpy())
        ax[i%num_predictions, 1].imshow(labels[i])
        ax[i%num_predictions, 2].imshow(prediction[i].cpu().numpy())
    
    plt.tight_layout()
    plt.show()