# Instance Segmentation

So far we were only interested in classes, what is background and foreground, where are cells or person vs car. But in many cases we not only want to know if a certain pixel belongs to a cell, but also to which cell.

For isolated objects, this is trivial, all connected foreground pixels form one instance, yet often instances are very close together or even overlapping. Then we need to think a bit more how to formulate the loss for our network and how to extract the instances from the predictions.

## Exercise 0.0: Importing packages

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import datetime

from glob import glob
from natsort import natsorted
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm
from tqdm.notebook import tqdm
from unet import *

torch.backends.cudnn.benchmark = True

## Exercise 1.0: Creating a simple model

<div class="alert alert-block alert-info"><h3>Exercise 1.1: Load and visualize data</h3>
    
    
- For this exercise we will be using data from [TissueNet](https://datasets.deepcell.org/)
- For our purposes we will use 50 training images and 20 testing images. The data is stored as tifs using the following structure:
```
woodshole/
    ├── test
    │   ├── img_0_cyto_masks.tif
    │   ├── img_0_nuclei_masks.tif
    │   ├── img_0.tif
    │   ├── img_1_cyto_masks.tif
    │   ├── img_1_nuclei_masks.tif
    │   └── img_1.tif
    │  
    └── train
        ├── img_0_cyto_masks.tif
        ├── img_0_nuclei_masks.tif
        ├── img_0.tif
        ├── img_1_cyto_masks.tif
        ├── img_1_nuclei_masks.tif
        └── img_1.tif
```
- Each raw image is stored as `img_{n}.tif` and is already stored as float32 so does not need to be normalized for training purposes. There are two channels in the raw data, one for nuclei and one for cytoplasm. 
- The corresponding mask files contain instance segmentations for the raw data. The nuclei masks correspond to the first channel of the raw data. The cytoplasm masks correspond to nucleus + cytoplasm. We will start with the nuclei data, as it is likely easier to segment than the cytoplasm, but you will be able to apply the techniques you learn on the harder data at the end of the exercise, time permitting.

</div>

In [None]:
# lets start by loading our images into lists so we can visualize and get oriented with our data
# natsorted is a package that takes away some of the annoyances of the regular sorting function
# glob is a package that allows us to load files from directories. 

train_cyto = natsorted(glob('woodshole/train/*cyto*'))
train_nuclei = natsorted(glob('woodshole/train/*nuclei*'))
train_raw = [i for i in natsorted(glob('woodshole/train/*.tif')) if 's.tif' not in i]

test_cyto = natsorted(glob('woodshole/test/*cyto*'))
test_nuclei = natsorted(glob('woodshole/test/*nuclei*'))
test_raw = [i for i in natsorted(glob('woodshole/test/*.tif')) if 's.tif' not in i]

In [None]:
# convenience functions for viewing labels as rgb, and reading files into numpy arrays
from skimage import color
from skimage.io import imread

# select a random cytoplasm mask file
cyto_file = random.choice(train_cyto)

# use skimage.io.imread to read our data into numpy arrays
cyto = imread(cyto_file)
nuclei = imread(cyto_file.replace('cyto', 'nuclei'))
raw = imread(cyto_file.replace('_cyto_masks', ''))

#our raw data shape is (c, h, w) and there are only two channels.
#to visualize as an rgb image we need to add another dummy dimension and then transpose so that it is (h,w,3)
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

# visualize the data - execute this cell a few times to see different examples.
# you can also change train_cyto to test_cyto above to see some test data. it is pretty similar 
fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw[:,:,0], cmap='gray')
axes[0][0].title.set_text('Raw nuclei channel')

axes[0][1].imshow(raw[:,:,1], cmap='gray')
axes[0][1].title.set_text('Raw cyto channel')

axes[0][2].imshow(raw)
axes[0][2].title.set_text('Raw overlay')

axes[0][3].imshow(raw[:,:,0], cmap='gray')
axes[0][3].imshow(color.label2rgb(nuclei), alpha=0.5)
axes[0][3].title.set_text('nuclei mask')

axes[0][4].imshow(raw[:,:,1], cmap='gray')
axes[0][4].imshow(color.label2rgb(cyto), alpha=0.5)
axes[0][4].title.set_text('cyto mask')

<div class="alert alert-block alert-info"><h3>Exercise 1.2: Create dataset</h3>
    
- Lets start by creating a simple dataset (similar to what was done in the image segmentation exercise)
- For now our dataset should just load the raw and mask data like above. We will just use the first channel of the raw data (nuclei) for now (`raw[0]`)

</div>

In [None]:
class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei'
                ):
        
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
                
    def __len__(self):
        return len(self.raw_files)
        
    def __getitem__(self, idx):
        
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0]

        return raw, mask

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train')

In [None]:
# get random batch

raw, mask = train_dataset[random.randrange(len(train_dataset))]
    
plt.imshow(raw, cmap='gray')
plt.imshow(color.label2rgb(mask), alpha=0.5)

<div class="alert alert-block alert-info"><h3>Exercise 1.3: Add augmentations</h3>
    
- You have already learned about the importance of augmenting your training data.
- For our exercise we will use an augmentation library called [albumentations](https://albumentations.ai/) which provides easy to use, fast transforms
- Here is a nice tutorial: https://albumentations.ai/docs/examples/example_kaggle_salt/
- To start, we will add a few simple augmentations to both our raw and mask data:
    - randomly crop a 64x64 patch
    - horizontally flip with a 50% probability
    - vertically flip with a 50% probability
- Below is a simple example. You should then add a function to your dataset to augment your batch. 

</div>

In [None]:
import albumentations as A

file = random.choice(train_nuclei)

full_mask_nuclei = imread(file)
full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]

transform = A.Compose([
              A.RandomCrop(width=64, height=64),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)
          
aug_raw, aug_mask = transformed['image'], transformed['mask']

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=False,sharey=False,squeeze=False)

