---
# Setup

In [None]:
# @title Install dependencies
!pip install opencv-python==4.1.2.30 --quiet
!pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
# Imports
import os
import cv2
import tqdm
import hashlib
import requests

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from numba import jit

from scipy.ndimage import gaussian_filter
from scipy.ndimage import find_objects, binary_fill_holes
from scipy.ndimage import generate_binary_structure, label
from scipy.optimize import linear_sum_assignment

# Our import functions
import glob
import matplotlib.image as mpimg
from pathlib import Path

In [None]:
import os
for dirname, _, filenames in os.walk('/kaggle'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
original_images=np.load('/kaggle/input/hippocampus-segmentation/images_with_blanks.npy')
label_images=np.load('/kaggle/input/hippocampus-segmentation/labels_with_blanks.npy')

In [None]:
original_images.shape, label_images.shape

In [None]:
# Split dataset into train and test set
from sklearn.model_selection import train_test_split

label_images_abs = np.max(label_images,axis=(1,2))

images_train, images_test, label_train, label_test = train_test_split(original_images, label_images, 
                                                    test_size=0.2, random_state=42, stratify=label_images_abs)

label_train_abs = np.max(label_train,axis=(1,2))
images_train, images_val, label_train, label_val = train_test_split(images_train, label_train, 
                                                    test_size=0.2, random_state=42, stratify=label_train_abs)

images_train = np.expand_dims(images_train, axis=1)
images_val = np.expand_dims(images_val, axis=1)
images_test = np.expand_dims(images_test, axis=1)

label_train = np.expand_dims(label_train, axis=1)
label_val = np.expand_dims(label_val, axis=1)
label_test = np.expand_dims(label_test, axis=1)

In [None]:
# Augmented images: transform (sharpen/high contrast) the images for training the GAN

data_transforms = transforms.Compose([
    transforms.ToPILImage(),
    # transforms.Scale(256),
    transforms.ToTensor(),
    transforms.RandomAdjustSharpness(sharpness_factor=10, p=1),
    transforms.ColorJitter(brightness=2, contrast=0, saturation=0, hue=0)
    ])

image_trans = []
for i in range(images_train.shape[0]):
    image_trans.append(data_transforms(torch.tensor(images_train[i][0])).numpy())
image_trans = np.array(image_trans)


images_train_aug = np.vstack([images_train,image_trans])
label_train_aug = np.vstack([label_train,label_train])

idx = torch.randperm(images_train_aug.shape[0]) #randomly shuffle the image order (while preserving image-label association)
images_train_aug = images_train_aug[idx]
label_train_aug = label_train_aug[idx]

print(images_train.shape, images_train_aug.shape)
print(label_train.shape, label_train_aug.shape)

In [None]:
# Normalize images (zero mean, std=1)
mean_train = np.mean(images_train)
std_train = np.std(images_train)

images_train_aug = (images_train_aug - mean_train) / std_train
images_train = (images_train - mean_train) / std_train
images_val = (images_val - mean_train) / std_train
images_test = (images_test - mean_train) / std_train

In [None]:
plt.imshow(images_train_aug[0,0],cmap='gray')
plt.axis('off')
plt.colorbar()

In [None]:
plt.imshow(label_train_aug[0,0],cmap='gray')
plt.axis('off')

In [None]:
# Load images into pytorch DataLoader

In [None]:
# Train U-Net segmentation network

# Class imbalance problem:
# Since there are many images with no hippocampus, we want to count these images less towards the loss

# When calculating the loss on each forward pass:
# LOSS = 0.85 * mean(loss(x_with, y_with)) + 0.15 * mean(loss(x_without, y_without))

In [None]:
def convbatchrelu(in_channels, out_channels, sz):
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, sz, padding=sz//2),
      nn.BatchNorm2d(out_channels, eps=1e-5),
      nn.ReLU(inplace=True),
      )


class convdown(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size):
    super().__init__()
    self.conv = nn.Sequential()
    for t in range(2):
      if t == 0:
        self.conv.add_module('conv_%d'%t,
                             convbatchrelu(in_channels,
                                           out_channels,
                                           kernel_size))
      else:
        self.conv.add_module('conv_%d'%t,
                             convbatchrelu(out_channels,
                                           out_channels,
                                           kernel_size))

  def forward(self, x):
    x = self.conv[0](x)
    x = self.conv[1](x)
    return x


