#### Import and set up paths

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from skimage.io import imread
from pathlib import Path
import pandas as pd
import cv2

import torch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
import logging
import torchvision 
from torchvision import datasets, transforms, utils
from torchsummary import summary
import torch.optim as optim

from google.colab import drive
drive.mount('/content/drive')

In [None]:
%load_ext tensorboard

In [None]:
root_directory = Path("/content/")
drive_path = root_directory / 'drive/MyDrive' / 'CS101' 

In [None]:
# Path to zip file in my drive
!unzip '/content/drive/My Drive/CS101/CS101_z_anish.zip'

In [None]:
if not os.path.exists("pytorch_unet.py"):
  if not os.path.exists("pytorch_unet"):
    !git clone https://github.com/usuyama/pytorch-unet.git

  %cd pytorch-unet

import pytorch_unet
%cd /content

#### Initialize Data

In [None]:
class EmbryoDataset(Dataset):

    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        image_list_file = os.path.join(self.data_dir, 'DIC.list') 
        gt_list_file = os.path.join(self.data_dir, 'seg.list')
        fl_list_file = os.path.join(self.data_dir, 'fluo.list')

        with open(image_list_file, 'r') as f:
            self.image_list = f.read().splitlines()
        with open(gt_list_file, 'r') as f:
            self.gt_list = f.read().splitlines()
        with open(fl_list_file, 'r') as f:
            self.fluo_list = f.read().splitlines()

        onset_df = pd.read_excel(drive_path / 'Time_Annotation.xlsx')
        self.onset_dict = dict(zip(onset_df["embryo_index"].astype(int), onset_df["onset"].astype(int)))
        self.z_dict = dict(zip(onset_df["embryo_index"].astype(int), onset_df["z_num"].astype(int)))

        self.data_dir = data_dir
        self.length = len(self.image_list)
        self.transform = transform

    def __getitem__(self, index):
        image_path = os.path.join(self.data_dir, self.image_list[index]) + '.npy'
        gt_path = os.path.join(self.data_dir, self.gt_list[index])
        fluo_path = os.path.join(self.data_dir, self.fluo_list[index])
        try:
            image = np.load(image_path)
        except:
            print(image_path)
        gt = imread(gt_path, as_gray=True)
        fluo = imread(fluo_path, as_gray=True)
        image = self.image_transform(image)
        gt = self.gt_transform(gt)

        # Extract embryo image details
        embryo_details = gt_path.split('/')
        embryo_num = int(embryo_details[-2][6:])
        timestep = int(embryo_details[-1][1:-4])
        onset = self.onset_dict[embryo_num]

        tags = {"embryo_num": embryo_num, "timestep": timestep, "onset": onset}
        return image, gt, fluo, tags

    def __len__(self):
        return self.length

    def image_transform(self, image):
        trans =  transforms.Normalize([0.4111] * 32, [0.2077] * 32)

        image = image / 255.0
        image = torch.from_numpy(image)
        image = trans(image)
        squeezed = torch.unsqueeze(image, 0)
        return squeezed.float()

    def gt_transform(self, gt):
        gt = gt / 255.0
        return torch.squeeze(torch.from_numpy(gt)).float()

In [None]:
# Instantiate train, validation, and test dataloaders
train_set = EmbryoDataset(root_directory / 'trainset')
test_set = EmbryoDataset(root_directory / 'testset')
val_set = EmbryoDataset(root_directory / 'valset')

D = 32 # Number of channels in input image
batch_size = 1

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
train_loader_not_shuffled = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

dataloaders = {"train": train_loader, "test": test_loader, "val": val_loader}

#### Training helper functions

In [None]:
%cd pytorch-unet
from loss import dice_loss
%cd /content
from collections import defaultdict
import torch.nn.functional as F

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def calculate_pixel_metrics(pred, label):
    '''
    Calculate TP, FP, FN of one image
    '''
    TP = (pred * label).sum()    
    FP = ((1-label) * pred).sum()
    FN = (label * (1-pred)).sum()

    return TP.item(), FP.item(), FN.item()

