# Train Model

In [1]:
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.autograd import Variable
from networks.data_utils import get_imdb_data
from torch.utils.data import DataLoader, Dataset
import random
import torch.nn.functional as F
from networks.relay_net import ReLayNet
# from networks.solver import Solver
from random import randint

In [2]:
class RandomTransforms(object):
    def __init__(self, height, width, layers, prob1=1, prob2=0.5, border = [8,8,8,8]):
        # tblr
        self.height = height
        self.width = width
        self.layers = layers
        self.prob1 = prob1
        self.prob2 = prob2
        self.b = border

    def __call__(self, image, target, weight):

        if random.random()<self.prob1:
            pad_image = np.zeros((self.height+self.b[0]+self.b[1], self.width+self.b[2]+self.b[3], 1))
            pad_target = np.zeros((self.height+self.b[0]+self.b[1], self.width+self.b[2]+self.b[3], self.layers))
            pad_weight = np.zeros((self.height+self.b[0]+self.b[1], self.width+self.b[2]+self.b[3]))
            
            # mirror image
            pad_image[self.b[0]:-self.b[1],self.b[2]:-self.b[3],:] = image 
            pad_image[:self.b[0],self.b[2]:-self.b[3]] = image[self.b[0]-1:None:-1]# fill up top border
            pad_image[-self.b[1]:,self.b[2]:-self.b[3]] = image[:-self.b[1]-1:-1] # fill up bottom border
            pad_image[:,self.b[2]-1:None:-1] = pad_image[:,self.b[2]:2*self.b[2]] # fill up left border
            pad_image[:,-self.b[3]:] = pad_image[:,-self.b[3]-1:-2*self.b[3]-1:-1] # fill up right border

            # mirror target
            pad_target[self.b[0]:-self.b[1],self.b[2]:-self.b[3],:] = target 
            pad_target[:self.b[0],self.b[2]:-self.b[3]] = target[self.b[0]-1:None:-1]# fill up top border
            pad_target[-self.b[1]:,self.b[2]:-self.b[3]] = target[:-self.b[1]-1:-1] # fill up bottom border
            pad_target[:,self.b[2]-1:None:-1] = pad_target[:,self.b[2]:2*self.b[2]] # fill up left border
            pad_target[:,-self.b[3]:] = pad_target[:,-self.b[3]-1:-2*self.b[3]-1:-1] # fill up right border

            # mirror weight
            pad_weight[self.b[0]:-self.b[1],self.b[2]:-self.b[3]] = weight 
            pad_weight[:self.b[0],self.b[2]:-self.b[3]] = weight[self.b[0]-1:None:-1]# fill up top border
            pad_weight[-self.b[1]:,self.b[2]:-self.b[3]] = weight[:-self.b[1]-1:-1] # fill up bottom border
            pad_weight[:,self.b[2]-1:None:-1] = pad_weight[:,self.b[2]:2*self.b[2]] # fill up left border
            pad_weight[:,-self.b[3]:] = pad_weight[:,-self.b[3]-1:-2*self.b[3]-1:-1] # fill up right border

            loc = [randint(0,16-1), randint(0,16-1)]
            image = pad_image[loc[0]:loc[0]+self.height, loc[1]:loc[1]+self.width]
            target = pad_target[loc[0]:loc[0]+self.height, loc[1]:loc[1]+self.width]
            weight = pad_weight[loc[0]:loc[0]+self.height, loc[1]:loc[1]+self.width]
        
        if random.random() < self.prob2:
            '''
            flipping
            '''

            image = np.flip(image,1)
            target = np.flip(target,1)
            weight = np.flip(weight,1)
        return image, target, weight
                
class ImdbData(Dataset):
    
    def __init__(self, config, X, y, W, transform=None):
        self.X = X
        self.y = y
        self.w = W
        self.height = config['general']['HEIGHT']
        self.width = config['general']['WIDTH']
        self.layers = config['general']['layers']
        self.transform = transform

    def __getitem__(self, index):
        img = np.transpose(self.X[index], (1,2,0)) 
        label = np.transpose(self.y[index],(1,2,0))
        weight = self.w[index]
        if self.transform is not None:
            img, label, weight = self.transform(img, label, weight)

        img = torch.from_numpy(img.copy()).float().permute(2,0,1)
        label = torch.from_numpy(label.copy()).long().permute(2,0,1)
        weight = torch.from_numpy(weight.copy()).float()
    
        return img, label, weight


    def __len__(self):
        return len(self.X)

In [3]:
with open( "./train.yaml") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
HEIGHT = config['general']['HEIGHT']
WIDTH = config['general']['WIDTH']
layers = config['general']['layers']
exp_dir_name = config['filepaths']['exp_dir_name']

model_path = config['filepaths']['model_path']
data_dir = config['filepaths']['processed_data_path']
param = config['param']

train_images, train_labels, train_wmaps, val_images, val_labels, val_wmaps = get_imdb_data(data_dir)

train_images2 = np.copy(np.expand_dims(train_images.reshape(-1,HEIGHT, WIDTH), axis = 1))
train_labels2 = np.copy(train_labels.reshape(-1, layers, HEIGHT, WIDTH))
train_wmaps2 = np.copy(train_wmaps.reshape(-1, HEIGHT, WIDTH))

