In [None]:
# To dos
# * up_mode
# * hyperparametes: depth, wf, edge_weight, max/avd pool
# * train on whole training set
# * metrics: loss functions for image segmentation, Jaccard index
# * twitch model: hyperparameter tuning
# * more training (increase num_epochs)
# * important parameters: depth, edge_weight, 
# * should the transformations be the same for viable and whole tasks
# * change optimizer: read article on optimizers http://ruder.io/optimizing-gradient-descent/
# * verify loss at init
# * overfit a tiny subset of the data
#


In [None]:
data_name = "paip2019"
gpu_id = 0


############################
# PREPROCESSING PARAMETERS #
############################
num_classes = 2
in_channels = 3
padding = True
depth = 5
wf = 2
up_mode = 'upconv'
batch_norm = True


#######################
# TRAINING PARAMETERS #
#######################
batch_size = 5
patch_size = 256
num_epochs = 100
ignore_index = -100

# edges tend to be the most poorly segmented given how little area 
# they occupy in the training set, this paramter boosts their values 
# along the lines of the original UNET paper; test with different values
edge_weight = 2.0 #*
stages = ["train", "val"]
validation_stages = ["val"]

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from unet import UNet

import PIL
import matplotlib.pyplot as plt
import cv2
import numpy as np
import sys, glob

from network import R2AttU_Net
from models import UNet16

import scipy.ndimage

import time
import math
import tables

import random
from sklearn.metrics import confusion_matrix

In [None]:
#helper function for pretty printing of current time and remaining time
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent+.00001)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
print(torch.cuda.get_device_properties(gpu_id))
torch.cuda.set_device(gpu_id)
device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')

In [None]:
# build the model according to the specified parameters and copy it to the GPU
# model = UNet(n_classes=num_classes, in_channels=in_channels, padding=padding, depth=depth, wf=wf,
#             up_mode=up_mode, batch_norm=batch_norm).to(device)

model = UNet16(num_classes=num_classes, num_filters=8, pretrained=True).to(device)

#print(model)
print(f"total params: \t{sum([np.prod(p.size()) for p in model.parameters()])}")

In [None]:
class WholeDataset(Dataset):
    """
        Class which defines the data model for the whole mask segmentation task.
        
        Attributes:
         - file_name (string) : path to the pytable containing the data
         - image_transform (callable, optional): optional transform to be applied on the WSI data
         - mask_transform (callable, optional): optional transform to be applied on the mask data 
         - edge_weight (float):
         - tables (pytable object): the db object which stores the training/validation data
         - num_items (int): the number of data samples in the dataset
         - num_pixels (int): the number of pixels
    """

    def __init__(self, file_name, image_transform=None, mask_transform=None, edge_weight=False):
        self.file_name = file_name
        self.edge_weight = edge_weight
        
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        
        self.tables = tables.open_file(self.file_name)
        self.num_items = self.tables.root.wsi.shape[0]
        self.wnumpixels = self.tables.root.wnumpixels[:]
        self.tables.close()
        
        self.images = None
        self.whole_masks = None
    
    def __getitem__(self, index):
        
        with tables.open_file(self.file_name, 'r') as db:
            # obtain the wsi-viable_mask pairs for the requested batch
            image = db.root.wsi[index,:,:,:]
            whole_mask = db.root.whole[index,:,:]
           
            # the original paper assigns increased weights to the edges of the annotation
            # use faster method: simply dilate and highlight all the pixels which were added            
            if (self.edge_weight):
                whole_weight = scipy.ndimage.morphology.binary_dilation(whole_mask==1,iterations=2) & ~whole_mask
            else: # otherwise the edge weight is all ones and thus no effect on learning
                whole_weight = np.ones(whole_mask.shape, dtype=whole_mask.dtype)
            
            # reshape in order to use transformations from torchvision
            whole_mask = whole_mask[:,:,None].repeat(3, axis=2)
            whole_weight = whole_weight[:,:,None].repeat(3, axis=2)
            
            # get random seed so that the transformations are reproducible
            seed = random.randrange(sys.maxsize)
            
            if self.image_transform is not None:
                random.seed(seed)
                image_new = self.image_transform(image)
            
            if self.mask_transform is not None:
                random.seed(seed)                
                whole_mask_new = self.mask_transform(whole_mask)
                whole_mask_new = np.asarray(whole_mask_new)[:,:,0].squeeze()                
                
                random.seed(seed)
                whole_weight_new = self.mask_transform(whole_weight)
                whole_weight_new = np.asarray(whole_weight_new)[:,:,0].squeeze()        
                
        return image_new, whole_mask_new, whole_weight_new 
        
    def __len__(self):
        return self.num_items

In [None]:
#note that since we need the transofrmations to be reproducible for both masks and images
#we do the spatial transformations first, and afterwards do any color augmentations

image_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=(patch_size, patch_size),pad_if_needed=True),
        transforms.RandomResizedCrop(size=patch_size),
        transforms.RandomRotation(180),
        transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=.5),
        transforms.RandomGrayscale(),
        transforms.ToTensor() ])

# try different transformation for viable and whole mask respectively
mask_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomVerticalFlip(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=(patch_size,patch_size),pad_if_needed=True),
        transforms.RandomResizedCrop(size=patch_size,interpolation=PIL.Image.NEAREST),
        transforms.RandomRotation(180) ])
#    ])


In [None]:
dataset = {}
data_loader = {}

