In [None]:
%matplotlib inline
from IPython.core.display import display, HTML
# display(HTML("<style>.container { width:100% !important; }</style>"))
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"


import torch
from __future__ import division, print_function
from torch.utils.data import Dataset
import glob
import os
import zarr
import random
import datetime
import numpy as np

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
# import h5py
# from imgaug import augmenters as iaa
# from imgaug.augmentables.segmaps import SegmentationMapsOnImage
# from imgaug.augmentables.heatmaps import HeatmapsOnImage
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchsummary import summary
from skimage import io
from utils.colormap import *
from unet_fov import *

plt.rcParams['image.cmap'] = 'gist_earth'
torch.backends.cudnn.benchmark = True

In [None]:
# decompress data
from shutil import unpack_archive
unpack_archive(os.path.join('datasets','example_toy_data.zip'), './')

In [None]:
def running_mean(x, n):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[n:] - cumsum[:-n]) / float(n)

def plot_history(history):
    # plot training and validation loss and binary accuracy
    
    loss = running_mean(history['loss'], 9)
    val_loss = running_mean(history['val_loss'], 9)
    #epochs = len(history.history['loss'])
    epochs = len(loss)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(range(0, epochs), loss , label='loss')
    ax1.plot(range(0, epochs), val_loss, label='val_loss')
    ax1.set_title('train and validation loss')
    ax1.legend(loc='upper right')
    
    acc = running_mean(history['binary_accuracy'], 9)
    val_acc = running_mean(history['val_binary_accuracy'], 9)

    ax2.plot(range(0, epochs), acc, label='binary_accuracy')
    ax2.plot(range(0, epochs), val_acc, label='val_binary_accuracy')
    ax2.set_title('train and validation binary accuracy')
    ax2.legend(loc='lower right')

    plt.show()

In [None]:
def show_predictions(raw, gt, pred):
    
    thresh = 0.9
    max_values = np.max(pred[:,0], axis=(1, 2))
    if np.any(max_values < thresh):
        print("Heads up: If prediction is below {} then the prediction map is shown.".format(thresh))
        print("Max predictions: {}".format(max_values))
    
    num_samples = pred.shape[0]
    fig, ax = plt.subplots(num_samples, 3, sharex=True, sharey=True, figsize=(12, num_samples * 4))
    for i in range(num_samples):
        ax[i, 0].imshow(raw[i,0], aspect="auto")
        ax[i, 1].imshow(gt[i,0], aspect="auto")
        # check for prediction threshold
        if np.sum(max_values[i]) < thresh:
            ax[i, 2].imshow(pred[i,0], aspect="auto")
        else:
            ax[i, 2].imshow(pred[i,0] >= thresh, aspect="auto")

    ax[0, 0].set_title("Input")
    ax[0, 1].set_title("Ground truth")
    ax[0, 2].set_title("Prediction")
    fig.tight_layout()

## (1) Load and visualize our toy data examples:


we have 12 train samples, 3 validation samples and 3 test samples.

In [None]:
# load tif images and reformat it
def load_dataset(in_folder):
    x = []
    y = []
    raw_files = glob.glob(in_folder + '/raw_*.tif')
    for raw_file in raw_files:
        x.append(io.imread(raw_file)) 
        y.append(io.imread(raw_file.replace('raw', 'gt'))) 
    x = np.array(x)[:, np.newaxis] #shape [N, 1, H, W]
    y = np.array(y)[:, np.newaxis] #shape [N, 1, H, W]
    return x, y

In [None]:
# load data into train/val/test sets
x_train, y_train = load_dataset('example_toy_data/train')
x_val, y_val = load_dataset('example_toy_data/val')
x_test, y_test = load_dataset('example_toy_data/test')

