# Instance Segmentation

- So far we were only interested in semantic classes, eg foreground / background, cell types, person / car, etc. But in many cases we not only want to know if a certain pixel belongs to an object, but also to **which** unique object.


- For isolated objects, this is trivial, all connected foreground pixels form one instance, yet often instances are very close together or even overlapping. Then we need to think a bit more how to formulate the inputs / loss to our network and how to extract the instances from the predictions.


- Below is an example of the differences between object detection, semantic segmentation, and instance segmentation. The raw data shows a 2d slice of neural tissue acquired from a high resolution electron microscope. We could teach a network to simply detect each mitochondria (object detection). If we want to assign every pixel to a specific class, we could do semantic segmentation (like in the previous exercise). In this case we would have two classes, one class for neurons that contain mitochondria (green) and another class for neurons which do not. Finally, we could assign every pixel as belonging to a unique object (neurons in this case). This is instance segmentation and is an approach that can be very useful for biological data.

![example_image](static/instance_seg.png)


- If you are running this in jupyter lab, every markdown header is collapsible and all cells are collapsed by default. Just click on the left of a cell to expand it, and just make sure to expand until the code cells show. The headings unfortunately do not collapse in jupyter notebook, but will still give you an idea of breaks between exercises. 


- Most TODOs build off the previous TODOs and require copying over classes/functions before adding more to them. It could be annoying to constantly scroll back and forth so each TODO and Task header can link to the previous and next section to make it easier to move around freely. **Note:** This only works for uncollapsed cells. If you collapse a cell containing TODO 1 and then try to return to it using the TODO 2 link it won't work.


- If you have questions please let us know. Have fun!!

<div class="alert alert-block alert-success"><h1>Start here (AKA checkpoint 0)</h1>

</div>

## Task 0.0: Importing packages

* You should have already set up your conda environment by now. Import these packages and let us know if something fails so we can debug before moving on. 
* We're also importing a UNet model from the corresponding module, and other predefined functions from the `utils.py`. Please check them out once you'll be using them.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import datetime

from glob import glob
from skimage import color
from skimage.io import imread
from natsort import natsorted
import albumentations as A
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm.auto import tqdm
from unet import *
from utils import *

torch.backends.cudnn.benchmark = True

## Task 1.0: Creating a simple model

- Let's start by creating a simple model similar to the one in the semantic segmentation exercise. We will then improve it in the subsequent checkpoints

Click [here](#task-10-creating-a-simple-model) to go to the next task


<div class="alert alert-block alert-info"><h3>Task 1.1: Load and visualize data</h3>
    
    
- For this exercise we will be using data from [TissueNet](https://datasets.deepcell.org/) (paper [here](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9010346/)).
    
    
- For our purposes we will use 50 training images and 20 testing images. The data is stored as tifs using the following structure:
```
woodshole/
    ├── test
    │   ├── img_0_cyto_masks.tif
    │   ├── img_0_nuclei_masks.tif
    │   ├── img_0.tif
    │   ├── img_1_cyto_masks.tif
    │   ├── img_1_nuclei_masks.tif
    │   └── img_1.tif
    │  
    └── train
        ├── img_0_cyto_masks.tif
        ├── img_0_nuclei_masks.tif
        ├── img_0.tif
        ├── img_1_cyto_masks.tif
        ├── img_1_nuclei_masks.tif
        └── img_1.tif
```
- Each raw image is stored as `img_{n}.tif` and is already stored as float32 between 0 and 1 so does not need to be normalized for training purposes. There are two channels in the raw data, one for nuclei and one for cytoplasm. 
    
    
- The corresponding mask files contain instance segmentations for the raw data. The nuclei masks correspond to the first channel of the raw data. The cytoplasm masks correspond to nucleus + cytoplasm. We will start with the nuclei data, as it is likely easier to segment than the cytoplasm, but you will be able to apply the techniques you learn on the harder data at the end of the exercise, time permitting.

</div>

![example_image](static/example_image.png)

In [None]:
# lets start by loading our images into lists so we can visualize and get oriented with our data
# natsorted is a package that takes away some of the annoyances of the regular sorting function
# glob is a package that allows us to load files from directories. Feel free to inspect these lists.

train_cyto = natsorted(glob('woodshole/train/*cyto*'))
train_nuclei = natsorted(glob('woodshole/train/*nuclei*'))
train_raw = [i for i in natsorted(glob('woodshole/train/*.tif')) if 's.tif' not in i]

test_cyto = natsorted(glob('woodshole/test/*cyto*'))
test_nuclei = natsorted(glob('woodshole/test/*nuclei*'))
test_raw = [i for i in natsorted(glob('woodshole/test/*.tif')) if 's.tif' not in i]

In [None]:
# Let's define the decive we'll be using throughout the notebook

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# select a random cytoplasm mask file
cyto_file = random.choice(train_cyto)

# use skimage.io.imread to read our data into numpy arrays
cyto = imread(cyto_file)
nuclei = imread(cyto_file.replace('cyto', 'nuclei'))
raw = imread(cyto_file.replace('_cyto_masks', ''))

#our raw data shape is (c, h, w) and there are only two channels.
#to visualize as an rgb image we need to add another dummy dimension and then transpose so that it is (h,w,3)
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

# visualize the data - execute this cell a few times to see different examples.
# you can also change train_cyto to test_cyto above to see some test data. it is pretty similar 
fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw[:,:,0], cmap='gray')
axes[0][0].title.set_text('Raw nuclei channel')

axes[0][1].imshow(raw[:,:,1], cmap='gray')
axes[0][1].title.set_text('Raw cyto channel')

axes[0][2].imshow(raw)
axes[0][2].title.set_text('Raw overlay')

axes[0][3].imshow(raw[:,:,0], cmap='gray')
axes[0][3].imshow(create_lut(nuclei), alpha=0.5)
axes[0][3].title.set_text('nuclei mask')

axes[0][4].imshow(raw[:,:,1], cmap='gray')
axes[0][4].imshow(create_lut(cyto), alpha=0.5)
axes[0][4].title.set_text('cyto mask')

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.2: Create simple augmentation function</h3>
    
- You have already learned about the importance of augmenting your training data.
- For our exercise we will use an augmentation library called [albumentations](https://albumentations.ai/) which provides easy to use, fast transforms
- Here is a nice tutorial: https://albumentations.ai/docs/examples/example_kaggle_salt/
- To start, we will add a few simple augmentations to both our raw and mask data:
    - randomly crop a 64x64 patch
    - horizontally flip with a 50% probability
    - vertically flip with a 50% probability
- Below is a simple example of a random crop augmentation. 

![example_augmentation](static/example_augmentation.png)

</div>

In [None]:
# Let's define simple augmentation pipeline

file = random.choice(train_nuclei)

full_mask_nuclei = imread(file)
full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]

transform = A.Compose([
              A.RandomCrop(width=64, height=64),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
            ])

transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)
          
aug_raw, aug_mask = transformed['image'], transformed['mask']

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=False,sharey=False,squeeze=False)

axes[0][0].imshow(full_raw_nuclei, cmap='gray')
axes[0][0].imshow(create_lut(full_mask_nuclei), alpha=0.5)
axes[0][0].set_title('original image')

axes[0][1].imshow(aug_raw, cmap='gray')
axes[0][1].imshow(create_lut(aug_mask), alpha=0.5)
axes[0][1].set_title('example random augmentation')

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.3: Create fg/bg representation</h3>
    
- It would be ideal to directly predict unique labels in a dataset. Unfortunately this requires global information which can become difficult as datasets increase in size. Consequently, alternative approaches aim to solve the problem locally.
    
    
- We will start with the most trivial approach: learning a foreground / background mask and then relabeling connected pixels as unique objects. While this approach might suffice on simple datasets, you will see how it can become problematic on datasets in which objects are tightly packed.

</div>

![example_fgbg](static/example_fgbg.png)

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.4: Create simple dataset</h3>

- Now let's combine all this into a simple dataset (similar to what was done in the image segmentation exercise)
- For now our dataset should just load the raw and mask data, and apply simple augmentation.
- We will just use the first channel of the raw data (nuclei) for now (`raw[0]`)

<a id='first-todo'></a>

##### **TODO (1)**