class downsample(nn.Module):
  def __init__(self, nbase, kernel_size):
    super().__init__()
    self.down = nn.Sequential()
    self.maxpool = nn.MaxPool2d(2, 2)
    for n in range(len(nbase) - 1):
      self.down.add_module('conv_down_%d'%n,
                           convdown(nbase[n],
                                    nbase[n + 1],
                                    kernel_size))

  def forward(self, x):
    xd = []
    for n in range(len(self.down)):
      if n > 0:
        y = self.maxpool(xd[n - 1])
      else:
        y = x
      xd.append(self.down[n](y))
    return xd


class convup(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size):
    super().__init__()
    self.conv = nn.Sequential()
    self.conv.add_module('conv_0', convbatchrelu(in_channels,
                                                 out_channels,
                                                 kernel_size))
    self.conv.add_module('conv_1', convbatchrelu(out_channels,
                                                 out_channels,
                                                 kernel_size))

  def forward(self, x, y):
    #print(x.shape, y.shape)
    x = self.conv[0](x)
    x = self.conv[1](x + y)
    return x


class upsample(nn.Module):
  def __init__(self, nbase, kernel_size):
    super().__init__()
    self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')
    self.up = nn.Sequential()
    for n in range(len(nbase) - 1 , 0, -1):
      self.up.add_module('conv_up_%d'%(n - 1),
              convup(nbase[n], nbase[n - 1], kernel_size))

  def forward(self, xd):
    x = xd[-1]
    for n in range(0, len(self.up)):
      if n > 0:
        x = self.upsampling(x)
      x = self.up[n](x, xd[len(xd) - 1 - n])
    return x