def get_samples_for_viz(val_metrics, N, thresh):
    '''
    Find images with best/worst/median iou at the given thresh
    '''
    all_imgs_by_iou = []
    for batch in val_metrics:
        imgs = val_metrics[batch][thresh]
        for pos in imgs:
            info = imgs[pos]
            iou = info["iou"]
            all_imgs_by_iou.append((iou, batch, pos, info))
    
    # Sort by lowest to highest
    all_imgs_by_iou.sort(key=lambda x: x[0])
    best_ious = all_imgs_by_iou[-N:]
    worst_ious = all_imgs_by_iou[:N]

    # computing strt, and end index for middle N
    strt_idx = (len(all_imgs_by_iou) // 2) - (N // 2)
    end_idx = (len(all_imgs_by_iou) // 2) + (N // 2)
    
    # slicing extracting middle elements
    middle_ious = all_imgs_by_iou[strt_idx: end_idx]

    return best_ious, middle_ious, worst_ious

def show_img_sample(axs, batch, inputs, labels, fls, tags, model, ious, subplot_row, thresh):
    for bound_iou, img_batch, pos, _ in ious:
        if batch == img_batch:
            inputs = inputs.to(device)

            embryo_num = tags["embryo_num"][pos]
            timestep = tags["timestep"][pos]
            onset = tags["onset"][pos]

            preds_prob = model(inputs).detach()
            preds = torch.where(preds_prob > thresh, 1.0, 0.0).int()

            axs[subplot_row, 0].set_title("FL")
            axs[subplot_row, 0].imshow(fls[pos], cmap="gray")
            axs[subplot_row, 1].set_title("GT")
            axs[subplot_row, 1].imshow(labels[pos], cmap="gray")
            axs[subplot_row, 2].set_title("Model Output".format(bound_iou))
            axs[subplot_row, 2].imshow(preds_prob[pos].cpu().numpy(), cmap="gray")
            axs[subplot_row, 3].set_title("Prediction IOU: {}".format(bound_iou))
            axs[subplot_row, 3].imshow((preds[pos]).cpu().numpy(), cmap="gray")
            embryo_info_text = "Embryo {}\n Timestep {}\n Onset {}".format(embryo_num, timestep, onset)
            axs[subplot_row, 4].text(0.5, 0.5, embryo_info_text, horizontalalignment='center', verticalalignment='center')

            subplot_row += 1
    return subplot_row

def visualize_metrics(model, val_loader, val_metrics, N=2, non_boundary_metrics=False):
    '''
    Takes data in val_metrics and displays:
        PR Curve
        Images with highest iou
        Images with lowest iou
        Images with median iou
        ?iou curve vs threshold
    '''

    # Compute PR curve and IoU scores
    threshs = np.linspace(0, 1, num=101) # Set of thresholds to try
    precisions = np.zeros_like(threshs)
    recalls = np.zeros_like(threshs)

    # Checks all thresholds and finda threshold that corresponds to best iou
    max_iou = 0
    curr_iou = 0
    max_threshold = 0
    num_images = len(val_loader) * batch_size
    iou_per_thresh = []
    for i in range(len(threshs)):
        thresh = threshs[i]
        for batch in val_metrics:
            thresh_dict = val_metrics[batch][thresh]
            for pos in thresh_dict:
                curr_iou += thresh_dict[pos]['iou']
                precisions[i] += thresh_dict[pos]['precision']
                recalls[i] += thresh_dict[pos]['recall']
        iou_per_thresh.append(curr_iou / num_images)
        if curr_iou > max_iou:
            max_iou = curr_iou
            max_threshold = thresh
        curr_iou = 0
    
    max_iou /= num_images
    precisions /= num_images
    recalls /= num_images

    print("\n###")
    print("Mean IoU at threshold {}: {}".format(max_threshold, max_iou))
    print("\n###")

    # Get samples with best, middle, and worst ious
    best_ious, middle_ious, worst_ious = get_samples_for_viz(val_metrics, N, max_threshold)

    fig, axs = plt.subplots(7, 5, figsize=(35, 35))
    
    axs[0,0].set_title("Image PR Curve")
    axs[0,0].set_xlabel("Recall")
    axs[0,0].set_ylabel("Precision")
    axs[0,0].set_xlim([0,1])
    axs[0,0].set_ylim([0,1])
    axs[0,0].plot(recalls, precisions)

    axs[0,2].set_title("IoU vs Threshold")
    axs[0,2].set_xlabel("Threshold")
    axs[0,2].set_ylabel("IoU")
    axs[0,2].set_xlim([0,1])
    axs[0,2].set_ylim([0,1])
    axs[0,2].plot(threshs, iou_per_thresh)

    subplot_row = 1
    dataloader_iterator = iter(val_loader)
    for batch in range(len(val_loader)):
        inputs, labels, fls, tags = next(dataloader_iterator)
        subplot_row = show_img_sample(axs, batch, inputs, labels, fls, tags, model, best_ious, subplot_row, max_threshold)
        subplot_row = show_img_sample(axs, batch, inputs, labels, fls, tags, model, middle_ious, subplot_row, max_threshold)
        subplot_row = show_img_sample(axs, batch, inputs, labels, fls, tags, model, worst_ious, subplot_row, max_threshold)
    plt.show()

def populate_metrics(inputs, labels, preds_prob, batch_num, val_metrics):
    '''
    ASSUMES VAL_LOADER IS NOT SHUFFLED
    Takes in batch and adds to the validation metrics
    val_metrics layout {
        batch_num: {
            thresh: {
                pos: {
                    TP:
                    FP:
                    FN:
                    iou:
                    precision:
                    recall:
                }
            }
        }
    }
    '''
    threshs = np.linspace(0, 1, num=101) # Set of thresholds to try
    SMOOTH = 1e-6
    batch_dict = {}

    for i, thresh in enumerate(threshs):
        preds = torch.where(preds_prob > thresh, 1.0, 0.0).int()
        thresh_dict = {}
        for pos in range(labels.shape[0]):
            pos_dict = {}
            img_TP, img_FP, img_FN = calculate_pixel_metrics(preds[pos], labels[pos])

            img_precision = (img_TP+SMOOTH) / (img_TP+img_FP+SMOOTH)
            img_recall = (img_TP+SMOOTH) / (img_TP+img_FN+SMOOTH)
            img_iou = (img_TP+SMOOTH) / (img_TP+img_FP+img_FN+SMOOTH)

            pos_dict['TP'] = img_TP
            pos_dict['FP'] = img_FP
            pos_dict['FN'] = img_FN
            pos_dict['iou'] = img_iou
            pos_dict['precision'] = img_precision
            pos_dict['recall'] = img_recall
            thresh_dict[pos] = pos_dict
        batch_dict[thresh] = thresh_dict
    val_metrics[batch_num] = batch_dict

#### Declare Model and Training loop

In [None]:
import torch.nn as nn
import torchvision.models
from collections import OrderedDict

class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=1, init_features=16):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose3d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose3d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose3d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose3d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv3d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

        self.final_pool = nn.MaxPool3d(kernel_size=(32, 1, 1), stride=1)

    def forward(self, x):
        #print("x: {}".format(x.shape))
        enc1 = self.encoder1(x)
        #print("Enc1: {}".format(enc1.shape))
        enc2 = self.encoder2(self.pool1(enc1))
        #print("Enc2: {}".format(enc2.shape))
        enc3 = self.encoder3(self.pool2(enc2))
        #print("Enc3: {}".format(enc3.shape))
        enc4 = self.encoder4(self.pool3(enc3))
        #print("Enc4: {}".format(enc4.shape))

        bottleneck = self.bottleneck(self.pool4(enc4))
        #print("Bottleneck: {}".format(bottleneck.shape))

        dec4 = self.upconv4(bottleneck)
        #print("upconv dec4: {}".format(dec4.shape))
        dec4 = torch.cat((dec4, enc4), dim=1)
        #print("cat dec4: {}".format(dec4.shape))
        dec4 = self.decoder4(dec4)
        #print("decoded dec4: {}".format(dec4.shape))
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        z_stack = self.conv(dec1)
        logits = self.final_pool(z_stack)
        
        # Make output (batch size) x H x W
        logits = torch.squeeze(logits, dim=1)
        logits = torch.squeeze(logits, dim=1)
        return torch.sigmoid(logits)

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv3d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm3d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv3d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm3d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
class FocalTverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalTverskyLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1, alpha=0.75, beta=0.25, gamma=0.75):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #True Positives, False Positives & False Negatives
        TP = (inputs * targets).sum()    
        FP = ((1-targets) * inputs).sum()
        FN = (targets * (1-inputs)).sum()
        
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        FocalTversky = (1 - Tversky)**gamma
                       
        return FocalTversky