* Create a basic dataset with simple data augmentations called `TissueNetDataset`
* The `__init__` method should load our **mask file names** and **raw file names** into sorted lists
* Add a parameter to your `TissueNetDataset` to define a `crop_size` of `64 pixels`, that will be used in the `augment_data` method
* Add a boolean parameter to your `TissueNetDataset` to create a validation split `val_split`
* The `augment_data` method should take in a raw and mask array and return the augmented raw and mask arrays.
* In the `__getitem__` method, we should only augment our data if our split is `train`. So you will need to also add split as an attribute in the `__init__` method
* After augmenting your data, create a fg / bg representation
* Make sure to return your fg/bg as float32 for training
* Make sure to add a dummy channel dimension to your arrays for training (Pytorch assumes we have tensor shape batch, channel, height, width). We will add a batch dimension later once we create a data loader.
* For now we will just use the nuclei channel for training, so make sure to slice the correct channel of the raw data before returning
* You should return the raw and mask arrays
* If you want to learn more about the ETL(extract, transform, load) pipeline in torch, click [here](https://pytorch.org/docs/stable/data.html)
* Click [here](#second-todo) if you need to go to the next **TODO**

In [None]:
class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False
                ):
        
        # make sure to add your split method since we will use it in `__getitem__`
        self.split = ...
        # make sure to add your crop size
        self.crop_size = ...

        # using the root dir, split and mask create a path to files and sort it 
        # Hint: natsorted glob and os libraries could come in handy
        self.mask_files = ... # load mask files into sorted list
        self.raw_files = ... # load image files into sorted list
        
        # Add another parameter `val_split`. 
        # **There are 20 test files in total**
        # If `split` is `test` and `val_split` is True, take the first half of the test files
        # else take the second half
        
        
    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = ... # create your augmentations
        
        transformed = ... # call your augmentations
        
        raw, mask = ... # get your resulting arrays
        
        return raw, mask
        
    
    def __getitem__(self, idx):
        raw_file = ... # get raw file at index
        mask_file = ... # get mask file at index
        
        raw = ... # load raw to numpy array
        mask = ... # load mask to numpy array
        
        raw = ... # get nuclei channel
        
        # augment your data if split mode is train
        
        mask = ... # erode your labels, cast to float32. Hint: use function that returns just the mask

        raw = ... # add channel dimension to comply with pytorch standard (C, H, W)
        mask = ... # add channel dimension
            
        return raw, mask

In [None]:
#### Solution ####

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False
                ):
        
        self.split = split
        self.crop_size = crop_size
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def augment_data(self, raw, mask):
        
        transform = A.Compose([
              A.RandomCrop(width=self.crop_size, height=self.crop_size),
              A.HorizontalFlip(p=0.5),
              A.VerticalFlip(p=0.5)
        ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0]

        if self.split == 'train':
            raw, mask = self.augment_data(raw, mask)
            
        fg = erode(
                    mask,
                    iterations=1,
                    border_value=1)
        
        mask = (fg != 0).astype(np.float32)

        # add channel dim for network
        raw = np.expand_dims(raw, axis=0).astype(np.float32)
        mask = np.expand_dims(mask, axis=0)
                
        return raw, mask

In [None]:
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test')
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True)

In [None]:
# visualize the representation (repeatedly run cell)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

labels = erode(
    mask,
    iterations=1,
    border_value=1)

labels_two_class = (labels != 0).astype(np.float32)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw.squeeze(), cmap='gray')
axes[0][0].title.set_text('Raw')

axes[0][0].imshow(create_lut(mask.squeeze().astype(int)), alpha=0.5)
axes[0][0].title.set_text('Segmentation')

axes[0][1].imshow(labels_two_class.squeeze())
axes[0][1].title.set_text('Foreground / background')

In [None]:
# run cell repeatedly to see different crops

raw, mask = train_dataset[random.randrange(len(train_dataset))]

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(raw.squeeze(), cmap='gray')
axes[0][0].title.set_text('Raw')

axes[0][1].imshow(mask.squeeze())
axes[0][1].title.set_text('Mask')

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.5: Create shallow network, visualize receptive field</h3>
    
- Let's create a shallow two level U-Net and visualize the receptive field. We will see later how this receptive field changes as we add more layers and change our input image size
    
    
- The receptive field tells us how much of the image the network is looking at in each layer -- this is the amount of spatial context that the network can use to create predictions.
    
    
- Run the following cell to see the networks receptive field. Try changing the downsampling factors to see how it affects the receptive field (eg try combinations of [1,1], [3,3], [4,4], etc)
  
</div>

![example_RF](static/example_RF.png)

In [None]:
raw, mask = train_dataset[random.randrange(len(train_dataset))]

net_t = raw
fovs = []
d_factors = [[2,2],[2,2]]

net = UNet(in_channels=1,
           num_fmaps=6,
           fmap_inc_factors=2,
           downsample_factors=d_factors,
           padding='same'
          )

for level in range(len(d_factors)+1):
    fov_tmp, _ = net.rec_fov(level , (1, 1), 1)
    fovs.append(fov_tmp[0])

fig=plt.figure(figsize=(5, 5))
colors = ["yellow", "red", "green"]

plt.imshow(np.squeeze(raw), cmap='gray')

for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = raw.shape[1]/2 - fov_t/2
    xmax = raw.shape[1]/2 + fov_t/2
    ymin = raw.shape[1]/2 - fov_t/2
    ymax = raw.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.6: Set hyperparameters, create model</h3>
    
- Let's start by setting some hyperparameters. Since we are just doing a fg/bg prediction to start, this will be pretty similar to the semantic segmentation exercise. 
    
</div>

<a id='second-todo'></a>

##### **TODO (2)**
    
