# Bonus exercises and further learning

- These exercises do not have todos, feel free to run them to get a sense of a few more tricks you can use for instance segmentation. 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import datetime

from glob import glob
from skimage import color
from skimage.io import imread
from natsort import natsorted
import albumentations as A
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm.auto import tqdm
from unet import *
from utils import *

torch.backends.cudnn.benchmark = True

In [None]:
# lets start by loading our images into lists so we can visualize and get oriented with our data
# natsorted is a package that takes away some of the annoyances of the regular sorting function
# glob is a package that allows us to load files from directories. Feel free to inspect these lists.

train_cyto = natsorted(glob('woodshole/train/*cyto*'))
train_nuclei = natsorted(glob('woodshole/train/*nuclei*'))
train_raw = [i for i in natsorted(glob('woodshole/train/*.tif')) if 's.tif' not in i]

test_cyto = natsorted(glob('woodshole/test/*cyto*'))
test_nuclei = natsorted(glob('woodshole/test/*nuclei*'))
test_raw = [i for i in natsorted(glob('woodshole/test/*.tif')) if 's.tif' not in i]

In [None]:
# Let's define the decive we'll be using throughout the notebook

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# select a random cytoplasm mask file
cyto_file = random.choice(train_cyto)

# use skimage.io.imread to read our data into numpy arrays
cyto = imread(cyto_file)
nuclei = imread(cyto_file.replace('cyto', 'nuclei'))
raw = imread(cyto_file.replace('_cyto_masks', ''))

#our raw data shape is (c, h, w) and there are only two channels.
#to visualize as an rgb image we need to add another dummy dimension and then transpose so that it is (h,w,3)
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

# visualize the data - execute this cell a few times to see different examples.
# you can also change train_cyto to test_cyto above to see some test data. it is pretty similar 
fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw[:,:,0], cmap='gray')
axes[0][0].title.set_text('Raw nuclei channel')

axes[0][1].imshow(raw[:,:,1], cmap='gray')
axes[0][1].title.set_text('Raw cyto channel')

axes[0][2].imshow(raw)
axes[0][2].title.set_text('Raw overlay')

axes[0][3].imshow(raw[:,:,0], cmap='gray')
axes[0][3].imshow(create_lut(nuclei), alpha=0.5)
axes[0][3].title.set_text('nuclei mask')

axes[0][4].imshow(raw[:,:,1], cmap='gray')
axes[0][4].imshow(create_lut(cyto), alpha=0.5)
axes[0][4].title.set_text('cyto mask')

In [None]:
class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def create_target(self, mask, prediction_type):
        
        mask, border = erode_border(
                    mask,
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')
        
        return mask.astype(np.float32)

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)
        
        raw = raw.transpose([1,2,0]).astype(np.float32)
        mask = np.expand_dims(mask, axis=-1)
        
        if self.split == 'train':
            padding = self.get_padding(self.crop_size, self.padding_size)
            raw, mask = augment_data(raw, mask, padding, self.crop_size)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
        
        mask = self.create_target(mask[0], self.prediction_type)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
                                        
        return raw, mask

## Auxiliary learning

- Auxiliary learning is a powerful technique that can help to improve the results of our main objective by providing a helper task. Up until now, we have only shown our model representations of the data that are boundary specific. But the data is a lot richer than that - these objects have distinct shapes that could be leveraged in order to better learn the boundaries.

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Cellpose</h3>
    