axes[0][0].imshow(full_raw_nuclei, cmap='gray')
axes[0][0].imshow(color.label2rgb(full_mask_nuclei), alpha=0.5)

axes[0][1].imshow(aug_raw, cmap='gray')
axes[0][1].imshow(color.label2rgb(aug_mask), alpha=0.5)

In [None]:
# add an augmentation function

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        
    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = A.Compose([
              A.RandomCrop(width=self.crop_size, height=self.crop_size),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask
    
    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0]

        if self.split == 'train':
            raw, mask = self.augment_data(raw, mask)
            
        return raw, mask

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)

In [None]:
# get random batch

raw, mask = train_dataset[random.randrange(len(train_dataset))]
    
plt.imshow(raw, cmap='gray')
plt.imshow(color.label2rgb(mask), alpha=0.5)

<div class="alert alert-block alert-info"><h3>Exercise 1.4: Create fg/bg representation</h3>
    
- It would be ideal to directly predict unique labels in a dataset. Unfortunately this requires global information which can become difficult as datasets increase in size. Consequently, alternative approaches aim to solve the problem locally.
- We will start with the most trivial approach: learning a foreground / background mask and then relabeling connected pixels as unique objects. While this approach might suffice on simple datasets, you will see how it can become problematic on datasets in which objects are tightly packed.

</div>

In [None]:
# function to erode boundary pixels
from scipy.ndimage import binary_erosion

def erode(labels, iterations, border_value):

    # copy labels to memory, create border array
    labels = np.copy(labels)

    # create zeros array for foreground
    foreground = np.zeros_like(labels, dtype=bool)

    # loop through unique labels
    for label in np.unique(labels):

        # skip background
        if label == 0:
            continue

        # mask to label
        label_mask = labels == label

        # erode labels
        eroded_mask = binary_erosion(
                label_mask,
                iterations=iterations,
                border_value=border_value)

        # get foreground
        foreground = np.logical_or(eroded_mask, foreground)

    # and background...
    background = np.logical_not(foreground)

    # set eroded pixels to zero
    labels[background] = 0

    return labels

In [None]:
# visualize the representation (repeatedly run cell)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

labels = erode(
    mask,
    iterations=1,
    border_value=1)

labels_two_class = (labels != 0).astype(np.float32)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw, cmap='gray')
axes[0][0].title.set_text('Raw')

axes[0][0].imshow(color.label2rgb(mask), alpha=0.5)
axes[0][0].title.set_text('Segmentation')

axes[0][1].imshow(labels_two_class)
axes[0][1].title.set_text('Foreground / background')

In [None]:
# add fg / bg representation to dataset. let's also create an optional val/test split

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = A.Compose([
              A.RandomCrop(width=self.crop_size, height=self.crop_size),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0]

        if self.split == 'train':
            raw, mask = self.augment_data(raw, mask)
            
        fg = erode(
                    mask,
                    iterations=1,
                    border_value=1)
        
        mask = (fg != 0).astype(np.float32)

        # add channel dim for network
        raw = np.expand_dims(raw, axis=0)
        mask = np.expand_dims(mask, axis=0)
        
        return raw, mask

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test')
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True)

In [None]:
# run cell repeatedly to see different crops

raw, mask = train_dataset[random.randrange(len(train_dataset))]

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

axes[0][1].imshow(np.squeeze(mask))
axes[0][1].title.set_text('Mask')

<div class="alert alert-block alert-info"><h3>Exercise 1.5: Create shallow network, visualize receptive field</h3>
    
- Let's create a shallow two layer network and visualize the receptive field. We will see later how this receptive field changes as we add more layers and change our input image size
  
</div>

In [None]:
raw, mask = train_dataset[random.randrange(len(train_dataset))]

net_t = raw
fovs = []
d_factors = [[2,2],[2,2]]

net = UNet(in_channels=1,
           num_fmaps=6,
           fmap_inc_factors=2,
           downsample_factors=d_factors,
           padding='same'
          )

for level in range(len(d_factors)+1):
    fov_tmp, _ = net.rec_fov(level , (1, 1), 1)
    fovs.append(fov_tmp[0])

fig=plt.figure(figsize=(8, 8))
colors = ["yellow", "red", "green"]

plt.imshow(np.squeeze(raw), cmap='gray')

for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = raw.shape[1]/2 - fov_t/2
    xmax = raw.shape[1]/2 + fov_t/2
    ymin = raw.shape[1]/2 - fov_t/2
    ymax = raw.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

<div class="alert alert-block alert-info"><h3>Exercise 1.6: Set hyperparameters, create model</h3>
    
- Let's start by setting some hyperparameters. Since we are just doing a fg/bg prediction to start, this will be pretty similar to the semantic segmentation exercise. 
    - We will have a single output channel
    - Our loss function will be `BCEWithLogitsLoss()`. This is the same as BCELoss but implicitly adds the sigmoid activation. It can be more numerically stable, but either Loss is fine to use. 
    - Our tensor dtype with be a torch float tensor. You can see see a list of tensor types [here](https://pytorch.org/docs/stable/tensors.html)
    - For our model, we will create a two layer Unet with the following parameters:
        - downsample by a factor of 2 in each layer
        - single input channel
        - 32 input feature maps
        - multiply by a factor of 2 between layers
        - `same` padding
        - constant upsampling
</div>

In [None]:
# some hyperparams

out_channels = 1
activation = torch.nn.Sigmoid()
loss_fn = torch.nn.BCEWithLogitsLoss()
dtype = torch.FloatTensor

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
torch.manual_seed(42)
d_factors = [[2,2],[2,2]]

in_channels=1
num_fmaps=32
fmap_inc_factors=2

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        activation='ReLU',
        padding='same',
        constant_upsample=False)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=1,
    padding=0,
    bias=True)

net = torch.nn.Sequential(unet, final_conv)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = net.to(device)

summary(net, (in_channels, 384, 384))

<div class="alert alert-block alert-info"><h3>Exercise 1.7: Run training loop</h3>
    