batch_size = 4
# make dataloaders
train_loader = DataLoader(list(zip(x_train, y_train)), batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(list(zip(x_val, y_val)), batch_size=1, pin_memory=True)
test_loader = DataLoader(list(zip(x_test, y_test)), batch_size=1)

In [None]:
# show training examples
num_samples = 3
fig, ax = plt.subplots(num_samples, 2, sharey=True, figsize=(8, num_samples * 4))
for i in range(num_samples):
    ax[i, 0].imshow(x_train[i,0], aspect="auto")
    ax[i, 1].imshow(y_train[i,0], aspect="auto")
ax[0, 0].set_title("Input")
ax[0, 1].set_title("Ground truth")
fig.tight_layout()

## (2) #TODO: Create and train our model

This model is composed by stacking one UNet instance and one convolution layer. The UNet Class is defined in the **unet_fov.py** file and we use the default setting that the number of feature maps of the UNet instance output will be eqaul to the number of feature maps at the first convolution. Then we use one more convolution layer with kernel_size=1 to generate the final output with number of feature maps equal to what we want.

For the meaning of parameters of the UNet class, please refer to the **unet_fov.py** file.

#TODO:
- 1.Define the output channels of the model
- 2.Define the d_factor which is the downsampling_factor parameter of the UNet class. 
- 3.Define the activation parameter. Note: this activation parameter will be used in the later functions **training_step** and **predict**, don't confuse it with the activation parameter defined for the UNet class.
- 4.Define an UNet which use same padding with depth=4, downsampling factor=2 for each dimension. Set the activation layers inside the UNet to be "ReLU" and use str as input.

In [None]:
# set seed
torch.manual_seed(42)

###########################################################################
# TODO: Define the net and uncomment the following code                   #                     
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****


# out_channels = 
# d_factors =
# activation = 

# net = torch.nn.Sequential(
#     UNet(in_channels=,
#     num_fmaps=,
#     fmap_inc_factors=,
#     downsample_factors=d_factors,
#     activation='ReLU',
#     padding=,
#     num_fmaps_out=,
#     constant_upsample=False
#     ),
#     torch.nn.Conv2d(in_channels=, out_channels=out_channels, kernel_size=1, padding=0, bias=True))

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################
receptive_field, _ = net[0].get_fov()

# print network layers
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = net.to(device)
summary(net, (1,512,512))
print("Receptive field: ", receptive_field)

### Receptive Field of View

The number of convolutions and the depth of the U-Net are the major factors in determining the 
receptive field of the network. The term is borrowed from biology where it describes the "portion of sensory space that can elicit neuronal responses when stimulated" (wikipedia). Each output pixel can look at/depends on an input patch with that diameter centered at its position.
Based on this patch, the network has to be able to make a decision about the prediction for the respective pixel.
Yet larger sizes increase the computation time significantly.

The following code snippet visualizes the field of view of the center pixel for networks with varying depth.

In [None]:
rnd = random.randrange(len(x_train))
image = x_train[rnd]
label = y_train[rnd]


net_t = image
fovs = []
d_factors = [[2,2],[2,2],[2,2],[2,2]]

for level in range(len(d_factors)):
    fov_tmp, _ = net[0].rec_fov(level , (1, 1), 1)
    fovs.append(fov_tmp[0])

image = np.squeeze(image)
fig=plt.figure(figsize=(8, 8))
colors = ["yellow", "red", "green", "blue", "magenta"]
plt.imshow(image, cmap='gray')
for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = image.shape[1]/2 - fov_t/2
    xmax = image.shape[1]/2 + fov_t/2
    ymin = image.shape[1]/2 - fov_t/2
    ymax = image.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

### TODO: Define the functions for training
In the training process of semantic segmentation scenario, we usally will record two basic criteria, loss and pixel accuracy.
The loss function we use here is the Binary Cross Entropy with the sigmoid function before it.
The pixel accuracy, or accuracy for short, refers to the percent of pixels in the image which were correctly classified.

#TODO:
- 1.Define the **calc_accuracy** function according to the activation parameter you define before.This function should calculate the pixel accuracy of the prediction results average on one batch.
- 2.Call the **train** function to start the training.

In [None]:
###########################################################################
# TODO:  calc_accuracy function                                           #                     
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

def calc_accuracy(y_pred, y_true):
    # Please remember to get the mean accuracy of one batch
    return 


# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################

In [None]:
dtype = torch.FloatTensor
def training_step(model, loss_fn, optimizer, feature, label):
    # speedup version of setting gradients to zero
    for param in model.parameters():
        param.grad = None
    # forward
    logits = model(feature) # B x C x H x W
    loss_value = loss_fn(input=logits, target=label) # label.squeeze(0) for three_class
    # backward if training mode
    if net.training:
        loss_value.backward()
        optimizer.step()
    if activation is not None:
        output = activation(logits)
    else:
        output = logits
    outputs = {
        'pred': output,
        'logits': logits,
    }
    return loss_value, outputs

def train(net, epochs, learning_rate,start_epoch=0, optimizer=None, history=None, early_stopping=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    loss_fn = loss_fn.to(device)
    # set optimizer
    if optimizer is None:
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) 
    if history is None:
        history = {'loss':[],
                  'val_loss':[],
                  'binary_accuracy':[],
                  'val_binary_accuracy':[]}
   
    pbar = tqdm(total=epochs*len(train_loader))
    for epoch in range(0, epochs):
    #for epoch in tqdm(range(0, epochs)):
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        train_acc_loss = []
        train_acc_accuracy = []
        net.train()
        for feature, label in tmp_loader:
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
            loss_value, outputs = training_step(net, loss_fn, optimizer, feature, label)
            pbar.update(1)
            train_acc_loss.append(loss_value.cpu().detach().numpy())
            accuracy = calc_accuracy(outputs['pred'], label)
            train_acc_accuracy.append(float(accuracy.cpu().detach().numpy()))
        history['loss'].append(np.mean(train_acc_loss))
        history['binary_accuracy'].append(np.mean(train_acc_accuracy))
        # validation
        net.eval()
        tmp_val_loader = iter(val_loader)
        val_acc_loss = []
        val_acc_accuracy = []
        for feature, label in tmp_val_loader:                    
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
            loss_value, outputs = training_step(net, loss_fn, optimizer, feature, label)
            val_acc_loss.append(loss_value.cpu().detach().numpy())
            accuracy = calc_accuracy(outputs['pred'], label)
            val_acc_accuracy.append(float(accuracy.cpu().detach().numpy()))
        history['val_loss'].append(np.mean(val_acc_loss))
        history['val_binary_accuracy'].append(np.mean(val_acc_accuracy))
        if early_stopping:
            #early_stopping(np.mean(acc_loss))
            early_stopping(np.mean(val_acc_loss))
            if early_stopping.early_stop:
                print('Early stopping after epoch', epoch)
                break
        print(f'Epoch {epoch+start_epoch+1}, train-loss: {np.mean(train_acc_loss):.4f} - train_accuracy:{np.mean(train_acc_accuracy):.4f}'+
      f' - val_loss:{np.mean(val_acc_loss):.4f} -val_accuracy:{np.mean(val_acc_accuracy):.4f}')
    return net, history, optimizer

In [None]:
num_epochs1 = 250
###########################################################################
# TODO:  train the net without using early_stopping                       #                     
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

# net, history, optimizer = train(net, num_epochs1,...)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################

In [None]:
# plot loss and accuracy
plot_history(history)

## (3) Test and evaluate our model

In [None]:
# predict the test set
def predict(net, test_loader):
    net.eval()
    predictions = []
    acc_accuracy = []
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)
        pred = net(image)
        pred = activation(pred)
        accuracy = calc_accuracy(pred, label)
        acc_accuracy.append(float(accuracy.cpu().detach().numpy()))
        image = np.squeeze(image.cpu())
        pred = np.squeeze(pred.cpu().detach().numpy(),0)
        predictions.append(pred)
    return predictions, float(np.mean(acc_accuracy))
    