class Unet(nn.Module):
  def __init__(self, nbase, nout, kernel_size):
    super(Unet, self).__init__()
    self.nbase = nbase
    self.nout = nout
    self.kernel_size = kernel_size
    self.downsample = downsample(nbase, kernel_size)
    nbaseup = nbase[1:]
    nbaseup.append(nbase[-1])
    self.upsample = upsample(nbaseup, kernel_size)
    self.output = nn.Conv2d(nbase[1], self.nout, kernel_size,
                            padding=kernel_size//2)

  def forward(self, data):
    T0 = self.downsample(data)
    T0 = self.upsample(T0)
    T0 = self.output(T0)
    return T0

  def save_model(self, filename):
    torch.save(self.state_dict(), filename)

  def load_model(self, filename, cpu=False):
    if not cpu:
      self.load_state_dict(torch.load(filename))
    else:
      self.__init__(self.nbase,
                    self.nout,
                    self.kernel_size,
                    self.concatenation)

      self.load_state_dict(torch.load(filename,
                                      map_location=torch.device('cpu')))

In [None]:
def resize(X,xy=(200, 200)):
    reshape_im = []
    for n in range(X.shape[0]):
        reshape_im.append(cv2.resize(X[n,0],xy))
    reshape_im = np.array(np.expand_dims(reshape_im,axis=1))
    return reshape_im


In [None]:
kernel_size = 3
nbase = [1, 32, 64, 128, 256]  # number of channels per layer
nout = 2  # number of outputs

net = Unet(nbase, nout, kernel_size)
# put on GPU here if you have it
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net.to(device);  # remove semi-colon to see net structure

In [None]:
from datetime import datetime

# train the network
# parameters related to training the network
batch_size = 8 # number of images per batch -- amount of required memory
              # for training will increase linearly in batchsize
### you will want to increase n_epochs!
n_epochs = 50  # number of times to cycle through all the data during training
learning_rate = 0.1 # initial learning rate
weight_decay = 1e-5 # L2 regularization of weights
momentum = 0.9 # how much to use previous gradient direction
n_epochs_per_save = 25 # how often to save the network
val_frac = 0.05 # what fraction of data to use for validation
augmentation = True


# where to save the network
# make sure to clean these out every now and then, as you will run out of space
now = datetime.now()
timestamp = now.strftime('%Y%m%dT%H%M%S')

if augmentation == True:
    n_train = images_train_aug.shape[0]
    n_val = images_val.shape[0]    
else:
    n_train = images_train.shape[0]
    n_val = images_val.shape[0]

# gradient descent flavor
optimizer = torch.optim.SGD(net.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay,
                            momentum=0.9)

# set learning rate schedule
LR = np.linspace(0, learning_rate, 10)
if n_epochs > 250:
    LR = np.append(LR, learning_rate*np.ones(n_epochs-100))
    for i in range(10):
        LR = np.append(LR, LR[-1]/2 * np.ones(10))
else:
    LR = np.append(LR, learning_rate * np.ones(max(0, n_epochs - 10)))

criterion = nn.CrossEntropyLoss()

# store loss per epoch
epoch_losses = np.zeros(n_epochs)
epoch_losses[:] = np.nan
val_losses = np.zeros(n_epochs)
val_losses[:] = np.nan

# when we last saved the network
saveepoch = None

# loop through entire training data set nepochs times
for epoch in range(n_epochs):

  epoch_loss = 0
  val_loss=0
  iters = 0
  for param_group in optimizer.param_groups:
    param_group['lr'] = LR[epoch]
  with tqdm.tqdm(total=n_train, desc=f"Epoch {epoch + 1}/{n_epochs}", unit='img') as pbar:
    # loop through each batch in the training data
    net.train() # put in train mode (affects batchnorm)
    for ibatch in np.arange(0, n_train, batch_size):
      # augment the data
      inds = np.arange(ibatch, min(n_train, ibatch+batch_size))
      if augmentation == True:
        imgs = resize(images_train_aug[inds],xy = (200,200))
        lbls = resize(label_train_aug[inds].astype(float),xy = (200,200)).astype(int)
      else: 
        imgs = resize(images_train[inds],xy = (200,200))
        lbls = resize(label_train[inds].astype(float),xy = (200,200)).astype(int)

      # transfer to torch + GPU
      imgs = torch.from_numpy(imgs).to(device=device)
      lbls = torch.from_numpy(lbls).to(device=device)
      imgs = imgs.to(dtype=torch.float32)
      lbls = lbls.to(dtype=torch.int64)

      # compute the loss
      y = net(imgs)
      loss = criterion(y, lbls[:, 0])
      epoch_loss += loss.item()
      pbar.set_postfix(**{'loss (batch)': loss.item()})
      # gradient descent
      optimizer.zero_grad()
      loss.backward()
      #nn.utils.clip_grad_value_(net.parameters(), 0.1)
      optimizer.step()
      iters+=1
      pbar.update(imgs.shape[0])

    net.eval()  

    for ibatch in np.arange(0,n_val,batch_size):
      inds = np.arange(ibatch, min(n_val, ibatch+batch_size))
      imgs_val = resize(images_val[inds],xy = (200,200))
      lbls_val = resize(label_val[inds].astype(float),xy = (200,200)).astype(int)
      imgs_val = torch.from_numpy(imgs_val).to(device=device)
      lbls_val = torch.from_numpy(lbls_val).to(device=device)
      imgs_val = imgs_val.to(dtype=torch.float32)
      lbls_val = lbls_val.to(dtype=torch.int64)  
      output = net(imgs_val)
      loss = criterion(output,lbls_val[:, 0])
      val_loss+=loss.item()
      pbar.set_postfix(**{'val (batch)': loss.item()})

  # Setting up mean loss: n_val/batch_size= total number of batches
    epoch_losses[epoch] = epoch_loss/(n_train/batch_size)
    val_losses[epoch] = val_loss/(n_val/batch_size)
    
    
    pbar.set_postfix(**{'loss (epoch)': epoch_loss,'Val (epoch)': val_loss})  #.update('loss (epoch) = %f'%epoch_loss)

  # save checkpoint networks every now and then
  if epoch % n_epochs_per_save == 0:
    print(f"\nSaving network state at epoch {epoch+1}")
    saveepoch = epoch
    savefile = f"unet_epoch_aug3_{saveepoch+1}.pth"
    net.save_model(savefile)
print(f"\nSaving network state at epoch {epoch+1}")
if augmentation:
    fname = 'unet_epoch_'+str(epoch+1)+'_aug3.pth'
    np.save('aug3_epoch_loss',epoch_losses)
    np.save('aug3_val_losses',val_losses)
else:
    fname = 'unet_epoch_'+str(epoch+1)+'_no_augmentation.pth'
    np.save('no_augmentation_epoch_loss',epoch_losses)
    np.save('no_augmentation_val_losses',val_losses)    
net.save_model(fname)

In [None]:
for dirname, _, filenames in os.walk('/kaggle'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
train_loss = np.load('/kaggle/input/high-contrast-unet/augmentation2_epoch_loss.npy')
val_loss = np.load('/kaggle/input/high-contrast-unet/augmentation2_val_losses.npy')
print(train_loss.shape, val_loss.shape)

In [None]:
plt.rcParams.update({'font.size': 18})

def plot_loss(epoch_losses, val_losses):
    fig = plt.figure(figsize=(6,6))
    plt.plot(epoch_losses,'black',label='Training')
    plt.plot(val_losses,color='r',linestyle=':',label='Validation')
    plt.xlabel('# Epoch')
    plt.ylabel('Loss')
    plt.ylim([0,np.max([epoch_losses, val_losses])])
    plt.xlim([0,10])
    plt.legend()
    
plot_loss(train_loss,val_loss)

In [None]:
# load saved model from local PC
# net = Unet(nbase, nout, kernel_size) # build a new model
path = '/kaggle/input/unet-aug3/unet_epoch_50_aug3.pth'
net.load_state_dict(torch.load(path,map_location=torch.device('cpu'))) # assign the saved parameters to the new model

## Model architecture (u-net)

A u-net is commonly used for biological image segmentation because its shape allows for local and global features to be combined to create highly-precise segmentations.

A u-net is shaped like an autoencoder, it has:
1. a standard convolutional network with downsampling, like one used for imagenet
2. upsampling layers that ultimately return an image at the same size as the input image
In addition to these downsampling and upsampling blocks, it has skip connections from the downsampling blocks TO the upsampling blocks, which allows it to propagate more precise local information to the later layers.

adapted from [cellpose/resnet_torch.py](https://github.com/MouseLand/cellpose/blob/master/cellpose/resnet_torch.py).


### Define the network

### Train the network

Here we've implemented code to train the network.

Note we probably should be evaluating test performance throughout training -- implement that yourself.

### Test performance

Let's see how the network performs on a test image.

In [None]:
#torch.cuda.empty_cache()
# !nvidia-smi

In [None]:
# compute results on test images
# (note for unet to run correctly we need to pad images to be divisible by 2**(number of layers))
# image = images_train[0]
n = 5;
image = resize(images_test)
label = resize(label_test.astype(float)).astype(int)
net.eval()
img_torch = torch.tensor(image).to(device)
img_torch = img_torch.to(torch.float32)
#img_torch = torch.from_numpy(img_padded).to(device).unsqueeze(0)  # also need to add a first dimension
print(img_torch.shape)
out = net(img_torch[0:n])
out[0].shape

In [None]:
from matplotlib import gridspec

color_code = '#0699C8'#'#e377c2'
color_code = '#006666'#'#e377c2'
legend_pos = 'lower left'
fig = plt.figure(figsize=(12,12))
nrows=5;
ncols=4;
gs = gridspec.GridSpec(nrows, ncols,figure=fig)
# gs.update(wspace=0.6, hspace=0.3) # set the spacing between axes. 

for rowIndx in range(nrows):
    prediction = np.array(out[rowIndx].detach().cpu())
    seg = np.argmax(prediction,axis=0)
    for cIndx in range(ncols):
        ax = fig.add_subplot(gs[rowIndx, cIndx]) 
        ax.axis('off')
        if cIndx == 0:
            plt.imshow(image[rowIndx][0],cmap='gray')
            plt.colorbar(fraction=0.046, pad=0.04)
            if rowIndx == 0:
                plt.title('Normalized MRI')
        if cIndx == 1: 
            plt.imshow(seg, cmap = 'gray')
            plt.colorbar(fraction=0.046, pad=0.04)
            if rowIndx == 0:
                plt.title('Predicted seg.')
        if cIndx == 2: 
            plt.imshow(label[rowIndx][0], cmap = 'gray')
            plt.colorbar(fraction=0.046, pad=0.04)
            if rowIndx == 0:
                plt.title('True seg.')  
        temp = label[rowIndx][0] +  2*seg        
        if cIndx == 3: 
            plt.imshow(temp[50:-80,60:-60], cmap = 'bwr')
            cbar = plt.colorbar(fraction=0.046, pad=0.04,ticks=[0, 1.5, 3])
            cbar.ax.set_yticklabels(['Background','No overlap','Overlap'],rotation=0)
            if rowIndx == 0:
                plt.title('Overlap')                   
plt.rc('font', family='sans-serif', serif='Helvetica')
plt.suptitle('Example segmentation for test data (contrast augmentation)', fontsize = 15)
plt.tight_layout()
fig.subplots_adjust(top=0.92)
plt.savefig('Example_with_contrast_aug.png', dpi=600,bbox_inches='tight')
# plt.savefig('Example_with_augmentation.svg', dpi=600,bbox_inches='tight')

In [None]:
n=12
plt.imshow(label_train_aug[13,0],cmap='gray')
plt.axis('off')
plt.savefig('Example_seg.png', dpi=600,bbox_inches='tight')

In [None]:
plt.imshow(images_train_aug[23,0],cmap='gray')
plt.axis('off')
plt.savefig('Example_im.png', dpi=600,bbox_inches='tight')

In [None]:
prediction = np.array(out.detach().cpu())
seg = np.argmax(prediction,axis=1)
plt.imshow(seg[4])

In [None]:
# Test network performance & plot performance metrics
# https://scikit-learn.org/stable/modules/classes.html?highlight=metric#module-sklearn.metrics
from sklearn.metrics import precision_score, accuracy_score, f1_score

def compute_metrics(all_labels, all_segs): 
    new_label = np.squeeze(all_labels)
    new_segs = np.squeeze(all_segs)
    p = []
    f = []
    n = new_label.shape[0]
    for imgIndx in range(n):
        precision = precision_score(new_label[imgIndx].flatten(), new_segs[imgIndx].flatten())
        f_score = f1_score(new_label[imgIndx].flatten(), new_segs[imgIndx].flatten())
        if (np.max(new_label[imgIndx])==0) & (np.max(new_segs[imgIndx])==0):
          p.append(1.0)
          f.append(1.0)
        else:
          p.append(precision)
          f.append(f_score)     
    return p, f

In [None]:
import time
start = time.time()
net.eval()
image = resize(images_test)
label = resize(label_test.astype(float)).astype(int)
batch_size = 16
seg = np.empty((label.shape[0], label.shape[2],label.shape[3]));              # initialize the predictions
n_test = images_test.shape[0]
for ibatch in np.arange(0,n_test,batch_size):
    inds = np.arange(ibatch, min(n_test, ibatch+batch_size))
    img_torch = torch.tensor(image[inds]).to(device)
    img_torch = img_torch.to(torch.float32)
    lbls_torch = torch.tensor(label[inds]).to(device)
    #     lbls_torch = lbls_torch.to(torch.int64)
    out = net(img_torch)
    out = np.array((out.detach()).to('cpu'))
    

    out = np.argmax(out,axis=1)
#     print(out.shape)
#     seg[inds][:][:][:] = np.expand_dims(out, 1)      #I don't know hwy this does not work..
    seg[inds] = out 
#     #     print(np.expand_dims(out, 1).shape)
seg = np.expand_dims(seg, 1)
end = time.time()
print(end - start)

In [None]:
resize_label_test = resize(label_test.astype(float)).astype(int)
p, f = compute_metrics(np.squeeze(resize_label_test), seg)

In [None]:
print('Mean of p: ', np.mean(p) )
print('Median of p: ', np.median(p) )
print('STD of p: ', np.std(p) )
print('Mean of F1: ', np.mean(f) )
print('Median of F1: ', np.median(f) )
print('STD of F1: ', np.std(f) )

In [None]:
# plot precision and F1 score
plt.figure(figsize=(2.5,2.5))
plt.hist(p,20)
plt.vlines(np.median(p),0,200,color='black')
plt.xlabel('Prediction precision')
plt.ylabel('Counts')
plt.suptitle('High-contrast aug.')

plt.savefig('Highcontrast_aug_precision.png', dpi=600,bbox_inches='tight')

plt.figure(figsize=(2.5,2.5))
plt.hist(f,20)
plt.vlines(np.median(f),0,300,color='black')
plt.xlabel('Prediction F1 score')
plt.ylabel('Counts')
plt.suptitle('High-contrast aug.')

plt.savefig('Highcontrast_aug_f1.png', dpi=600,bbox_inches='tight')



In [None]:
mean = {}
median = {}
STD = {}

# augmentation with contrast
mean['con_p'] = np.mean(p)
median['con_p'] = np.median(p)
STD['con_p'] = np.std(p)
mean['con_f'] = np.mean(f)
median['con_f'] = np.median(f)
STD['con_f'] = np.std(f)

In [None]:
from sklearn.metrics import average_precision_score
average_precision = average_precision_score(resize_label_test[indx[nn]], seg[indx[nn]])
print(average_precision)


In [None]:
# find images with low prediction precision 
indx = np.where((np.array(p)<0.3) & (np.array(p)>0.2))[0]
indx

In [None]:
p[indx[nn]]

In [None]:
p_, f_ = compute_metrics(resize_label_test[indx[nn]:indx[nn]+1], seg[indx[nn]:indx[nn]+1])
print(p_,f_)