- Just train for 500 steps to start.
- Use a learning rate of 1e-4
  
</div>

In [None]:
def training_step(model, loss_fn, optimizer, feature, label, activation):
    # speedup version of setting gradients to zero
    for param in model.parameters():
        param.grad = None
    # forward
    logits = model(feature) # B x C x H x W

    loss_value = loss_fn(input=logits, target=label)  #logits.shape=[N,C,H,W] label.shape=[N,H,W]
    # backward if training mode
    if net.training:
        loss_value.backward()
        optimizer.step()
    
    output = activation(logits)

    outputs = {
        'pred': output,
        'logits': logits,
    }
    return loss_value, outputs

In [None]:
training_steps = 500

# create a logdir for each run and a corresponding summary writer
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

# make sure net and loss are cast to our device (should be gpu, can check by printing device)
net = net.to(device)
loss_fn = loss_fn.to(device)

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, label in tmp_loader:
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
            loss_value, pred = training_step(net, loss_fn, optimizer, feature, label, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(val_loader)
                acc_loss = []
                for feature, label in tmp_val_loader:                    
                    label = label.type(dtype)
                    label = label.to(device)
                    feature = feature.to(device)
                    loss_value, _ = training_step(net, loss_fn, optimizer, feature, label, activation)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step)
                net.train()

                print(np.mean(acc_loss))

In [None]:
# to view runs in tensorboard you can call either (uncommented):

#%reload_ext tensorboard
#%tensorboard --logdir logs

#or run:

# !tensorboard --logdir=logs 

# to view in separate window

# Note that if running over ssh you will need to also forward the tensorflow port (usually 6006)
# you can also do this by passing the host (relevant machine ip address), eg:

# !tensorboard --logdir=logs --host hostname

<div class="alert alert-block alert-info"><h3>Exercise 1.8: Visualize results</h3>
  
</div>

In [None]:
# convenience functions to threshold mask and relabel connected components as unique objects
from skimage.filters import threshold_otsu
from skimage.measure import label as relabel_cc

net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
    
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
    pred = np.squeeze(pred.cpu().detach().numpy())
            
    pred = np.squeeze(pred)
            
    thresh = threshold_otsu(pred)
    binary = pred >= thresh
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    labeled = relabel_cc(binary)
    
    fig, axes = plt.subplots(1,5,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')

    axes[0][1].imshow(mask)
    axes[0][1].title.set_text('GT mask')
    
    axes[0][2].imshow(color.label2rgb(gt_labels))
    axes[0][2].title.set_text('GT seg')
    
    axes[0][3].imshow(pred)
    axes[0][3].title.set_text('Predicted Mask')

    axes[0][4].imshow(color.label2rgb(labeled))
    axes[0][4].title.set_text('Predicted Seg')

    break

## Exercise 2.0: Improving the model

- As you can see, our prediction segmentation isn't very good. When objects are tightly packed together, using a simple foreground / background representation gives us results that aren't much better than if we just thresholded our data and relabelled connected components.
- So, let's improve our model. We can do a few things to enhance our results:
    1. Add more complex representations
        * three class
        * signed distance transform
        * edge affinities
    2. Use both channels of the raw data as input to the network
    3. Add better augmentations
    4. Increase the input size to our network
    5. Use a bigger network (eg increase layers, number of feature maps)
    6. Train for longer
    7. Use a better post-processing strategy (e.g. seeded watershed)

<div class="alert alert-block alert-info"><h3>Exercise 2.1: Add extra representations</h3>
    
- Three-class model

This is an extension of the basic foreground/background (or two-class) model. In addition a third class is introduced: the boundary. Even if two instances are touching, there is a boundary between them. This way they can be separated. Instead of a single output (where an output of zero is one class and of one is the other class), the network outputs three values, one per class. And the loss function changes from binary to (sparse) categorical cross entropy.
    
- Signed Distance Transform

The label for each pixel is the distance to the closest boundary. The value within instances is negative and outside of instances is positive. As the output is not a probability but an (in principle) unbounded scalar, the mean squared error loss function is used.
    
- Edge Affinities

Here we consider not just the pixel but also its direct neighbors (in 2D the left neighbor and the upper neighbor are sufficient, right and down are redundant with the next pixel's left and upper neighbor). Imagine there is an edge between two pixels if they are in the same class and no edge if not. If we then take all pixels that are directly and indirectly connected by edges, we get an instance. We can use mean squared error. 

</div>

In [None]:
from scipy.ndimage import distance_transform_edt

def erode(labels, iterations, border_value):

    # copy labels to memory, create border array
    labels = np.copy(labels)
    border = np.array(labels)

    # create zeros array for foreground
    foreground = np.zeros_like(labels, dtype=bool)

    # loop through unique labels
    for label in np.unique(labels):

        # skip background
        if label == 0:
            continue

        # mask to label
        label_mask = labels == label

        # erode labels
        eroded_mask = binary_erosion(
                label_mask,
                iterations=iterations,
                border_value=border_value)

        # get foreground
        foreground = np.logical_or(eroded_mask, foreground)

    # and background...
    background = np.logical_not(foreground)

    # set eroded pixels to zero
    labels[background] = 0

    # get eroded pixels
    border = labels - border

    return labels, border


# utility function to compute a signed distance transform
def compute_sdt(labels, constant=0.5, scale=5):

    inner = distance_transform_edt(binary_erosion(labels))
    outer = distance_transform_edt(np.logical_not(labels))

    distance = (inner - outer) + constant

    distance = np.tanh(distance / scale)

    return distance


# utility function to compute edge affinities
def compute_affinities(seg, nhood):

    nhood = np.array(nhood)

    shape = seg.shape
    nEdge = nhood.shape[0]
    dims = nhood.shape[1]
    aff = np.zeros((nEdge,) + shape, dtype=np.int32)

    for e in range(nEdge):
        aff[e, \
          max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \
          max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] = \
                      (seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \
                          max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] == \
                        seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \
                          max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1])] ) \
                      * ( seg[max(0,-nhood[e,0]):min(shape[0],shape[0]-nhood[e,0]), \
                          max(0,-nhood[e,1]):min(shape[1],shape[1]-nhood[e,1])] > 0 ) \
                      * ( seg[max(0,nhood[e,0]):min(shape[0],shape[0]+nhood[e,0]), \
                          max(0,nhood[e,1]):min(shape[1],shape[1]+nhood[e,1])] > 0 )
                          

    return aff