# plot predicted results
predictions, mean_accuracy = predict(net, test_loader)
predictions = np.stack(predictions, axis=0)
print('Accuracy: {:.3f}'.format(mean_accuracy))
show_predictions(x_test, y_test, predictions)

### A1: #TODO: Continue training for more epochs

In [None]:
# continue training, takes ~3min
# heads up: the "net" variable still carries all the information from the previous training
num_epochs2=160
###########################################################################
# TODO:  continue to train the net without using early_stopping           #   
# remember to give the history parameter this time                        # 
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

# net, history, optimizer = train(net, num_epochs2,...)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################


print('Finished Training')

In [None]:
# plot loss and accuracy
plot_history(history)

In [None]:
# evaluate and plot predicted results
predictions, mean_accuracy = predict(net, test_loader)
predictions = np.stack(predictions, axis=0)
print('Accuracy: {:.3f}'.format(mean_accuracy))
show_predictions(x_test, y_test, predictions)

#### The training of the networks depend on many hyperparameters such as
- network architecture: #layers, #fmaps
- batch size, learning rate
- number and distribution of the training samples

#### You can play and see how these settings influence the learning curve.
![image.png](utils/example_learning_curves/lc_all.png)

![](example_learning_curves/lc_all.png)

### A2:  #TODO: Use early stopping to avoid overfitting
#TODO:
- 1.Define the net.
- 2.Look the EarlyStopping Class defined below and create an EarlyStopping instance.
- 3.Train the net with early-stopping function working.

In [None]:
class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    Code from https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
    """
    def __init__(self, patience=20, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
# define model

###########################################################################
# TODO: Define the net and uncomment the following code                   #
# You can basically copy the code that you use to define the net before   # 
# Please also define the early_stopping                                   # 
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****


# out_channels = 
# d_factors =
# activation = 
# early_stopping = 

# net = torch.nn.Sequential(
#     UNet(in_channels=,
#     num_fmaps=,
#     fmap_inc_factors=,
#     downsample_factors=d_factors,
#     activation='ReLU',
#     padding=,
#     num_fmaps_out=,
#     constant_upsample=False
#     ),
#     torch.nn.Conv2d(in_channels=, out_channels=out_channels, kernel_size=1, padding=0, bias=True))

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################


In [None]:
# train model
epochs = 500
###########################################################################
# TODO:  train the net and use early_stopping                             #   
###########################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

# net, history_w_ea,_ = train(net, epochs,...)

# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
###########################################################################
#                             END OF YOUR CODE                            #
###########################################################################
print('Finished Training')

In [None]:
# plot loss and accuracy
plot_history(history_w_ea)

In [None]:
# evaluate and plot predicted results
predictions, mean_accuracy = predict(net, test_loader)
predictions = np.stack(predictions, axis=0)
print('Accuracy: {:.3f}'.format(mean_accuracy))
show_predictions(x_test, y_test, predictions)