val_images2 = np.copy(np.expand_dims(val_images.reshape(-1,HEIGHT, WIDTH), axis = 1))
val_labels2 = np.copy(val_labels.reshape(-1, layers, HEIGHT, WIDTH))
val_wmaps2 = np.copy(val_wmaps.reshape(-1, HEIGHT, WIDTH))

# combining train and validation since paper only did train
train_images3 = np.concatenate((train_images2, val_images2), axis=0)
train_labels3 = np.concatenate((train_labels2, val_labels2), axis=0)
train_wmaps3 = np.concatenate((train_wmaps2, val_wmaps2), axis=0)    


In [4]:
from random import shuffle
import numpy as np
import torch.nn.functional as F
import torch
import pathlib
import torch.nn as nn
import pandas as pd
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
from torch.autograd import Variable
from networks.net_api.losses import DiceLoss, CrossEntropyLoss2d
from torch.optim import lr_scheduler
import os
from tqdm import tqdm


class Solver(object):
    # global optimiser parameters
    default_optim_args = {"lr": 0.1,
                          "momentum" : 0.9,
                          "weight_decay": 0.0001}
    gamma = 0.1
    step_size = 30
    NumClass = 10 # TO CHANGE

    def __init__(self, device, optim=torch.optim.SGD, optim_args={}):
        optim_args_merged = self.default_optim_args.copy()
        optim_args_merged.update(optim_args)
        self.optim_args = optim_args_merged
        self.optim = optim
        self.loss_func = CombinedLoss(device)
        self.device = device

        self._reset_histories()

    def _reset_histories(self):
        """
        Resets train and val histories for the accuracy and the loss.
        """
        self.train_loss_history = []
        self.train_acc_history = []

In [5]:
random_transform = RandomTransforms(config['general']['HEIGHT'], config['general']['WIDTH'], config['general']['layers'])

train_dataset = ImdbData(config, train_images3, train_labels3, train_wmaps3, transform = random_transform)



In [None]:

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=50, shuffle=True, num_workers=4)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=50, shuffle=False, num_workers=4)
device = torch.device("cuda")

relaynet_model = ReLayNet(param)
solver = Solver(device)
num_epochs = 60
        
# solver.train(relaynet_model, train_loader, model_path=model_path, num_epochs=num_epochs, log_nth=1,  exp_dir_name=exp_dir_name)

model = relaynet_model
log_nth = 1
exp_dir_name = 'exp_default'

"""
Train a given model with the provided data.

Inputs:
- model: model object initialized from a torch.nn.Module
- train_loader: train data in torch.utils.data.DataLoader
- num_epochs: total number of training epochs
- log_nth: log training accuracy and loss every nth iteration
"""
optim = solver.optim(model.parameters(), **solver.optim_args)
# learning rate schedular
scheduler = lr_scheduler.StepLR(optim, step_size=solver.step_size,
                                gamma=solver.gamma)  # decay LR by a factor of 0.1 every 30 epochs


iter_per_epoch = 1
# iter_per_epoch = len(train_loader)

model.to(solver.device)

print('START TRAIN.')
curr_iter = 0

per_epoch_train_acc = []

for epoch in range(num_epochs):
    scheduler.step()
    solver._reset_histories()
    model.train()
    iteration = 0

    batch = tqdm(enumerate(train_loader), total=len(train_loader))

    for i_batch, sample_batched in batch:
        X = Variable(sample_batched[0], requires_grad=True)
        y = Variable(sample_batched[1])
        w = Variable(sample_batched[2])

        if model.is_cuda:
            X, y, w = X.cuda(), y.cuda(), w.cuda()
        optim.zero_grad()
        output = model(X)
        loss = solver.loss_func(output, y, w)

        loss.backward()
        optim.step()
        _,batch_output =torch.max(F.softmax(output, dim=1), dim=1)
        _, y = torch.max(y, dim=1)

        avg_dice = per_class_dice(batch_output, y, solver.NumClass)
        solver.train_loss_history.append(loss.detach().item())

        solver.train_acc_history.append(avg_dice)
    per_epoch_train_acc.append(np.sum(np.asarray(solver.train_acc_history))/len(train_loader))

    print('[Epoch : {} / {}]: {:.2f}'.format(epoch, num_epochs, avg_dice.item()))

    full_save_path = os.path.join(model_path, exp_dir_name)
    pathlib.Path(full_save_path).mkdir(parents=True, exist_ok=True)
    model.save(os.path.join(full_save_path, 'relaynet_epoch'+ str(epoch + 1) + '.model'))



In [None]:
def per_class_dice(y_pred, y_true, num_class):
    avg_dice = 0
    y_pred = y_pred.data.cpu().numpy()
    y_true = y_true.data.cpu().numpy()
    for i in range(5,num_class):
        GT = y_true == (i)
        Pred = y_pred == (i)
        plt.figure()
        plt.imshow(GT[0])
        plt.figure()
        plt.imshow(Pred[0])
        inter = np.sum(np.multiply(GT, Pred)) + 0.0001
        union = np.sum(GT) + np.sum(Pred) + 0.0001
        t = 2 * inter / union
        print(t)
        avg_dice = avg_dice + (t / num_class)
        break
    return avg_dice

per_class_dice(batch_output, y, 10)

In [None]:
df = pd.DataFrame(data=d)
df.to_csv(os.path.join(full_save_path, 'accuracy_history.csv'))
print('FINISH.')