- In [**Cellpose**](https://cellpose.readthedocs.io/en/latest/), cells are turned into flow representations. We create these flow representations by simulating diffusion from the center of the cell to get the spatial gradients for each pixel that point towards the center of the cell. During test time, we use the flows as a dynamical system and all pixels that converge to the same point are defined as the pixels in a given cell. The flows shown below are represented by an HSV colormap used in the optic flow literature.
    
    
- We also predict the foregroud / background -- the two classes you predicted in exercise 1. In Cellpose we call this the cell probability. We threshold this to decide which pixels are in cells -- we only use these pixels to run the dynamical system.
    
    
- The flow representation allows the learning of non-convex shapes, because pixels can flow around corners. It also prevents merging, as flows for two cells that are touching are opposite.
</div>

![cellpose_flows](static/cellpose_flows.png)

In [None]:
from cellpose import models

# create a cellpose model on the gpu
# use a built-in model trained on tissuenet
# (the first time you run this cell the model will download)
model = models.CellposeModel(gpu=device, model_type='tissuenet')

test_dataset = TissueNetDataset(root_dir='woodshole', split='test')
test_loader = DataLoader(test_dataset, batch_size=1)

### IMPORTANT: these are the channels used for the segmentation
# the first one is the channel to segment, and the second one is the optional nuclear channel
# red = 1
# green = 2
# blue = 3

channels = [2, 1]

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = model.eval(image, diameter=25, channels=channels)
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    
    if idx == 2:
        break

In [None]:
# we could also use the nuclear channel ONLY and run a nuclear model in cellpose
# we set the second channel = 0 because we do not have an additional channel now
channels = [1, 0]

# initialize nuclei model (can also try "cyto" model if this doesn't work)
# the "nuclei" model in cellpose has been trained on lots of nuclear data (but not the tissuenet dataset)
# the "cyto" model in cellpose has been trained on many cellular images (but not the tissuenet dataset)
model = models.CellposeModel(gpu=device, model_type='nuclei')

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = model.eval(image, diameter=20, channels=channels)
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    if idx == 2:
        break

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Local Shape Descriptors</h3>

- Another example of auxiliary learning is [**LSDs**](https://localshapedescriptors.github.io/). This embedding encodes object shape similarly but is computed in a defined gaussian constrained to each label. This allows for consistent gradients regardless of object shapes which makes it a good candidate for segmentation of complex objects such as neurons in large electron microscopy datasets. 


- The LSDs are combined with nearest neighbor affinities to improve the boundary representations. The improved affinities then produce nice segmentations when using a hierarchical agglomeration approach and can be easily parallelized to allow for scaling to massive volumes. 

![example_image](static/lsd_schematic.png)

In [None]:
# import lsds, calculate on a small patch and visualize the descriptor components

from lsd.train import local_shape_descriptor

file = random.choice(train_nuclei)

nuclei = imread(file)[0:64, 0:64]
raw = imread(file.replace('_nuclei_masks', ''))[:, 0:64, 0:64]

#just to visualize
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=nuclei,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

fig, axes = plt.subplots(
            1,
            6,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

In [None]:
# slightly modify our dataset just for simplicity

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def augment_data(self, raw, mask, padding):
        
        transform = A.Compose([
              A.RandomCrop(
                  width=self.crop_size,
                  height=self.crop_size),
              A.PadIfNeeded(
                  min_height=padding,
                  min_width=padding,
                  p=1,
                  border_mode=0),
              A.HorizontalFlip(p=0.3),
              A.VerticalFlip(p=0.3),
              A.RandomRotate90(p=0.3),
              A.Transpose(p=0.3),
              A.RandomBrightnessContrast(p=0.3)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        raw = raw.transpose([1,2,0])
        
        mask = np.expand_dims(mask, axis=0)
        mask = mask.transpose([1,2,0])
                
        # just do this regardless of split to make val/test faster for demo purposes
        padding = self.get_padding(self.crop_size, self.padding_size)
        raw, mask = self.augment_data(raw, mask, padding)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
        
        mask, border = erode_border(
                    mask[0],
                    iterations=1,
                    border_value=1)

        affs = compute_affinities(mask, nhood=[[0,1],[1,0]])
                        
        lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=mask,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

        lsds = lsds.astype(np.float32)
        affs = affs.astype(np.float32)
                                        
        return raw, lsds, affs

In [None]:
# visualize batch

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)

raw, lsds, affs = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(
            1,
            7,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

axes[0][6].imshow(np.squeeze(affs[0]+affs[1]), cmap='jet')
axes[0][6].title.set_text('Affs')

In [None]:
# we need two output heads for our network, one for lsds and one for affinities
# to do this we will subclass torch.nn.Module and create our UNet inside
# before we had a single final convolution. Now we have one for each head.
# then in the forward pass we pass our image through our unet and then the output through each head

class MtlsdModel(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        num_fmaps,
        fmap_inc_factors,
        downsample_factors,
        padding='same'
    ):
        super().__init__()

        self.unet = UNet(
            in_channels=in_channels,
            num_fmaps=num_fmaps,
            fmap_inc_factors=fmap_inc_factors,
            downsample_factors=downsample_factors,
            padding=padding)

        self.lsd_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=6, kernel_size=1)
        self.aff_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=2, kernel_size=1)

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

# We want to combine the lsds and affs losses and minimize the sum
# we can do this by subclassing our loss function (torch.nn.MSELoss) and overriding the forward method

class CombinedLoss(torch.nn.MSELoss):

    def __init__(self):
        super(CombinedLoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, affs_prediction, affs_target):

        loss1 = super(CombinedLoss, self).forward(lsds_prediction,lsds_target)
        loss2 = super(CombinedLoss, self).forward(affs_prediction, affs_target)
        
        return loss1 + loss2

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factors=4

net = MtlsdModel(in_channels,num_fmaps,fmap_inc_factors,d_factors)

loss_fn = CombinedLoss().to(device)

net = net.to(device)

In [None]:
training_steps = 3000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
dtype = torch.FloatTensor

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# set activation
activation = torch.nn.Sigmoid()

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', crop_size=128)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, crop_size=64)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
# update our training step to have two logits and two predictions

def model_step(model, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=True):
    
    # zero gradients if training
    if train_step:
        optimizer.zero_grad()
        
    # forward
    lsd_logits, affs_logits = model(feature)

    loss_value = loss_fn(lsd_logits, gt_lsds, affs_logits, gt_affs)
    
    # backward if training mode
    if train_step:
        loss_value.backward()
        optimizer.step()
        
    lsd_output = activation(lsd_logits)
    affs_output = activation(affs_logits)
   
    outputs = {
        'pred_lsds': lsd_output,
        'pred_affs': affs_output,
        'lsds_logits': lsd_logits,
        'affs_logits': affs_logits,
    }
    
    return loss_value, outputs

In [None]:
# update our training loop to do both lsds and affs

# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, gt_lsds, gt_affs in tmp_loader:
            gt_lsds = gt_lsds.to(device)
            gt_affs = gt_affs.to(device)
            feature = feature.to(device)
                        
            #print(label.shape, feature.shape)
                    
            loss_value, pred = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, gt_lsds, gt_affs in tmp_val_loader:                    
                    gt_lsds = gt_lsds.to(device)
                    gt_affs = gt_affs.to(device)
                    feature = feature.to(device)
                    loss_value, _ = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=False)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step) 
                net.train()
                print(np.mean(acc_loss))

In [None]:
# visualize a few predictions - have the lsds helped to improve the affinities?
# For a future challenge you could try using a weighted combined loss and watershed + agglomeration to get strong segmentations

net.eval()

activation = torch.nn.Sigmoid()

for idx, (image, gt_lsds, gt_affs) in enumerate(test_loader):
    image = image.to(device)
    lsds_logits, affs_logits = net(image)
    pred_lsds = activation(lsds_logits)
    pred_affs = activation(affs_logits)
        
    image = np.squeeze(image.cpu())
    gt_lsds = np.squeeze(gt_lsds.cpu().numpy())
    gt_affs = np.squeeze(gt_affs.cpu().numpy())
    
    pred_lsds = np.squeeze(pred_lsds.cpu().detach().numpy())
    pred_affs = np.squeeze(pred_affs.cpu().detach().numpy())
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
  
    axes[0][1].imshow(np.squeeze(pred_lsds[0]), cmap='jet')
    axes[0][1].imshow(np.squeeze(pred_lsds[1]), cmap='jet', alpha=0.5)
    axes[0][1].title.set_text('Mean offsets')

    axes[0][2].imshow(np.squeeze(pred_affs[0]+pred_affs[1]), cmap='jet')
    axes[0][2].title.set_text('Affs')
    
    if idx == 2:
        break

### Further learning

* Instance segmentation can be challenging and this exercise just scratches the surface of what is possible.


* This notebook assumes images that fit into memory but often times this is not the case (especially in biology). 
    1. To see an example for predicting over an image in chunks and stitching the results together, see this [notebook](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/3_tile_and_stitch.ipynb)
    2. For a more advanced library that makes it easier to do machine learning on massive datasets, see gunpowder (navigate to the tutorials, or browse the API): https://funkelab.github.io/gunpowder
    
    
* We did not cover more complex loss functions. Here are some nice explanations / implementations of other loss functions that are useful for instance segmentation: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook


* A more complex (but powerful) approach is called metric learning. This can be seen in last years [exercise](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/2_instance_segmentation.ipynb)


* We did not cover stardist in this tutorial, and barely scratched the surface on cellpose and lsds. For more tutorials on:
    1. Stardist: https://github.com/maweigert/tutorials/tree/main/stardist
    2. CellPose: https://github.com/MouseLand/cellpose#run-cellpose-10-without-local-python-installation
    3. LSDs: https://github.com/funkelab/lsd#notebooks
    
### Good luck on your instance segmentation endeavors!!