# create the data loader for both training and validation stages
for stage in stages:
    dataset[stage] = WholeDataset(f'./pytables/{data_name}_{stage}.pytable', image_transform=image_transform, mask_transform=mask_transform, edge_weight=edge_weight)
    data_loader[stage] = DataLoader(dataset[stage],batch_size=batch_size,shuffle=True, num_workers=0, pin_memory=True)

In [None]:
#visualize a single example to verify that it is correct
idx = 677
(img, whole_mask, whole_mask_weight)=dataset["train"][idx]
print(img.shape)
fig, ax = plt.subplots(1,3, figsize=(10,5))  # 1 row, 2 columns

#build output showing original patch  (after augmentation), class = 1 mask, weighting mask, overall mask (to see any ignored classes)
ax[0].set_title("original wsi")
ax[0].imshow(np.moveaxis(img.numpy(),0,-1))

ax[1].set_title("whole mask")
ax[1].imshow(whole_mask)

ax[2].set_title("whole edge weight")
ax[2].imshow(whole_mask_weight)

plt.savefig(f'D:\work2019-2020\PAIP_2019\data\Patches\whole_patch{idx}.png')

In [None]:
# set up optimizer
# Adam is the most robust but for better performance use SGD
# test with different optimizers; study optimizers

optim = torch.optim.Adam(model.parameters())
#optim = torch.optim.SGD(model.parameters(), lr=.1, momentum=0.9, weight_decay=0.0005)

In [None]:
#weight individual classes based on their presence in the training set to avoid biasing any particular class
nclasses = dataset["train"].wnumpixels.shape[1]
print(nclasses)

whole_class_weight = dataset["train"].wnumpixels[1,0:2]
print(whole_class_weight)
whole_class_weight = torch.from_numpy(1-whole_class_weight/whole_class_weight.sum()).type('torch.FloatTensor').to(device)

print(f'whole class weight: {whole_class_weight}')

#reduce = False makes sure we get a 2D output instead of a 1D "summary" value
criterion = nn.CrossEntropyLoss(weight=whole_class_weight, ignore_index=ignore_index, reduction='none') 
#criterion = LossBinary(viable_class_weight, jaccard_weight=0.5)

In [None]:
# Initial settings
best_loss_on_test = np.Infinity
edge_weight = torch.tensor(edge_weight).to(device)
start_time = time.time()

In [None]:
for epoch in range(num_epochs):
    
    # initialize epoch-level performance variables
    epoch_acc = 0
    epoch_loss = {key: torch.zeros(0).to(device) for key in stages}
    cmatrix = np.zeros((2,2))
    #jaccard = torch.zeros(0).to(device) # jaccard index to make an idea of overral performance
    
    for stage in stages:
        if stage == "train":
            model.train()
        else:
            model.eval()
        
        for batch_idx, (X, y, y_weight) in enumerate(data_loader[stage]):            
            X = X.to(device) # [Nbatch, 3, H, W]
            y_weight = y_weight.type('torch.FloatTensor').to(device) # [NBatch, H, W]
            y = y.type('torch.LongTensor').to(device) # [NBatch, H, W] with class indexes (0,1)
        
            with torch.set_grad_enabled(stage=='train'):
                prediction = model(X) # [NBatch, Nclass, H, W]
                loss_matrix = criterion(prediction, y)
                loss = (loss_matrix * (edge_weight ** y_weight)).mean()
                
                # backpropagation
                if stage == "train":
                    optim.zero_grad() # clear previous gradients
                    loss.backward() # compute gradients of all variables wrt to loss
                    optim.step() # perform updates using calculated gradients
                
                epoch_loss[stage] = torch.cat((epoch_loss[stage], loss.detach().view(1,-1)))
                
                # if this phase is part of validation, compute confusion matrix
                if stage in validation_stages:
                    p = prediction[:,:,:,:].detach().cpu().numpy()
                    class_pred = np.argmax(p, axis=1).flatten()
                    mask = y.cpu().flatten()                    
                    cmatrix = cmatrix + confusion_matrix(mask, class_pred, labels=range(num_classes))
#                     if not batch_idx % 50:
#                         print(f'jaccard index: {jaccard_index(cmatrix)}')
                    #jaccard = torch.cat((jaccard, jaccard_index(cmatrix).view(1,-1)))
        
        epoch_acc = (cmatrix / cmatrix.sum()).trace()
        epoch_loss[stage] = epoch_loss[stage].cpu().numpy().mean()
        #jaccard = jaccard.cpu().numpy().mean()

    print('%s ([%d/%d] %d%%), train loss: %.4f test loss: %.4f accuracy: %.4f%%' \
          %(timeSince(start_time, (epoch+1) / num_epochs), epoch+1, num_epochs, (epoch+1) / num_epochs * 100, \
            epoch_loss["train"], epoch_loss["val"], epoch_acc*100), end="")    

    #if current loss is the best we've seen, save model state with all variables
    if epoch_loss["val"] < best_loss_on_test:
        best_loss_on_test = epoch_loss["val"]
        print("  **")
        state = {'epoch': epoch + 1,
                 'model_dict': model.state_dict(),
                 'optim_dict': optim.state_dict(),
                 'best_loss_on_test': epoch_loss,
                 'n_classes': num_classes,
                 'in_channels': in_channels,
                 'padding': padding,
                 'depth': depth,
                 'wf': wf,
                 'up_mode': up_mode, 
                 'batch_norm': batch_norm
                }

        torch.save(state, f"{data_name}_unet_whole.pth")   # best loss 0.1030
    else:
        print("")
    