def calc_loss(pred, target, metrics):
    ft = FocalTverskyLoss()
    loss = ft.forward(pred, target)
    metrics['loss'] += loss.data.cpu().numpy()
    return loss

def train_model(model, optimizer, scheduler, checkpoint_path, num_epochs=25, start_epoch=0):
    best_loss = 1e10

    for epoch in range(start_epoch, num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        epoch_loss_full = defaultdict(float)
        epoch_loss_dice = defaultdict(float)
        epoch_loss_bce = defaultdict(float)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            batch_num = -1

            if phase == "val":
                val_metrics = {}

            for inputs, labels, _, _ in tqdm(dataloaders[phase]):
                batch_num += 1
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                iteration = epoch_samples // inputs.size(0)

                if phase == "val":
                    populate_metrics(inputs, labels, outputs, batch_num, val_metrics)
                    
                if iteration % 50 == 1 and phase == 'train':
                    #torch.save(model.state_dict(), os.path.join(checkpoint_path, '{}_{}.pth'.format(epoch, iteration)))
                    print_metrics(metrics, epoch_samples, phase)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            if phase == 'train':
                # Update the learning rate
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                
                # Log training loss
                epoch_loss_full['train'] = metrics['loss'] / epoch_samples
                # epoch_loss_bce['train'] = metrics['bce'] / epoch_samples
                # epoch_loss_dice['train'] = metrics['dice'] / epoch_samples
                

            if phase == "val":
                # Log validation loss
                epoch_loss_full['val'] = metrics['loss'] / epoch_samples
                # epoch_loss_bce['val'] = metrics['bce'] / epoch_samples
                # epoch_loss_dice['val'] = metrics['dice'] / epoch_samples

                # Visualize model results
                if epoch % 5 == 1:
                    print("Epoch {}: Training visualizations".format(epoch))
                    train_metrics = {}
                    batch_num = -1
                    for inputs, labels, _, _ in tqdm(train_loader_not_shuffled):
                        batch_num += 1
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        with torch.set_grad_enabled(phase == 'train'):
                            outputs = model(inputs)
                        populate_metrics(inputs, labels, outputs, batch_num, train_metrics)
                    visualize_metrics(model, train_loader_not_shuffled, train_metrics)

                print("Epoch {}: Validation visualizations".format(epoch))
                with torch.set_grad_enabled(phase == 'train'):
                    visualize_metrics(model, dataloaders[phase], val_metrics)


            if phase == 'val':
                if epoch_loss < best_loss:
                    # save the model weights
                    print(f"saving best model to {checkpoint_path}")
                    best_loss = epoch_loss
                    torch.save(model.state_dict(), os.path.join(checkpoint_path, 'best_{}.pth'.format(epoch)))
                else:
                    torch.save(model.state_dict(), os.path.join(checkpoint_path, '{}.pth'.format(epoch)))
        
        # Plot Loss
        writer.add_scalars('Loss', epoch_loss_full, epoch)
        # writer.add_scalars('BCE Loss', epoch_loss_bce, epoch)
        # writer.add_scalars('Dice Loss', epoch_loss_dice, epoch)


        #torch.save(model.state_dict(), os.path.join(checkpoint_path, '{}.pth'.format(epoch)))
        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val loss: {:4f}'.format(best_loss))

#### Load Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device', device)

model = UNet().to(device)
#summary(model, input_size=(1, 32, 512, 512), batch_size=batch_size)

device cuda


In [None]:
# Run if you want to load model from checkpoint

model_path = drive_path / 'exps' / 'anish' / 'exp20' / 'checkpoints' / 'best_7.pth'
model.load_state_dict(torch.load(model_path))

#### Train Model

In [None]:
log_path = drive_path / 'exps' / 'anish' / 'exp20' / 'logs'
checkpoint_path = drive_path / 'exps' / 'anish' / 'exp20' / 'checkpoints'
writer = SummaryWriter(log_path)

start_epoch = 0
optimizer_ft = optim.Adam(model.parameters(), lr=0.00005)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=8, gamma=0.1)
for i in range(start_epoch):
    exp_lr_scheduler.step()
train_model(model, optimizer_ft, exp_lr_scheduler, checkpoint_path, num_epochs=30, start_epoch = start_epoch)

In [None]:
writer.flush()
writer.close()
%tensorboard --logdir='/content/drive/MyDrive/CS101/exps/anish/exp20/logs'

#### Evaluate performance on validation set

In [None]:
val_metrics = {}
model.eval()
batch_num = -1
for inputs, labels, fls, _ in tqdm(val_loader):
    batch_num += 1
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    populate_metrics(inputs, labels, outputs, batch_num, val_metrics)

In [None]:
visualize_metrics(model, val_loader, val_metrics)

#### Evaluate performance on train set

In [None]:
train_metrics = {}
model.eval()
batch_num = -1
for inputs, labels, _ in tqdm(train_loader_not_shuffled):
    batch_num += 1
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = model(inputs)
    populate_metrics(inputs, labels, outputs, batch_num, train_metrics)

In [None]:
visualize_metrics(model, train_loader_not_shuffled, train_metrics)