In [None]:
# compute each representation and visualize

file = random.choice(train_nuclei)

full_mask_nuclei = imread(file)
full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]

transform = A.Compose([
              A.RandomCrop(width=64, height=64),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)
          
aug_raw, aug_mask = transformed['image'], transformed['mask']
    
labels, border = erode(
    aug_mask,
    iterations=1,
    border_value=1)

labels_two_class = (labels != 0)
border[border!=0] = 2

labels_three_class = (labels_two_class + border)
sdt = compute_sdt(labels)
affs = compute_affinities(labels, nhood=[[0,1],[1,0]])

fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)

for idx, (ds_name, data) in enumerate([
    ('raw', aug_raw),
    ('fg/bg', labels_two_class),
    ('three class', labels_three_class),
    ('sdt', sdt),
    ('affinities', affs[0] + affs[1])]
):

    cmap = 'gray' if ds_name == 'raw' else 'viridis'

    axes[0][idx].imshow(data.astype(np.float32), cmap=cmap)
    axes[0][idx].title.set_text(ds_name)

In [None]:
# add representation functions to dataset

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class'
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = A.Compose([
              A.RandomCrop(width=self.crop_size, height=self.crop_size),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0]

        if self.split == 'train':
            raw, mask = self.augment_data(raw, mask)
                
        mask, border = erode(
                    mask,
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')

        mask = mask.astype(np.float32)
        
        # add channel dim for network
        raw = np.expand_dims(raw, axis=0)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
            
        return raw, mask

In [None]:
# try each prediction type
prediction_type = 'three_class'

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

try:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
except:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<div class="alert alert-block alert-info"><h3>Exercise 2.2: Use both channels of raw data</h3>
    
- Up until now, we have only been using the nuclei channel of the raw data as input into our network. But if we have extra channels available, it will help our network to see them. We should give our network as much information as possible to learn from, even if it is only tasked with learning a single channel output.

</div>

In [None]:
# add representation functions to dataset

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class'
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = A.Compose([
              A.RandomCrop(width=self.crop_size, height=self.crop_size),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        raw = raw.transpose([1,2,0])
        
        mask = np.expand_dims(mask, axis=0)
        mask = mask.transpose([1,2,0])

        if self.split == 'train':
            raw, mask = self.augment_data(raw, mask)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
                
        mask, border = erode(
                    mask[0],
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')

        mask = mask.astype(np.float32)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
        
        return raw, mask

In [None]:
# try each prediction type
prediction_type = 'three_class'

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw)
axes[0][0].title.set_text('Raw')

try:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
except:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<div class="alert alert-block alert-info"><h3>Exercise 2.3: Add more augmentations</h3>
    
- We were using some pretty simple augmentations (crop and flips). Since we want to create a more robust model, we should augment our data more so that it performs better on data that it hasn't seen. This is also a good way to effectively increase our training sample size. 
- This is a good tutorial for adding useful augmentations: #https://albumentations.ai/docs/examples/example_kaggle_salt/
- We will still crop and flip our data as before. Additionally, we will:
    1. Pad our data if needed (This is useful if our network input size is not compatible with our max pooling layers)
    2. Randomly rotate by 90 degrees
    3. Transpose
    4. Elastically warp 
    5. Randomly adjust our brightness and contrast
</div>

In [None]:
# add representation functions to dataset

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def augment_data(self, raw, mask, padding):
        
        transform = A.Compose([
              A.RandomCrop(
                  width=self.crop_size,
                  height=self.crop_size),
              A.PadIfNeeded(
                  min_height=padding,
                  min_width=padding,
                  p=1,
                  border_mode=0),
              A.HorizontalFlip(p=0.3),
              A.VerticalFlip(p=0.3),
              A.RandomRotate90(p=0.3),
              A.Transpose(p=0.3),
              A.ElasticTransform(
                  p=0.3,
                  alpha=0.1,
                  sigma=0.1,
                  alpha_affine=0.3),
              A.RandomBrightnessContrast(p=0.3)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        raw = raw.transpose([1,2,0])
        
        mask = np.expand_dims(mask, axis=0)
        mask = mask.transpose([1,2,0])
        
        if self.split == 'train':
            padding = self.get_padding(self.crop_size, self.padding_size)
            raw, mask = self.augment_data(raw, mask, padding)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
                        
        mask, border = erode(
                    mask[0],
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')

        mask = mask.astype(np.float32)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
                                        
        return raw, mask

In [None]:
# try each prediction type 
# what happens if you increase crop size to 65? - how does the padding come into play?

prediction_type = 'sdt'
crop_size = 64

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw)
axes[0][0].title.set_text('Raw')

try:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
except:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<div class="alert alert-block alert-info"><h3>Exercise 2.4: Add extra hyperparameters based on prediction type</h3>
    
- Now that we have more representations, we need to be sure that our hyperparams are consistent with whichever representation we choose. 

</div>

In [None]:
def get_hyperparams(prediction_type):

    if prediction_type == "two_class":
        out_channels = 1
        activation = torch.nn.Sigmoid()
        loss_fn = torch.nn.BCEWithLogitsLoss()
        dtype = torch.FloatTensor

    elif prediction_type == "three_class":
        out_channels = 3
        activation = torch.nn.Softmax(dim=1)
        loss_fn = torch.nn.CrossEntropyLoss()
        dtype = torch.LongTensor

    elif prediction_type == "sdt":
        out_channels = 1
        activation = torch.nn.Tanh()
        loss_fn = torch.nn.MSELoss()
        dtype = torch.FloatTensor

    elif prediction_type == "affs":
        out_channels = 2
        activation = torch.nn.Sigmoid()
        loss_fn = torch.nn.MSELoss()
        dtype = torch.FloatTensor

    else:
        raise RuntimeError("invalid prediction type")
        
    params = {
        'out_channels': out_channels,
        'activation': activation,
        'loss_function': loss_fn,
        'dtype': dtype
    }
        
    return params

In [None]:
prediction_type = 'three_class'

params = get_hyperparams(prediction_type)

print(params)

<div class="alert alert-block alert-info"><h3>Exercise 2.5: Increase batch crop size</h3>
    
- Before we were using a smaller batch crop size (64). Since we are training a 2d network with a relatively small batch number (4), it is not such a big deal to increase our crop size (128) to let our network see more data.

</div>

In [None]:
# try each prediction type
prediction_type = 'three_class'
crop_size = 128

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

try:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
except:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<div class="alert alert-block alert-info"><h3>Exercise 2.6: Increase features in network</h3>
    
- Before we were training with a pretty small network, eg two downsampling layers. Let's see how the receptive fields change as we increase our network to 3 layers

</div>

In [None]:
raw, mask = train_dataset[random.randrange(len(train_dataset))]

net_t = raw
fovs = []
d_factors = [[2,2],[2,2],[2,2]]

net = UNet(in_channels=1,
           num_fmaps=6,
           fmap_inc_factors=2,
           downsample_factors=d_factors,
           padding='same'
          )

for level in range(len(d_factors)+1):
    fov_tmp, _ = net.rec_fov(level , (1, 1), 1)
    fovs.append(fov_tmp[0])

fig=plt.figure(figsize=(8, 8))
colors = ["yellow", "red", "green", "blue", "magenta"]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

plt.imshow(raw)

for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = raw.shape[1]/2 - fov_t/2
    xmax = raw.shape[1]/2 + fov_t/2
    ymin = raw.shape[1]/2 - fov_t/2
    ymax = raw.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

In [None]:
prediction_type = 'three_class'
params = get_hyperparams(prediction_type)

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factors=3
out_channels=params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        activation='ReLU',
        padding='same',
        constant_upsample=False)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=1,
    padding=0,
    bias=True)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 384, 384))

<div class="alert alert-block alert-info"><h3>Exercise 2.7: Train for longer</h3>

</div>

In [None]:
def training_step(model, loss_fn, optimizer, feature, label, prediction_type, activation):
    # speedup version of setting gradients to zero
    for param in model.parameters():
        param.grad = None
    # forward
    logits = model(feature) # B x C x H x W
       
    if prediction_type == "three_class":
        label=torch.squeeze(label,1) #label.shape=[N,H,W]

    loss_value = loss_fn(input=logits, target=label)  #logits.shape=[N,C,H,W] label.shape=[N,H,W]
    # backward if training mode
    if net.training:
        loss_value.backward()
        optimizer.step()
        
    output = activation(logits)
    
    outputs = {
        'pred': output,
        'logits': logits,
    }
    return loss_value, outputs

In [None]:
training_steps = 2000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=128, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, label in tmp_loader:
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
                                
            loss_value, pred = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, label in tmp_val_loader:                    
                    label = label.type(dtype)
                    label = label.to(device)
                    feature = feature.to(device)
                    loss_value, _ = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step)
                net.train()

                print(np.mean(acc_loss))

In [None]:
net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
    
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    if prediction_type == 'three_class':
        pred = np.argmax(pred, axis=0)
        
    try:
        axes[0][1].imshow(mask)
        axes[0][1].title.set_text('GT mask')
        axes[0][2].imshow(pred)
        axes[0][2].title.set_text('Predicted')
    except:
        axes[0][1].imshow(mask[0] + mask[1])
        axes[0][1].title.set_text('GT mask')
        axes[0][2].imshow(unpad(pred[0], 8) + unpad(pred[1], 8))
        axes[0][2].title.set_text('Predicted')
      
    if idx == 2:
        break

## Exercise 3.0: Post-processing

<div class="alert alert-block alert-info"><h3>Exercise 3.1: Introduce watershed</h3>
    
- Before we were just thresholding our predictions and then relabeling connected components. This is a totally fine approach in the cases where we don't have touching objects. Now we will use a better approach commonly used for instance segmentation called seeded watershed. See here for a nice overview: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_watershed.html
- To compute our seeded watershed, we first need to get a boundary mask from our predictions. This is done slightly differently for the various representations, but generally speaking our boundary mask will just be a boolean indicating our foreground regions. From this boundary mask we compute boundary distances using a distance transform. These will then give us local maxima that can be used to extract seed points. The watershed algorithm then expands each seed out in a local "basin" until the segments touch.
- Because of this, it is often not sufficient to use watershed alone on complex datasets. In most cases the resulting objects are referred to as fragments (or supervoxels), which can then be stitched together using the underlying predictions as edge weights through a process called agglomeration.
- Agglomeration is out of the scope of this exercise, but you can find a nice overview here: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_boundary_merge.html
</div>

In [None]:
from skimage.segmentation import watershed
from scipy.ndimage import label, maximum_filter

def watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask,
        id_offset=0,
        min_seed_distance=10):

    # get our seeds 
    max_filtered = maximum_filter(boundary_distances, min_seed_distance)
    maxima = max_filtered==boundary_distances
    seeds, n = label(maxima)

    if n == 0:
        return np.zeros(boundary_distances.shape, dtype=np.uint64), id_offset

    seeds[seeds!=0] += id_offset

    # calculate our segmentation
    segmentation = watershed(
        boundary_distances.max() - boundary_distances,
        seeds,
        mask=boundary_mask)
    
    return segmentation