- Decide how many output channels to have, remember we are starting with a binary task
- What loss function and final layer activation should we use? Think back to the semantic segmentation exercise.
- What type should we ensure our tensors to be? You can see see a list of tensor types [here](https://pytorch.org/docs/stable/tensors.html) - maybe the equivalent of 32-bit floating point :)
- For our model, we will create a two level U-Net with the following parameters: (to learn more about the torch layers click [here](https://pytorch.org/docs/stable/nn.html))
    - downsample by a factor of 2 in each layer
    - single input channel
    - 32 input feature maps 
    - multiply by a factor of 2 between layers
    - `same` padding (this gives us the same input and output shapes)     
    - Since our Unet will have the same number of output features as input features, we need to add a final convolution to get to our desired output feature maps. We should use a final convolution with kernel size of 1 
    - To see parameter defs you can run `UNet?`
- How many trainable parameters does our network have? 

* Click [here](#first-todo) if you need to go back to the previous **TODO**
* Click [here](#third-todo) if you need to go to the next **TODO**


In [None]:
# set hyperparams

out_channels = ... 
activation = ... 
loss_fn = ... 
dtype = ...

torch.manual_seed(42)

downsample_factors = ... # refer to the init method signature in unet.py module 
in_channels = ... 
num_fmaps = ... 
fmap_inc_factors = ... # refer to the init method signature in unet.py module 
padding = ... 
final_kernel_size = ... 

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = ... # create your network from your unet and final convolution (hint: torch.nn.Sequential might be useful)

net.to(device)

summary(net, (in_channels, 64, 64))

In [None]:
#### Solution ####

out_channels = 1
activation = torch.nn.Sigmoid()
loss_fn = torch.nn.BCELoss()
dtype = torch.FloatTensor

torch.manual_seed(42)

d_factors = [[2,2],[2,2]]
in_channels=1
num_fmaps=32
fmap_inc_factors=2
padding='same'
final_kernel_size=1

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 64, 64))

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.7: Run training loop</h3>
    
- Here we create a train loop function which calls a train step function on each iteration.
    
    
- As you are familiar with by now, we pass our image or feature into our model to get our logits. We then pass our logits through our final activation in order to get our predictions.
    
    
- Our predictions are the input to our loss, with our target being the ground truth labels. We then backpropagate and step the optimizer. In our case we will use the same train step for validation so we have to make sure to only backprop and step if our net is in train mode. We can control this inside our train loop by setting our net to eval before the validation loop, and back to train once it's done.
</div>

In [None]:
def model_step(model, loss_fn, optimizer, feature, label, activation, prediction_type=None, train_step=True):
    
    # zero gradients if training
    if train_step:
        optimizer.zero_grad()
    
    # forward
    logits = model(feature)
    
    if prediction_type == "three_class":
        label=torch.squeeze(label,1)
        
    # final activation
    predicted = activation(logits)

    # pass through loss
    loss_value = loss_fn(input=predicted, target=label)
    
    # backward if training mode
    if train_step:
        loss_value.backward()
        optimizer.step()

    outputs = {
        'pred': predicted,
        'logits': logits,
    }
    
    return loss_value, outputs

In [None]:
def train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype, prediction_type=None):

    # set train flags, initialize step
    net.train() 
    loss_fn.train()
    step = 0

    with tqdm(total=training_steps) as pbar:
        while step < training_steps:
            # reset data loader to get random augmentations
            np.random.seed()
            tmp_loader = iter(train_loader)
            for feature, label in tmp_loader:
                label = label.type(dtype)
                label = label.to(device)
                feature = feature.to(device)
                loss_value, pred = model_step(net, loss_fn, optimizer, feature, label, activation, prediction_type)
                writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
                step += 1
                pbar.update(1)
                if step % 100 == 0:
                    net.eval()
                    tmp_val_loader = iter(val_loader)
                    acc_loss = []
                    for feature, label in tmp_val_loader:                    
                        label = label.type(dtype)
                        label = label.to(device)
                        feature = feature.to(device)
                        loss_value, _ = model_step(net, loss_fn, optimizer, feature, label, activation, prediction_type, train_step=False)
                        acc_loss.append(loss_value.cpu().detach().numpy())
                    writer.add_scalar('val_loss',np.mean(acc_loss),step)
                    net.train()

                    print(np.mean(acc_loss))

<a id='third-todo'></a>

##### **TODO (3)**

- Now let's train a model.
- Create our data loaders from our datasets. The data loader will take care of batching. We should have 3 dataloaders, one for each dataset we created
- The train data loader should use a batch size of 4 (shape = 4, c, h, w) and our val/test data loaders should use a batch size of 1. Set shuffle and pin memory to `True` in the train loader. (for more info on dataloaders see [here](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader))
- For now lets just train for 1000 steps
- Use a learning rate of 1e-4 and an Adam optimizer
- Use the `train` function with all the required parameters to train the model
- Click [here](#second-todo) if you need to go back to the previous **TODO**
- Click [here](#fourth-todo) if you need to go to the next **TODO**

In [None]:
train_batch_size = ... # set train batch size
test_batch_size = ... # set test / val batch size

train_loader = ... # create train data loader
test_loader = ... # create test data loader
val_loader = ... # create val data loader

training_steps = ... # set your training setps

# create a logdir for each run and a corresponding summary writer
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

# make sure net and loss are cast to our device (should be gpu, can check by printing device)
net = net.to(device)
loss_fn = loss_fn.to(device)

# set optimizer
learning_rate = ... # set your learning rate 
optimizer = ... # create your optimizer

In [None]:
# run training loop... (eg call train)

In [None]:
#### Solution ####

train_batch_size = 4
test_batch_size = 1

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size)
val_loader = DataLoader(val_dataset, batch_size=test_batch_size)

training_steps = 1000

# create a logdir for each run and a corresponding summary writer
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

# make sure net and loss are cast to our device (should be gpu, can check by printing device)
net = net.to(device)
loss_fn = loss_fn.to(device)

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# run training loop
train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype)

In [None]:
# To view runs in tensorboard you can call either (uncommented):

# %reload_ext tensorboard
# %tensorboard --logdir logs

# or to view in separate window, run:

# !tensorboard --logdir=logs 

# Note that if running over ssh you will need to also forward the tensorflow port (usually 6006)
# you can also do this by passing the host (relevant machine ip address), eg:

# !tensorboard --logdir=logs --host hostname

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 1.8: Visualize results</h3>
    
- Here we will run inference using our trained model

- We will iterate over the test loader and pass our image through the model

- Once we have our prediction (after passing logits through our activation function), we will do a simple thresholding and relabelling to get a segmentation
  
</div>

In [None]:
# make sure net is in eval mode so we don't backprop
net.eval()

for idx, (image, mask) in enumerate(test_loader):
    
    # move image to device
    image = image.to(device)
    
    # pass image through network
    logits = net(image)
    
    # pass logits through activation
    pred = activation(logits)
        
    # get our tensors to numpy arrays so we can post-process / visualize
    image = image.cpu()
    mask = mask.cpu().numpy()
    
    # we need to detach our predicted tensor
    pred = pred.cpu().detach().numpy()

    # we also need to remove the batch/channel dimensions from the arrays for visualizing
    # (b,c,h,w) -> (h, w)
    image = np.squeeze(np.squeeze(image))
    mask = np.squeeze(np.squeeze(mask))
    pred = np.squeeze(np.squeeze(pred))
                
    # get threshold value (How does the segmentation change if we change to 0.1 / 0.8?)
    thresh = threshold_otsu(pred)
    
    # fg = prediction greater than / equal to threshold
    boundary_mask = pred >= thresh
    
    # relabel all boundary_mask connected pixels as unique objects
    labeled = relabel_cc(boundary_mask)
    
    # get the corresponding gt labels to visualize
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    fig, axes = plt.subplots(1,6,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
        
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')

    axes[0][1].imshow(mask)
    axes[0][1].title.set_text('GT mask')
    
    axes[0][2].imshow(create_lut(gt_labels))
    axes[0][2].title.set_text('GT seg')
    
    axes[0][3].imshow(pred)
    axes[0][3].title.set_text('Predicted Mask')
    
    axes[0][4].imshow(boundary_mask)
    axes[0][4].title.set_text('Thresholded Mask')

    axes[0][5].imshow(create_lut(labeled))
    axes[0][5].title.set_text('Predicted Seg')

    break

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 1</h1>

</div>

## Task 2.0: Improving the model

- As you can see, our prediction segmentation isn't very good. When objects are tightly packed together, using a simple foreground / background representation gives us results that aren't much better than if we just thresholded our data and relabelled connected components.


- So, let's improve our model. We can do a few things to enhance our results:
    1. Add more complex representations
        * three class
        * signed distance transform
        * edge affinities
    2. Add better augmentations
    3. Increase the input size to our network
    4. Use a bigger network (eg increase layers, number of feature maps)
    5. Train for longer
    6. Use a better post-processing strategy (e.g. seeded watershed)


- Click [here](#task-10-creating-a-simple-model) to go back to the previous task
- Click [here](#task-20-improving-the-model) to go back to the next task

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 2.1: Add more augmentations</h3>
    
- We were using some pretty simple augmentations (crop and flips). Since we want to create a more robust model, we should augment our data more so that it performs better on data that it hasn't seen. This is also a good way to effectively increase our training sample size. 
- This is a good tutorial for adding useful augmentations: #https://albumentations.ai/docs/examples/example_kaggle_salt/
- We will still crop and flip our data as before. Additionally, we will:
    1. Pad our data if needed (This is useful if our network input size is not compatible with our max pooling layers)
    2. Randomly rotate by 90 degrees
    3. Transpose
    4. Randomly adjust our brightness and contrast
- For an intuition about the PadIfNeeded augmentation see the tutorial above. This is for compatibility in our Unet with the max pooling layers. In the following cells we will increase our Unet layers from 2 -> 3. Since we will have 3 max pooling layers, we will need to pad the images by the next closest number that is divisible by 2^3 (8). In our case since we are using a crop size of 64 we won't need to pad since it is divisible by 8. But in the case that our crop size was 65, we would need to take the next number divisible by 8 which would be 72. The PadIfNeeded() augmentation will handle these cases for us, but we need to provide it with the correct padding.
</div>

In [None]:
# We're using probability of 1 (the padding should always occur if needed)
# We're using `border_mode` of 0 to pad it with zeros (reflect is the default which we don't want)

def augment_data(raw, mask, padding, crop_size):
    
    transform = A.Compose([
            A.RandomCrop(
                width=crop_size,
                height=crop_size),
            A.PadIfNeeded(
                min_height=padding,
                min_width=padding,
                p=1,
                border_mode=0),
            A.HorizontalFlip(p=0.3),
            A.VerticalFlip(p=0.3),
            A.RandomRotate90(p=0.3),
            A.Transpose(p=0.3),
            A.RandomBrightnessContrast(p=0.3)
        ])

    transformed = transform(image=raw, mask=mask)

    raw, mask = transformed['image'], transformed['mask']
    
    return raw, mask

<div class="alert alert-block alert-info"><h3>Task 2.2: Add extra representations</h3>
    
- Three-class model

This is an extension of the basic foreground/background (or two-class) model. In addition a third class is introduced: the boundary. Even if two instances are touching, there is a boundary between them. This way they can be separated. Instead of a single output (where an output of zero is one class and of one is the other class), the network outputs three values, one per class. And the loss function changes from binary to (sparse) categorical cross entropy.
    
- Signed Distance Transform

The label for each pixel is the distance to the closest boundary. The value within instances is negative and outside of instances is positive. As the output is not a probability but an (in principle) unbounded scalar, the mean squared error loss function is used.
    
- Edge Affinities

Here we consider not just the pixel but also its direct neighbors (in 2D the left neighbor and the upper neighbor are sufficient, right and down are redundant with the next pixel's left and upper neighbor). Imagine there is an edge between two pixels if they are in the same class and no edge if not. If we then take all pixels that are directly and indirectly connected by edges, we get an instance. Essentially, we label edges between neighboring pixels as “connected” or “cut”, rather than labeling the pixels themselves. This representation can be useful especially in the case of smaller objects that would otherwise be classified as background pixels. We can use mean squared error. 

</div>

![diff_costs](static/diff_costs.png)

In [None]:
# compute each representation and visualize

file = random.choice(train_nuclei)

full_mask_nuclei = imread(file)
full_raw_nuclei = imread(file.replace('_nuclei_masks', ''))[0]

# We're applying a random crop to the image to keep the shape consistent
transform = A.Compose([
              A.RandomCrop(width=64, height=64),
            ])

transformed = transform(image=full_raw_nuclei, mask=full_mask_nuclei)
          
aug_raw, aug_mask = transformed['image'], transformed['mask']
    
labels, border = erode_border(
    aug_mask,
    iterations=1,
    border_value=1)

labels_two_class = (labels != 0)
border[border!=0] = 2

labels_three_class = (labels_two_class + border)
sdt = compute_sdt(labels)
affs = compute_affinities(labels, nhood=[[0,1],[1,0]])

fig, axes = plt.subplots(1,5,figsize=(20, 10),sharex=True,sharey=True,squeeze=False)

for idx, (ds_name, data) in enumerate([
    ('raw', aug_raw),
    ('fg/bg', labels_two_class),
    ('three class', labels_three_class),
    ('sdt', sdt),
    ('affinities', affs[0] + affs[1])]
):

    cmap = 'gray' if ds_name == 'raw' else 'viridis'

    axes[0][idx].imshow(data.astype(np.float32), cmap=cmap)
    axes[0][idx].title.set_text(ds_name)

<a id='fourth-todo'></a>

##### **TODO (4)**

- Add `prediction_type` as a parameter to your dataset
- Add a function `create_target` which should take your augmented mask and a prediction type and return the correct representation as float32
- Create an extra argument to the dataset `padding_size`. This should be 2^(number of max pooling layers), which will be 2^3 -> `8` for us
- Update your augmentation function with the function we defined eatlier
- Create a function `get_padding` that takes the `crop_size` and `padding_size` and checks if the crop size is divisible by the padding size. If it is, return the crop size, otherwise get the next higher number that is divisible by the padding size and return this number. 
- succesful calls should look like:
    - `get_padding(64, 8) -> 64`
    - `get_padding(65, 8) -> 72`
    - `get_padding(72, 8) -> 72`
    - `get_padding(73, 8) -> 80`
- The returned padding should then be passed in as the `min_height` and `min_width` for the augmentation
- Your `__getitem__` should call `create_target` after augmenting your data
- Try each prediction type 
- What happens if you increase crop size to 65? - how does the padding come into play?
- Click [here](#third-todo) if you need to go back to the previous **TODO**
- Click [here](#fifth-todo) if you need to go to the next **TODO**

In [None]:
class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.crop_size = crop_size
        
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        
        # Add a parameter `prediction_type` to define which representation to choose.
        # Default should be `two_class`
        
        # Add a parameter `padding_size` to define when to pad the image
        # Default should be 8 (2^3 for 3 max pooling layers)

        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
        else:
            self.mask_files = self.mask_files[10:]
            self.raw_files = self.raw_files[10:]
        
        
    def __len__(self):
        return len(self.raw_files)
    
    
    def get_padding(self, crop_size, padding_size):
    
        padding = ... # calculate your padding
    
        return padding

    
    def create_target(self, mask, prediction_type):
        
        mask, border = ... # erode your labels, return inner and border
        
        if prediction_type == 'two_class':
            mask = ... # get two class
        elif prediction_type == 'three_class':
            mask = ... # get three class
        elif prediction_type == 'sdt':
            mask = ... # get sdt using the function defined above
        elif prediction_type == 'affs':
            mask = ... # get affs using the function defined above
            
        mask = ... # cast mask to float32
        
        return mask  
    
    
    def __getitem__(self, idx):
        raw_file = ... # get raw file at index
        mask_file = ... # get mask file at index
        
        raw = imread(raw_file)
        mask = imread(mask_file)
        
        raw = ... # get nuclei channel, dont forget to cast to float32
        
        if self.split == 'train':
            padding = ... # get padding using the function we defined above 
            raw, mask = ... # get your augment your data using the correct parameters
        
        mask = ... # create your learning representation
        
        raw = ... # add raw channel dimension
        mask = ... # add mask channel dimension (**if not** using affs since you get two channels anyway)
            
        return raw, mask

In [None]:
#### Solution ####

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding


    def create_target(self, mask, prediction_type):
        
        mask, border = erode_border(
                    mask,
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')
        
        return mask.astype(np.float32)
    

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        # for now just do single channel training
        raw = raw[0].astype(np.float32)

        if self.split == 'train':
            padding = self.get_padding(self.crop_size, self.padding_size)
            raw, mask = augment_data(raw, mask, padding, self.crop_size)
            
        mask = self.create_target(mask, self.prediction_type)
        
        # add channel dim for network
        raw = np.expand_dims(raw, axis=0)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
            
        return raw, mask

In [None]:
prediction_type = ... # try each of prediction types we defined in create_target function
crop_size = ...

train_dataset = TissueNetDataset(
    root_dir='woodshole',
    split='train',
    crop_size=crop_size,
    prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

if mask.shape[0] == 1:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
else:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 2.3: Add extra hyperparameters based on prediction type</h3>
    
- Now that we have more representations, we need to be sure that our hyperparams are consistent with whichever representation we choose. 

</div>

<a id='fifth-todo'></a>

##### **TODO (5)**

- Create a function `get_hyperparams` that takes your `prediction_type` and returns a dictionary mapping our hyperparameters:   
    ```
    params = {
        'out_channels': out_channels,
        'activation': activation,
        'loss_fn', loss_fn,
        'dtype': dtype
     }
     ```
- You already know what these need to be for the two class representation, let's figure out the others
- **Hint**: We will choose from the following loss functions: `nn.CrossEntropyLoss` and `nn.MSELoss`
- **Hint**: We will choose from the following final activations: `tanh`, `softmax`, `sigmoid`
- **Three class**:
    - We are now doing multi class classification, what type of loss function should we use?
    - Our number of out channels should be equal to the number of classes we are trying to predict
    - What final activation should we use? What dimension should it be computed on? Remember our tensors will have shape (b,c,h,w) or (0,1,2,3)
    - What dtype should we have? It should be the torch equivalent of a 64-bit integer (signed) (see [here](https://pytorch.org/docs/stable/tensors.html) if stuck with tensor types)   
- **Sdt**:
    - Since we are computing a **signed** distance transform, we want to have negative values outside of our objects. Therefore we want an activation that can give us values between -1 and 1. 
    - We are doing regression now, what loss function can we use?
    - Our output will be a single number per pixel, so how many output channels should we have?
    - Our dtype should be the torch equivalent of a 32-bit floating point   
- **Affs**:
    - We are doing regression again, what should our loss be?
    - Our outputs will be between 0 and 1, what activation should we use?
    - We can use the same dtype as Sdt
    - We will have both x and y affinities, so how many channels should we have?
- Click [here](#fourth-todo) if you need to go back to the previous **TODO**
- Click [here](#sixth-todo) if you need to go to the next **TODO**

In [None]:
def get_hyperparams(prediction_type):
    
    if prediction_type == "two_class":
        out_channels = 1
        activation = torch.nn.Sigmoid()
        loss_fn = torch.nn.BCELoss()
        dtype = torch.FloatTensor
        
    elif prediction_type == "three_class":
        ... # get params
        
    elif prediction_type == "sdt":
        ... # get params
        
    elif prediction_type == "affs":
        ... # get params
        
    else:
        raise RuntimeError("invalid prediction type")
        
    params = ... # get dict
    
    return params

In [None]:
#### Solution ####

def get_hyperparams(prediction_type):

    if prediction_type == "two_class":
        out_channels = 1
        activation = torch.nn.Sigmoid()
        loss_fn = torch.nn.BCELoss()
        dtype = torch.FloatTensor

    elif prediction_type == "three_class":
        out_channels = 3
        activation = torch.nn.Softmax(dim=1)
        loss_fn = torch.nn.CrossEntropyLoss()
        dtype = torch.LongTensor

    elif prediction_type == "sdt":
        out_channels = 1
        activation = torch.nn.Tanh()
        loss_fn = torch.nn.MSELoss()
        dtype = torch.FloatTensor

    elif prediction_type == "affs":
        out_channels = 2
        activation = torch.nn.Sigmoid()
        loss_fn = torch.nn.MSELoss()
        dtype = torch.FloatTensor

    else:
        raise RuntimeError("invalid prediction type")
        
    params = {
        'out_channels': out_channels,
        'activation': activation,
        'loss_function': loss_fn,
        'dtype': dtype
    }
        
    return params

In [None]:
prediction_type = 'affs'

params = get_hyperparams(prediction_type)

print(params)

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 2.4: Increase patch crop size</h3>
    
- Before we were using a smaller patch crop size (64). Since we are training a 2d network with a relatively small batch number (4), it is not such a big deal to increase our crop size (128) to let our network see more data.
- **Note**, that increasing the crop size won't affect the receptive field of the network.

</div>

In [None]:
# try each prediction type
prediction_type = 'affs'
crop_size = 128

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

if mask.shape[0] == 1:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
else:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 2.5: Increase features in network and train longer</h3>
    
- Before we were training with a pretty small network, eg two downsampling levels. Let's see how the receptive fields change as we increase our network to 3 levels
    
    
- Try changing the downsampling factors as before. How does it change the receptive field?

</div>

![deep_RF](static/deep_RF.png)

In [None]:
raw, mask = train_dataset[random.randrange(len(train_dataset))]

net_t = raw
fovs = []
d_factors = [[2,2],[2,2],[2,2]]

net = UNet(in_channels=1,
           num_fmaps=6,
           fmap_inc_factors=2,
           downsample_factors=d_factors,
           padding='same'
          )

for level in range(len(d_factors)+1):
    fov_tmp, _ = net.rec_fov(level , (1, 1), 1)
    fovs.append(fov_tmp[0])

fig=plt.figure(figsize=(8, 8))
colors = ["yellow", "red", "green", "blue"]

plt.imshow(np.squeeze(raw), cmap='gray')

for idx, fov_t in enumerate(fovs):
    print("Field of view at depth {}: {:3d} (color: {})".format(idx+1, fov_t, colors[idx]))
    xmin = raw.shape[1]/2 - fov_t/2
    xmax = raw.shape[1]/2 + fov_t/2
    ymin = raw.shape[1]/2 - fov_t/2
    ymax = raw.shape[1]/2 + fov_t/2
    plt.hlines(ymin, xmin, xmax, color=colors[idx], lw=3)
    plt.hlines(ymax, xmin, xmax, color=colors[idx], lw=3)
    plt.vlines(xmin, ymin, ymax, color=colors[idx], lw=3)
    plt.vlines(xmax, ymin, ymax, color=colors[idx], lw=3)
plt.show()

<a id='sixth-todo'></a>

##### **TODO (6)**
    
- Choose a prediction type (other than two_class) to use and get the corresponding parameters
- Add another layer to your network with the same downsampling factors
- Increase your multiplication factor between layers to 3
- Make sure you correctly set the output channels in your final convolution (hint: get it from your param dict)
- How many trainable parameters do we have now? How does this compare to when we used two layers and a mult factor of 2 instead of 3?
- Create your datasets and loaders as before. Increase crop patch size (eg 64 -> 128)    
- Use the same learning rate and optimizer as before   
- Train for longer (eg 1000 -> 3000 steps)
- Click [here](#fifth-todo) if you need to go back to the previous **TODO**
- Click [here](#seventh-todo) if you need to go to the next **TODO**

In [None]:

prediction_type = ... 
params = get_hyperparams(prediction_type)

torch.manual_seed(42)

downsample_factors = ... # refer to the init method signature in unet.py module 
in_channels=1
num_fmaps=32
fmap_inc_factors= ... # refer to the init method signature in unet.py module 
padding='same'
final_kernel_size=1
out_channels = params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 128, 128))

In [None]:
#### Solution ####

prediction_type = 'affs'
params = get_hyperparams(prediction_type)

torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]
in_channels=1
num_fmaps=32
fmap_inc_factors=3
padding='same'
final_kernel_size=1
out_channels = params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 128, 128))

In [None]:
training_steps = ...
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

crop_size = ...

### create datasets
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# run training loop
train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype, prediction_type)

In [None]:
#### Solution ####

training_steps = 3000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

crop_size = 128

### create datasets
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# run training loop
train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype, prediction_type)

In [None]:
# Visualize predictions

net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
    
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')
    
    if prediction_type == 'three_class':
        # get indices of the maximum values along channel axis
        pred = np.argmax(pred, axis=0)
        
    if prediction_type == 'affs':
        axes[0][1].imshow(mask[0] + mask[1])
        axes[0][1].title.set_text('GT mask')
        axes[0][2].imshow(pred[0] + pred[1])
        axes[0][2].title.set_text('Predicted')
        
    else:
        axes[0][1].imshow(mask)
        axes[0][1].title.set_text('GT mask')
        axes[0][2].imshow(pred)
        axes[0][2].title.set_text('Predicted')
      
    if idx == 2:
        break

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 2</h1>

</div>

## Task 3.0: Post-processing / further improvements

- Before we were just thresholding and relabeling connected components. Now we will see a more advanced post-processing strategy called watershed.


- We also want to gauge our model performance so we will introduce some evaluation methods for instance segmentation.


- Finally, up until now we were just using a single input channel to our network - but since we have multiple channels in our raw data we should leverage them. You will get the opportunity to put everything together to try to improve your model.

- Click [here](#task-20-improving-the-model) to go to the previous task
- Click [here](#task-30-post-processing--further-improvements) to go to the next task

<div class="alert alert-block alert-info"><h3>Task 3.1: Introduce watershed</h3>
    
- Before we were just thresholding our predictions and then relabeling connected components. This is a totally fine approach in the cases where we don't have touching objects. Now we will use a better approach commonly used for instance segmentation called seeded watershed. See here for a nice overview: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_watershed.html
    
    
- To compute our seeded watershed, we first need to get a boundary mask from our predictions. This is done slightly differently for the various representations, but generally speaking our boundary mask will just be a boolean indicating our foreground regions. From this boundary mask we compute boundary distances using a distance transform. These will then give us local maxima that can be used to extract seed points. The watershed algorithm then expands each seed out in a local "basin" until the segments touch.
    
    
- Because of this, it is often not sufficient to use watershed alone on complex datasets. In most cases the resulting objects are referred to as fragments (or supervoxels), which can then be stitched together using the underlying predictions as edge weights through a process called agglomeration.
    
    
- Agglomeration is out of the scope of this exercise, but you can find a nice overview here: https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_boundary_merge.html
</div>

In [None]:
# get segmentations

net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
    
    # feel free to try different thresholds
    thresh = np.mean(pred)
            
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    seg = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(create_lut(seg))
    axes[0][2].title.set_text('Predicted Labels')
    
    if idx == 2:
        break

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 3.2: Evaluate</h3>
    
- There are several ways to evaluate accuracy of models for instance segmentation.
    
    
- For our purposes, we will calculate the intersection over union (IoU) between the ground truth labels and the predicted labels. Here is a nice overview: https://www.jeremyjordan.me/evaluating-image-segmentation-models/
    
    
- We then choose a threshold on the IoU for defining a match, e.g. if we choose a threshold of 0.5, then an IoU >= 0.5 means a true positive for a ground truth label and a predicted label.
    
    
- Using IoU and the threshold, we match predicted labels to ground truth labels, and then evaluate:
    1. True positives
    2. False positives
    3. False negatives
    4. Precision
    5. Recall
    6. Average precision
    
    
- These will already give a good indication of model performance. You can easily look up some information on each of these metrics, and you will have a good idea about why they are used following the previous exercises and lectures.
</div>

In [None]:
# Evaluate on a single batch

net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = np.mean(pred)
                  
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh=thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    pred_labels = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    ap, precision, recall, tp, fp, fn = evaluate(gt_labels, pred_labels)
    
    print(
        f'Computed with IoU threshold = 0.5\n',
        f'Average precision: {ap} \n',
        f'Precision: {precision} \n',
        f'Recall: {recall} \n',
        f'True positives: {tp} \n',
        f'False positives: {fp} \n',
        f'False negatives: {fn} \n'
    )
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(create_lut(pred_labels))
    axes[0][2].title.set_text('Predicted Labels')
    
    break

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 3.3: Loop over batches</h3>
    
- Now we can loop over all test set images and get the the average model precision

</div>

In [None]:
avg = 0.0

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = np.mean(thresh)
            
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    pred_labels = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    ap, precision, recall, tp, fp, fn = evaluate(gt_labels, pred_labels)
    
    avg += ap
    
    print(
        f'Average precision: {ap} \n',
        f'Precision: {precision} \n',
        f'Recall: {recall} \n',
        f'True positives: {tp} \n',
        f'False positives: {fp} \n',
        f'False negatives: {fn} \n'
    )
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image, cmap='gray')
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(create_lut(pred_labels))
    axes[0][2].title.set_text('Predicted Labels')
    
    plt.show()
        
avg /= (idx+1)
    
print("average precision on test set: {}".format(avg))

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 3.4: Use both channels of raw data</h3>

- Up until now, we have only been using the nuclei channel of the raw data as input into our network. But if we have extra channels available, it will help our network to see them. We should give our network as much information as possible to learn from, even if it is only tasked with learning a single channel output.
    
</div>

<a id='seventh_todo'></a>

##### **TODO (7)**

- Update your dataset to use both channels of raw data
- It is very important to keep track of the shape of your data as it passes through.
- By default our raw data is (C, H, W) but albumentations expects (H,W,C) for rgb data
- Our mask data is H,W by default so we need to add a channel dimension
- Following augmentation, we then need to get both our raw and mask back back to (C,H,W) for training
- Make sure when creating your target representation that you handle the mask channel correctly (eg pass mask[0] instead of mask)
- Click [here](#sixth_todo) if you need to go back to the previous **TODO**
- Click [here](#eighth_todo) if you need to go to the next **TODO**

In [None]:
# Define the dataset class 

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]


    def __len__(self):
        return len(self.raw_files)
    
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding


    def create_target(self, mask, prediction_type):
        
        mask, border = erode_border(
                    mask,
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')
        
        return mask.astype(np.float32)
    

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)
        
        raw = ... # make dimensions are in the right order, transpose could be useful. 
        # And don't forget to cast to float32 
        mask = np.expand_dims(mask, axis=-1)
        
        if self.split == 'train':
            padding = self.get_padding(self.crop_size, self.padding_size)
            raw, mask = augment_data(raw, mask, padding, self.crop_size)
            
        raw = ... # yet again, make sure the dimensions are in the right order
        mask = ... # yet again, make sure the dimensions are in the right order
        
        mask = self.create_target(mask[0], self.prediction_type)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
                                        
        return raw, mask

In [None]:
#### Solution ####

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 prediction_type='two_class',
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.prediction_type = prediction_type
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def create_target(self, mask, prediction_type):
        
        mask, border = erode_border(
                    mask,
                    iterations=1,
                    border_value=1)
        
        if self.prediction_type == 'two_class':
            mask = (mask != 0)

        elif self.prediction_type == 'three_class':
            labels_two_class = (mask != 0)
            border[border!=0] = 2
            
            mask = labels_two_class + border

        elif self.prediction_type == 'sdt':
            mask = compute_sdt(mask)

        elif self.prediction_type == 'affs':
            mask = compute_affinities(mask, nhood=[[0,1],[1,0]])

        else:
            raise Exception('Choose from one of the following prediction types: two_class, three_class, sdt, affs')
        
        return mask.astype(np.float32)

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)
        
        raw = raw.transpose([1,2,0]).astype(np.float32)
        mask = np.expand_dims(mask, axis=-1)
        
        if self.split == 'train':
            padding = self.get_padding(self.crop_size, self.padding_size)
            raw, mask = augment_data(raw, mask, padding, self.crop_size)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
        
        mask = self.create_target(mask[0], self.prediction_type)
        
        if self.prediction_type != 'affs':
            mask = np.expand_dims(mask, axis=0)
                                        
        return raw, mask

In [None]:
# try each prediction type
prediction_type = 'affs'
crop_size = 128

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)

raw, mask = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(1,2,figsize=(10, 10),sharex=True,sharey=True,squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].title.set_text('Raw')

if mask.shape[0] == 1:
    axes[0][1].imshow(np.squeeze(mask))
    axes[0][1].title.set_text('Mask')
else:
    # affs has two channels (x/y)
    axes[0][1].imshow(mask[0]+mask[1])
    axes[0][1].title.set_text('Mask')

<a id='eighth_todo'></a>

##### **TODO (8)**   
    
- Create your network, hyperparameters and data loaders as last time (eg 3 levels, 3 mult factor, 128 crop size, 3k iterations, etc). Make sure to use the correct number of input channels to your network!!!
- Click [here](#seventh_todo) if you need to go back to the previous **TODO**
- Click [here](#final_todo) if you need to go to the next **TODO**

In [None]:
prediction_type = 'affs'
params = get_hyperparams(prediction_type)

torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]
in_channels= ... # set correct number of input channels
num_fmaps=32
fmap_inc_factors=3
padding='same'
final_kernel_size=1
out_channels = params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 128, 128))

In [None]:
#### Solution ####

prediction_type = 'affs'
params = get_hyperparams(prediction_type)

torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]
in_channels=2
num_fmaps=32
fmap_inc_factors=3
padding='same'
final_kernel_size=1
out_channels = params['out_channels']

unet = UNet(
        in_channels=in_channels,
        num_fmaps=num_fmaps,
        fmap_inc_factors=fmap_inc_factors,
        downsample_factors=d_factors,
        padding=padding)

final_conv = torch.nn.Conv2d(
    in_channels=num_fmaps,
    out_channels=out_channels,
    kernel_size=final_kernel_size)

net = torch.nn.Sequential(unet, final_conv)

net = net.to(device)

summary(net, (in_channels, 128, 128))

In [None]:
training_steps = 3000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
loss_fn = params['loss_function'].to(device)
activation = params['activation']
dtype = params['dtype']

crop_size = 128

### create datasets
train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=crop_size, prediction_type=prediction_type)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, prediction_type=prediction_type)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# run training loop
train(train_loader, val_loader, net, loss_fn, activation, optimizer, dtype, prediction_type)

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 3.5: Predict and visualize</h3>

- Once the model is trained, we loop over the test dataset and visualize the results.
    
</div>

In [None]:
avg = 0.0

net.eval()

for idx, (image, mask) in enumerate(test_loader):
    image = image.to(device)
    logits = net(image)
    pred = activation(logits)
        
    image = np.squeeze(image.cpu())
    mask = np.squeeze(mask.cpu().numpy())
        
    pred = np.squeeze(pred.cpu().detach().numpy())
        
    thresh = np.mean(pred)
            
    boundary_mask = get_boundary_mask(pred, prediction_type, thresh)
    boundary_distances = distance_transform_edt(boundary_mask)
    
    pred_labels = watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask
    )
    
    gt_labels = imread(test_loader.dataset.mask_files[idx])
    
    ap, precision, recall, tp, fp, fn = evaluate(gt_labels, pred_labels)
    
    avg += ap
    
    print(
        f'Average precision: {ap} \n',
        f'Precision: {precision} \n',
        f'Recall: {recall} \n',
        f'True positives: {tp} \n',
        f'False positives: {fp} \n',
        f'False negatives: {fn} \n'
    )
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
        
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(gt_labels))
    axes[0][1].title.set_text('GT Labels')
    
    axes[0][2].imshow(create_lut(pred_labels))
    axes[0][2].title.set_text('Predicted Labels')
    
    plt.show()
        
avg /= (idx+1)
    
print("average precision on test set: {}".format(avg))

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 3.6 (time permitting): Improve accuracy / segment cyto</h3>
    
- It is likely that even with the extra things we added, we still aren't achieving the level of accuracy we would like to see. Production level models are usually trained for many more iterations and use lots of tricks to maximize the accuracy (also there are some false labels in the ground truth which will punish our network predictions even if correct). It is more important for now to conceptually understand the basics of instance segmentation and different approaches to increasing model robustness.
    
    
- If you have time (now or in the future), try to improve the accuracy of your model. How accurate can you get it? Can you also get an accurate model on the cytoplasm masks? 
    
    
- If you have time (now or in the future), try the following bonus exercises - which show more advanced approaches to getting good instance segmentation results
</div>

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 3</h1>

</div>

## Task 4.0: Auxiliary learning

- Auxiliary learning is a powerful technique that can help to improve the results of our main objective by providing a helper task. Up until now, we have only shown our model representations of the data that are boundary specific. But the data is a lot richer than that - these objects have distinct shapes that could be leveraged in order to better learn the boundaries.

- Click [here](#task-30-post-processing--further-improvements) to go to the previous task.
- Click [here](#task-50-bonus-exercises-and-further-learning) to go to the bonus task.

<hr style="height:2px;"><div class="alert alert-block alert-info"><h3>Task 4.1: Cellpose</h3>
    
- In [**Cellpose**](https://cellpose.readthedocs.io/en/latest/), cells are turned into flow representations. We create these flow representations by simulating diffusion from the center of the cell to get the spatial gradients for each pixel that point towards the center of the cell. During test time, we use the flows as a dynamical system and all pixels that converge to the same point are defined as the pixels in a given cell. The flows shown below are represented by an HSV colormap used in the optic flow literature.
    
    
- We also predict the foregroud / background -- the two classes you predicted in exercise 1. In Cellpose we call this the cell probability. We threshold this to decide which pixels are in cells -- we only use these pixels to run the dynamical system.
    
    
- The flow representation allows the learning of non-convex shapes, because pixels can flow around corners. It also prevents merging, as flows for two cells that are touching are opposite.
</div>

![cellpose_flows](static/cellpose_flows.png)

<a id='final-todo'></a>

##### **TODO (9)**

- Use relevant pretrained model for prediction:
    - get the Model class from `cellpose.models`.  
    **Hint** Load the weights for TissueNet from Cellpose, refer to the [documentation](https://cellpose.readthedocs.io/en/latest/models.html#other-built-in-models)
    - call `model.eval` method with correct parameters
    - for the nuclei model, also set the correct channels  
    **Hint** refer to the [documentation](https://cellpose.readthedocs.io/en/latest/settings.html#channels)
- Click [here](#eighth_todo) if you need to go back to the previous **TODO**

In [None]:
from cellpose import models

# create a cellpose model on the gpu
# use a built-in model trained on tissuenet
# (the first time you run this cell the model will download)

model = # get Cellpose model

test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
test_loader = DataLoader(test_dataset, batch_size=1)

### IMPORTANT: these are the channels used for the segmentation
# the first one is the channel to segment, and the second one is the optional nuclear channel
# red = 1
# green = 2
# blue = 3

channels = [2, 1]
# Diameter parameter for tissuenet dataset
diameter = 25

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = # call model in evaluation mode with image, diameter and channels parameters
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    
    if idx == 2:
        break

In [None]:
# we could also use the nuclear channel ONLY and run a nuclear model in cellpose

channels = ... # set correct channels
diameter = 20

# initialize nuclei model (can also try "cyto" model if this doesn't work)
# the "nuclei" model in cellpose has been trained on lots of nuclear data (but not the tissuenet dataset)
# the "cyto" model in cellpose has been trained on many cellular images (but not the tissuenet dataset)
model = ... # Get the model for nuclei

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = # call model in evaluation mode with image, diameter and channels parameters
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    if idx == 2:
        break

In [None]:
#### Solution ####


from cellpose import models

# create a cellpose model on the gpu
# use a built-in model trained on tissuenet
# (the first time you run this cell the model will download)
model = models.CellposeModel(gpu=device, model_type='tissuenet')

test_dataset = TissueNetDataset(root_dir='woodshole', split='test', prediction_type=prediction_type)
test_loader = DataLoader(test_dataset, batch_size=1)

### IMPORTANT: these are the channels used for the segmentation
# the first one is the channel to segment, and the second one is the optional nuclear channel
# red = 1
# green = 2
# blue = 3

channels = [2, 1]

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = model.eval(image, diameter=25, channels=channels)
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    
    if idx == 2:
        break

In [None]:
#### Solution ####


# we could also use the nuclear channel ONLY and run a nuclear model in cellpose
# we set the second channel = 0 because we do not have an additional channel now
channels = [1, 0]

# initialize nuclei model (can also try "cyto" model if this doesn't work)
# the "nuclei" model in cellpose has been trained on lots of nuclear data (but not the tissuenet dataset)
# the "cyto" model in cellpose has been trained on many cellular images (but not the tissuenet dataset)
model = models.CellposeModel(gpu=device, model_type='nuclei')

masks_cp = []
for idx, (image, mask) in enumerate(test_loader):
    image = image.cpu().detach().numpy()
    mask_cp, flows, styles = model.eval(image, diameter=20, channels=channels)
    masks_cp.append(mask_cp)
    
    fig, axes = plt.subplots(1,4,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.squeeze(image)
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
    
    axes[0][1].imshow(create_lut(mask_cp))
    axes[0][1].title.set_text('Predicted Labels')

    axes[0][2].imshow(flows[0])
    axes[0][2].title.set_text('Predicted cellpose')
    
    axes[0][3].imshow(flows[2])
    axes[0][3].title.set_text('Predicted cell probability')
    
    if idx == 2:
        break

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 4</h1>

</div>

## Task 5.0: Bonus exercises and further learning

- These exercises do not have todos, feel free to run them to get a sense of a few more tricks you can use for instance segmentation. 

##### **Local Shape Descriptors**

- Another example of auxiliary learning is [**LSDs**](https://localshapedescriptors.github.io/). This embedding encodes object shape similarly but is computed in a defined gaussian constrained to each label. This allows for consistent gradients regardless of object shapes which makes it a good candidate for segmentation of complex objects such as neurons in large electron microscopy datasets. 


- The LSDs are combined with nearest neighbor affinities to improve the boundary representations. The improved affinities then produce nice segmentations when using a hierarchical agglomeration approach and can be easily parallelized to allow for scaling to massive volumes. 

![example_image](static/lsd_schematic.png)

In [None]:
# import lsds, calculate on a small patch and visualize the descriptor components

from lsd.train import local_shape_descriptor

file = random.choice(train_nuclei)

nuclei = imread(file)[0:64, 0:64]
raw = imread(file.replace('_nuclei_masks', ''))[:, 0:64, 0:64]

#just to visualize
raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=nuclei,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

fig, axes = plt.subplots(
            1,
            6,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

In [None]:
# slightly modify our dataset just for simplicity

class TissueNetDataset(Dataset):
    def __init__(self,
                 root_dir,
                 split='train',
                 mask='nuclei',
                 crop_size=None,
                 val_split=False,
                 padding_size=8
                ):
        
        self.split = split
        self.mask_files = natsorted(glob(os.path.join(root_dir, split, f'*{mask}*')))
        self.raw_files = [i for i in natsorted(glob(os.path.join(root_dir, split, '*.tif'))) if 's.tif' not in i]
        self.crop_size = crop_size
        self.padding_size = padding_size
        
        if split == 'test':
            if val_split:
                self.mask_files = self.mask_files[:10]
                self.raw_files = self.raw_files[:10]
            else:
                self.mask_files = self.mask_files[10:]
                self.raw_files = self.raw_files[10:]

    def __len__(self):
        return len(self.raw_files)
    
    def get_padding(self, crop_size, padding_size):
    
        # quotient
        q = int(crop_size / padding_size)
    
        if crop_size % padding_size != 0:
            padding = (padding_size * (q + 1))
        else:
            padding = crop_size
    
        return padding
    
    def augment_data(self, raw, mask, padding):
        
        transform = A.Compose([
              A.RandomCrop(
                  width=self.crop_size,
                  height=self.crop_size),
              A.PadIfNeeded(
                  min_height=padding,
                  min_width=padding,
                  p=1,
                  border_mode=0),
              A.HorizontalFlip(p=0.3),
              A.VerticalFlip(p=0.3),
              A.RandomRotate90(p=0.3),
              A.Transpose(p=0.3),
              A.RandomBrightnessContrast(p=0.3)
            ])

        transformed = transform(image=raw, mask=mask)

        raw, mask = transformed['image'], transformed['mask']
        
        return raw, mask

    def __getitem__(self, idx):
        raw_file = self.raw_files[idx]
        mask_file = self.mask_files[idx]
        
        raw = imread(raw_file)
        mask = imread(mask_file)

        raw = raw.transpose([1,2,0])
        
        mask = np.expand_dims(mask, axis=0)
        mask = mask.transpose([1,2,0])
                
        # just do this regardless of split to make val/test faster for demo purposes
        padding = self.get_padding(self.crop_size, self.padding_size)
        raw, mask = self.augment_data(raw, mask, padding)
            
        raw = raw.transpose([2,0,1])
        mask = mask.transpose([2,0,1])
        
        mask, border = erode_border(
                    mask[0],
                    iterations=1,
                    border_value=1)

        affs = compute_affinities(mask, nhood=[[0,1],[1,0]])
                        
        lsds = local_shape_descriptor.get_local_shape_descriptors(
              segmentation=mask,
              sigma=(5,)*2,
              voxel_size=(1,)*2)

        lsds = lsds.astype(np.float32)
        affs = affs.astype(np.float32)
                                        
        return raw, lsds, affs

In [None]:
# visualize batch

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)

raw, lsds, affs = train_dataset[random.randrange(len(train_dataset))]

raw = np.vstack((raw, np.zeros_like(raw)[:1]))
raw = raw.transpose(1,2,0)

fig, axes = plt.subplots(
            1,
            7,
            figsize=(20, 20),
            sharex=False,
            sharey=True,
            squeeze=False)
  
axes[0][0].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][0].title.set_text('Mean offset Y')

axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][1].title.set_text('Mean offset X')

axes[0][2].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[0][2].title.set_text('Covariance Y-Y')

axes[0][3].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[0][3].title.set_text('Covariance X-X')

axes[0][4].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[0][4].title.set_text('Covariance Y-X')

axes[0][5].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[0][5].title.set_text('Size')

axes[0][6].imshow(np.squeeze(affs[0]+affs[1]), cmap='jet')
axes[0][6].title.set_text('Affs')

In [None]:
# we need two output heads for our network, one for lsds and one for affinities
# to do this we will subclass torch.nn.Module and create our UNet inside
# before we had a single final convolution. Now we have one for each head.
# then in the forward pass we pass our image through our unet and then the output through each head

class MtlsdModel(torch.nn.Module):

    def __init__(
        self,
        in_channels,
        num_fmaps,
        fmap_inc_factors,
        downsample_factors,
        padding='same'
    ):
        super().__init__()

        self.unet = UNet(
            in_channels=in_channels,
            num_fmaps=num_fmaps,
            fmap_inc_factors=fmap_inc_factors,
            downsample_factors=downsample_factors,
            padding=padding)

        self.lsd_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=6, kernel_size=1)
        self.aff_head = torch.nn.Conv2d(in_channels=num_fmaps,out_channels=2, kernel_size=1)

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

# We want to combine the lsds and affs losses and minimize the sum
# we can do this by subclassing our loss function (torch.nn.MSELoss) and overriding the forward method

class CombinedLoss(torch.nn.MSELoss):

    def __init__(self):
        super(CombinedLoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, affs_prediction, affs_target):

        loss1 = super(CombinedLoss, self).forward(lsds_prediction,lsds_target)
        loss2 = super(CombinedLoss, self).forward(affs_prediction, affs_target)
        
        return loss1 + loss2

In [None]:
torch.manual_seed(42)

d_factors = [[2,2],[2,2],[2,2]]

in_channels=2
num_fmaps=32
fmap_inc_factors=4

net = MtlsdModel(in_channels,num_fmaps,fmap_inc_factors,d_factors)

loss_fn = CombinedLoss().to(device)

net = net.to(device)

In [None]:
training_steps = 3000
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

net = net.to(device)
dtype = torch.FloatTensor

# set optimizer
learning_rate = 1e-4
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# set activation
activation = torch.nn.Sigmoid()

### create datasets

train_dataset = TissueNetDataset(root_dir='woodshole', split='train', crop_size=64)
test_dataset = TissueNetDataset(root_dir='woodshole', split='test', crop_size=128)
val_dataset = TissueNetDataset(root_dir='woodshole', split='test', val_split=True, crop_size=64)

batch_size = 4

# make dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

In [None]:
# update our training step to have two logits and two predictions

def model_step(model, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=True):
    
    # zero gradients if training
    if train_step:
        optimizer.zero_grad()
        
    # forward
    lsd_logits, affs_logits = model(feature)

    loss_value = loss_fn(lsd_logits, gt_lsds, affs_logits, gt_affs)
    
    # backward if training mode
    if train_step:
        loss_value.backward()
        optimizer.step()
        
    lsd_output = activation(lsd_logits)
    affs_output = activation(affs_logits)
   
    outputs = {
        'pred_lsds': lsd_output,
        'pred_affs': affs_output,
        'lsds_logits': lsd_logits,
        'affs_logits': affs_logits,
    }
    
    return loss_value, outputs

In [None]:
# update our training loop to do both lsds and affs

# set flags
net.train() 
loss_fn.train()
step = 0

with tqdm(total=training_steps) as pbar:
    while step < training_steps:
        # reset data loader to get random augmentations
        np.random.seed()
        tmp_loader = iter(train_loader)
        for feature, gt_lsds, gt_affs in tmp_loader:
            gt_lsds = gt_lsds.to(device)
            gt_affs = gt_affs.to(device)
            feature = feature.to(device)
                        
            #print(label.shape, feature.shape)
                    
            loss_value, pred = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation)
            writer.add_scalar('loss',loss_value.cpu().detach().numpy(),step)
            step += 1
            pbar.update(1)
            
            if step % 100 == 0:
                net.eval()
                tmp_val_loader = iter(test_loader)
                acc_loss = []
                for feature, gt_lsds, gt_affs in tmp_val_loader:                    
                    gt_lsds = gt_lsds.to(device)
                    gt_affs = gt_affs.to(device)
                    feature = feature.to(device)
                    loss_value, _ = model_step(net, loss_fn, optimizer, feature, gt_lsds, gt_affs, activation, train_step=False)
                    acc_loss.append(loss_value.cpu().detach().numpy())
                writer.add_scalar('val_loss',np.mean(acc_loss),step) 
                net.train()
                print(np.mean(acc_loss))

In [None]:
# visualize a few predictions - have the lsds helped to improve the affinities?
# For a future challenge you could try using a weighted combined loss and watershed + agglomeration to get strong segmentations

net.eval()

activation = torch.nn.Sigmoid()

for idx, (image, gt_lsds, gt_affs) in enumerate(test_loader):
    image = image.to(device)
    lsds_logits, affs_logits = net(image)
    pred_lsds = activation(lsds_logits)
    pred_affs = activation(affs_logits)
        
    image = np.squeeze(image.cpu())
    gt_lsds = np.squeeze(gt_lsds.cpu().numpy())
    gt_affs = np.squeeze(gt_affs.cpu().numpy())
    
    pred_lsds = np.squeeze(pred_lsds.cpu().detach().numpy())
    pred_affs = np.squeeze(pred_affs.cpu().detach().numpy())
    
    fig, axes = plt.subplots(1,3,figsize=(20, 20),sharex=True,sharey=True,squeeze=False)
    
    image = np.vstack((image, np.zeros_like(image)[:1]))
    image = image.transpose(1,2,0)
    
    axes[0][0].imshow(image)
    axes[0][0].title.set_text('Raw')
  
    axes[0][1].imshow(np.squeeze(pred_lsds[0]), cmap='jet')
    axes[0][1].imshow(np.squeeze(pred_lsds[1]), cmap='jet', alpha=0.5)
    axes[0][1].title.set_text('Mean offsets')

    axes[0][2].imshow(np.squeeze(pred_affs[0]+pred_affs[1]), cmap='jet')
    axes[0][2].title.set_text('Affs')
    
    if idx == 2:
        break

### Further learning

* Instance segmentation can be challenging and this exercise just scratches the surface of what is possible.


* This notebook assumes images that fit into memory but often times this is not the case (especially in biology). 
    1. To see an example for predicting over an image in chunks and stitching the results together, see this [notebook](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/3_tile_and_stitch.ipynb)
    2. For a more advanced library that makes it easier to do machine learning on massive datasets, see gunpowder (navigate to the tutorials, or browse the API): https://funkelab.github.io/gunpowder
    
    
* We did not cover more complex loss functions. Here are some nice explanations / implementations of other loss functions that are useful for instance segmentation: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook


* A more complex (but powerful) approach is called metric learning. This can be seen in last years [exercise](https://github.com/dlmbl/instance_segmentation/blob/2021-solutions/2_instance_segmentation.ipynb)


* We did not cover stardist in this tutorial, and barely scratched the surface on cellpose and lsds. For more tutorials on:
    1. Stardist: https://github.com/maweigert/tutorials/tree/main/stardist
    2. CellPose: https://github.com/MouseLand/cellpose#run-cellpose-10-without-local-python-installation
    3. LSDs: https://github.com/funkelab/lsd#notebooks
    
### Good luck on your instance segmentation endeavors!!