def get_boundary_mask(pred, prediction_type, thresh=None):
    
    if prediction_type == 'two_class' or prediction_type == 'sdt':
        # simple threshold
        boundary_mask = pred > thresh

    elif prediction_type == 'three_class':
        # Return the indices of the maximum values along channel axis, then set mask to cell interior (1)
        boundary_mask = np.argmax(pred, axis=0)
        boundary_mask = boundary_mask == 1

    elif prediction_type == 'affs':
        # take mean of combined affs then threshold
        boundary_mask = 0.5 * (pred[0] + pred[1]) > thresh
    else:
        raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')
        
    return boundary_mask

In [None]:
net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = np.mean(pred)
            
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    seg = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(color.label2rgb(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(color.label2rgb(seg))
    axes[0][2].title.set_text('Predicted Labels')
    
    if idx == 2:
        break

<div class="alert alert-block alert-info"><h3>Exercise 3.2: Evaluate</h3>
    
- There are several ways to evaluate accuracy of models for instance segmentation.
- For our purposes, we will calculate the intersection over union (IoU) between the ground truth labels and the predicted labels. Here is a nice overview: https://www.jeremyjordan.me/evaluating-image-segmentation-models/
- From IoU, we can then evaluate:
    1. True positives
    2. False positives
    3. False negatives
    4. Precision
    5. Recall
    6. Average precision
- These will already give a good indication of model performance. You can easily look up some information on each of these metrics, and you will have a good idea about why they are used following the previous exercises and lectures.
</div>

In [None]:
from scipy.optimize import linear_sum_assignment
from skimage.segmentation import relabel_sequential

def evaluate(gt_labels, pred_labels):
    pred_labels_rel, _, _ = relabel_sequential(pred_labels)
    gt_labels_rel, _, _ = relabel_sequential(gt_labels)

    overlay = np.array([pred_labels_rel.flatten(),
                        gt_labels_rel.flatten()])

    # get overlaying cells and the size of the overlap
    overlay_labels, overlay_labels_counts = np.unique(
        overlay, return_counts=True, axis=1)
    overlay_labels = np.transpose(overlay_labels)

    # get gt cell ids and the size of the corresponding cell
    gt_labels_list, gt_counts = np.unique(gt_labels_rel, return_counts=True)
    gt_labels_count_dict = {}
    
    for (l, c) in zip(gt_labels_list, gt_counts):
        gt_labels_count_dict[l] = c

    # get pred cell ids
    pred_labels_list, pred_counts = np.unique(pred_labels_rel,
                                              return_counts=True)

    pred_labels_count_dict = {}
    for (l, c) in zip(pred_labels_list, pred_counts):
        pred_labels_count_dict[l] = c

    num_pred_labels = int(np.max(pred_labels_rel))
    num_gt_labels = int(np.max(gt_labels_rel))
    num_matches = min(num_gt_labels, num_pred_labels)
    
    # create iou table
    iouMat = np.zeros((num_gt_labels+1, num_pred_labels+1),
                      dtype=np.float32)

    for (u, v), c in zip(overlay_labels, overlay_labels_counts):
        iou = c / (gt_labels_count_dict[v] + pred_labels_count_dict[u] - c)
        iouMat[int(v), int(u)] = iou

    # remove background
    iouMat = iouMat[1:, 1:]

    # default threshold
    th = 0.5
    if num_matches > 0 and np.max(iouMat) > th:
        costs = -(iouMat > th).astype(float) - iouMat / (2*num_matches)
        gt_ind, pred_ind = linear_sum_assignment(costs)
        assert num_matches == len(gt_ind) == len(pred_ind)
        match_ok = iouMat[gt_ind, pred_ind] > th
        tp = np.count_nonzero(match_ok)
    else:
        tp = 0
    fp = num_pred_labels - tp
    fn = num_gt_labels - tp
    ap = tp / max(1, tp + fn + fp)
    precision = tp / max(1, tp + fp)
    recall = tp / max(1, tp + fn)

    return ap, precision, recall, tp, fp, fn

In [None]:
net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = 0.6
                  
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    pred_labels = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    ap, precision, recall, tp, fp, fn = evaluate(gt_labels, pred_labels)
    
    print(
        f'Average precision: {ap} \n',
        f'Precision: {precision} \n',
        f'Recall: {recall} \n',
        f'True positives: {tp} \n',
        f'False positives: {fp} \n',
        f'False negatives: {fn} \n'
    )
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(color.label2rgb(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(color.label2rgb(pred_labels))
    axes[0][2].title.set_text('Predicted Labels')
    
    break

<div class="alert alert-block alert-info"><h3>Exercise 3.3: Loop over batches</h3>
    
- Now we can loop over all test set images and get the the average model precision

</div>

In [None]:
avg = 0.0

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = 0.6
            
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    pred_labels = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    ap, precision, recall, tp, fp, fn = evaluate(gt_labels, pred_labels)
    
    avg += ap
    
    print(
        f'Average precision: {ap} \n',
        f'Precision: {precision} \n',
        f'Recall: {recall} \n',
        f'True positives: {tp} \n',
        f'False positives: {fp} \n',
        f'False negatives: {fn} \n'
    )
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(color.label2rgb(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(color.label2rgb(pred_labels))
    axes[0][2].title.set_text('Predicted Labels')
    
    plt.show()
        
avg /= (idx+1)
    
print("average precision on test set: {}".format(avg))

<div class="alert alert-block alert-info"><h3>Exercise 3.4 (time permitting): Improve accuracy / segment cyto</h3>
    
- It is likely that even with the extra things we added, we still aren't achieving the level of accuracy we would like to see. Production level models are usually trained for many more iterations and use lots of tricks to maximize the accuracy. It is more important for now to conceptually understand the basics of instance segmentation and different approaches to increasing model robustness.
- If you have time (now or in the future), try to improve the accuracy of your model. How accurate can you get it? Can you also get an accurate model on the cytoplasm masks? 
- If you have time (now or in the future), try the following bonus exercises - which show more advanced approaches to getting good instance segmentation results
</div>

## Exercise 4.0: Bonus exercises

<div class="alert alert-block alert-info"><h3>Exercise 4.1: Add early stopping</h3>
    
- I think maybe we can remove this? It looks like there is early stopping in the image segmentation exercise?

</div>

In [None]:
class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    Code from https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
    """
    def __init__(self, patience=20, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
prediction_type = 'affs'

params = get_hyperparams(prediction_type)

print(params)

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factors=4
out_channels=params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        activation='ReLU',
        padding='same',
        constant_upsample=False)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=1,
    padding=0,
    bias=True)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

In [None]:
training_steps = 2000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=128, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
# set flags
net.train() 
loss_fn.train()
step = 0

early_stopping = EarlyStopping(patience=100)

with tqdm(total=training_steps) as pbar:
    while step < training_steps and not early_stopping.early_stop:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, label in tmp_loader:
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
            
            #print(label.shape, feature.shape)
                    
            loss_value, pred = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, label in tmp_val_loader:                    
                    label = label.type(dtype)
                    label = label.to(device)
                    feature = feature.to(device)
                    loss_value, _ = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step) 
                net.train()
                print(np.mean(acc_loss))
                
            if early_stopping:
                early_stopping(np.mean(acc_loss))
                if early_stopping.early_stop:
                    print('Early stopping after step', step)
                    break

<div class="alert alert-block alert-info"><h3>Exercise 4.2: Balance labels, use weighted loss</h3>
    
- Before we were treating every pixel as equally important, but this is not always the case. Let's look at the three class representation as an example
- We have significantly more background pixels (label 0) than border pixels (label 2). So why don't we weight them differently? We should have a higher weighting on the border pixels. 
</div>

In [None]:
prediction_type = 'three_class'

params = get_hyperparams(prediction_type)
params

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=128, prediction_type=prediction_type)

background = []
inner = []
border = []

for raw, mask in train_dataset:

    # get unique labels and how many pixels they have (this should just be 0,1,2 for three class)
    labels, counts = np.unique(mask, return_counts=True)
    
    # test for background batch (some batches could be entirely background)
    if len(counts) != 3:
        continue
    
    background.append(counts[0])
    inner.append(counts[1])
    border.append(counts[2])
    
# get the averages across all training batches
# to make sure we don't have a misrepresented weighting from a single batch
all_counts = [np.mean(i) for i in (background, inner, border)]

# get the probabilities as they sum up to 1 
label_proportions = [(label, count/sum(counts)) for label,count in zip(labels, all_counts)]

print(label_proportions)

# our loss function can take a weight. this should be the inverse of the label proportions. 
# eg:
    # if label 0 has a proportion of 0.55, it would have a weight of 1/0.55 = 1.8
    # if label 2 has a proportion of 0.09 it would have a weight of 1/0.09 = 11

params['loss_function'] = torch.nn.CrossEntropyLoss(
    weight=torch.tensor(
        [(1/p).astype(np.float32) for l, p in label_proportions]
    ))

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factors=3
out_channels=params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        activation='ReLU',
        padding='same',
        constant_upsample=False)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=1,
    padding=0,
    bias=True)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

In [None]:
training_steps = 2000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=128, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, label in tmp_loader:
            label = label.type(dtype)
            label = label.to(device)
            feature = feature.to(device)
                        
            #print(label.shape, feature.shape)
                    
            loss_value, pred = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, label in tmp_val_loader:                    
                    label = label.type(dtype)
                    label = label.to(device)
                    feature = feature.to(device)
                    loss_value, _ = training_step(net, loss_fn, optimizer, feature, label, prediction_type, activation)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step) 
                net.train()
                print(np.mean(acc_loss))

In [None]:
net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
    pred = np.argmax(pred, axis=0)
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(mask)
    axes[0][1].title.set_text('GT mask')

    axes[0][2].imshow(pred)
    axes[0][2].title.set_text('Predicted')
    
    if idx == 2:
        break

<div class="alert alert-block alert-info"><h3>Exercise 4.3: Auxiliary learning</h3>
    
- Auxiliary learning is a powerful technique that can help to improve the results of our main objective by providing a helper task. Up until now, we have only shown our model representations of the data that are boundary specific. But the data is a lot richer than that - these objects have distinct shapes that could be leveraged in order to better learn the boundaries.
- TODO (Carsen): Explain cellpose
</div>

In [None]:
from cellpose import models

In [None]:
model = models.CellposeModel(gpu=device, model_type='tissuenet')

In [None]:
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
test_loader = DataLoader(test_dataset, batch_size=1)

In [None]:
channels = [2, 1]

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = model.eval(image, diameter=25, channels=channels)
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(color.label2rgb(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    if idx == 2:
        break

In [None]:
# end of cellpose, start of lsds

# add explanation

In [None]:
from lsd.train import local_shape_descriptor

file = random.choice(train_nuclei)

nuclei = imread(file)[0:64, 0:64]
raw = imread(file.replace('_nuclei_masks', ''))[:, 0:64, 0:64]

#just to visualize
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=nuclei,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

fig, axes = plt.subplots(
            1,
            6,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

In [None]:
class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def augment_data(self, raw, mask, padding):
        
        transform = A.Compose([
              A.RandomCrop(
                  width=self.crop_size,
                  height=self.crop_size),
              A.PadIfNeeded(
                  min_height=padding,
                  min_width=padding,
                  p=1,
                  border_mode=0),
              A.HorizontalFlip(p=0.3),
              A.VerticalFlip(p=0.3),
              A.RandomRotate90(p=0.3),
              A.Transpose(p=0.3),
              A.ElasticTransform(
                  p=0.3,
                  alpha=0.1,
                  sigma=0.1,
                  alpha_affine=0.3),
              A.RandomBrightnessContrast(p=0.3)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        raw = raw.transpose([1,2,0])
        
        mask = np.expand_dims(mask, axis=0)
        mask = mask.transpose([1,2,0])
                
        padding = self.get_padding(self.crop_size, self.padding_size)
        raw, mask = self.augment_data(raw, mask, padding)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
        
        mask, border = erode(
                    mask[0],
                    iterations=1,
                    border_value=1)

        affs = compute_affinities(mask, nhood=[[0,1],[1,0]])
                        
        lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=mask,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

        lsds = lsds.astype(np.float32)
        affs = affs.astype(np.float32)
                                        
        return raw, lsds, affs

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)

raw, lsds, affs = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(
            1,
            7,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

axes[0][6].imshow(np.squeeze(affs[0]+affs[1]), cmap='jet')
axes[0][6].title.set_text('Affs')

In [None]:
class MtlsdModel(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        num_fmaps,
        fmap_inc_factors,
        downsample_factors,
        padding='same',
        constant_upsample=False,
    ):
        super().__init__()

        self.unet = UNet(
            in_channels=in_channels,
            num_fmaps=num_fmaps,
            fmap_inc_factors=fmap_inc_factors,
            downsample_factors=downsample_factors,
            padding=padding,
            constant_upsample=constant_upsample)

        self.lsd_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=6, kernel_size=1)
        self.aff_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=2, kernel_size=1)

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

# combine the lsds and affs losses

class CombinedLoss(torch.nn.MSELoss):

    def __init__(self):
        super(CombinedLoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, affs_prediction, affs_target):

        loss1 = super(CombinedLoss, self).forward(lsds_prediction,lsds_target)
        loss2 = super(CombinedLoss, self).forward(affs_prediction, affs_target)
        
        return loss1 + loss2

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factor=4

net = MtlsdModel(in_channels,num_fmaps,fmap_inc_factors,d_factors)

loss_fn = CombinedLoss().to(device)

net = net.to(device)

In [None]:
training_steps = 5000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
dtype = torch.FloatTensor

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', crop_size=128)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, crop_size=64)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
def training_step(model, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation=torch.nn.Sigmoid()):
    # speedup version of setting gradients to zero
    for param in model.parameters():
        param.grad = None
    # forward
    lsd_logits, affs_logits = model(feature) # B x C x H x W

    loss_value = loss_fn(lsd_logits, gt_lsds, affs_logits, gt_affs)  #logits.shape=[N,C,H,W] label.shape=[N,H,W]
    # backward if training mode
    if net.training:
        loss_value.backward()
        optimizer.step()
        
    lsd_output = activation(lsd_logits)
    affs_output = activation(affs_logits)
   
    outputs = {
        'pred_lsds': lsd_output,
        'pred_affs': affs_output,
        'lsds_logits': lsd_logits,
        'affs_logits': affs_logits,
    }
    return loss_value, outputs

In [None]:
# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, gt_lsds, gt_affs in tmp_loader:
            gt_lsds = gt_lsds.type(dtype)
            gt_lsds = gt_lsds.to(device)
            gt_affs = gt_affs.type(dtype)
            gt_affs = gt_affs.to(device)
            feature = feature.to(device)
                        
            #print(label.shape, feature.shape)
                    
            loss_value, pred = training_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, gt_lsds, gt_affs in tmp_val_loader:                    
                    gt_lsds = gt_lsds.type(dtype)
                    gt_lsds = gt_lsds.to(device)
                    gt_affs = gt_affs.type(dtype)
                    gt_affs = gt_affs.to(device)
                    feature = feature.to(device)
                    loss_value, _ = training_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step) 
                net.train()
                print(np.mean(acc_loss))

In [None]:
net.eval()

activation = torch.nn.Sigmoid()

for idx, (image, gt_lsds, gt_affs) in enumerate(test_loader):
    image = image.to(device)
    lsds_logits, affs_logits = net(image)
    pred_lsds = activation(lsds_logits)
    pred_affs = activation(affs_logits)
        
    image = np.squeeze(image.cpu())
    gt_lsds = np.squeeze(gt_lsds.cpu().numpy())
    gt_affs = np.squeeze(gt_affs.cpu().numpy())
    
    pred_lsds = np.squeeze(pred_lsds.cpu().detach().numpy())
    pred_affs = np.squeeze(pred_affs.cpu().detach().numpy())
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
  
    axes[0][1].imshow(np.squeeze(pred_lsds[0]), cmap='jet')
    axes[0][1].imshow(np.squeeze(pred_lsds[1]), cmap='jet', alpha=0.5)
    axes[0][1].title.set_text('Mean offset')

    axes[0][2].imshow(np.squeeze(pred_affs[0]+pred_affs[1]), cmap='jet')
    axes[0][2].title.set_text('Affs')
    
    break

# Further learning

* Instance segmentation can be challenging and this exercise just scratches the surface of what is possible. 
* This notebook assumes images that fit into memory but often times this is not the case (especially in biology). 
    1. To see an example for predicting over an image in chunks and stitching the results together, see this notebook: https://github.com/dlmbl/instance_segmentation/blob/main/3_tile_and_stitch.ipynb
    2. For a more advanced library that makes it easier to do machine learning on massive datasets, see gunpowder (navigate to the tutorials, or browse the API): http://funkey.science/gunpowder/
* We did not cover more complex loss functions. Here are some nice explanations / implementations of other loss functions that are useful for instance segmentation: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook
* A more complex (but powerful) approach is called metric learning. This can be seen in last years exercise: https://github.com/dlmbl/instance_segmentation/blob/main/2_instance_segmentation.ipynb
* We did not cover stardist in this tutorial, and barely scratched the surface on cellpose and lsds. For more tutorials on:
    1. Stardist: https://github.com/maweigert/tutorials/tree/main/stardist
    2. CellPose: https://github.com/MouseLand/cellpose#run-cellpose-10-without-local-python-installation
    3. LSDs: https://github.com/funkelab/lsd#notebooks
    
### Good luck on your instance segmentation endeavors!!