In [None]:
# some extra package installations for your conda environment

!pip install cython
!pip install git+git://github.com/funkelab/lsd.git
!pip install git+git://github.com/funkey/gunpowder.git
!pip install git+git://github.com/funkey/waterz.git
!pip install git+git://github.com/funkelab/funlib.segment.git
!pip install git+git://github.com/funkelab/daisy.git
!pip install git+git://github.com/funkelab/funlib.learn.torch.git
!pip uninstall -y tornado
!pip install tornado

In [None]:
from IPython.core.display import display, HTML
from IPython.display import Image

# display(HTML("<style>.container { width:100% !important; }</style>"))
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch

import glob
import random
import zarr
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True

In [None]:
# decompress data
from shutil import unpack_archive
unpack_archive(os.path.join('datasets','data_epithelia.tar.gz'), './')

In [None]:
Image(filename='utils/epithelia.png') 

In [None]:
'''
inspect the structure of data

 └── volumes
     ├── gt_affs (2, 256, 256) uint8--> affinities map, 2 channels
     ├── gt_fgbg (1, 256, 256) uint8--> foreground and background semantic segmentation
     ├── gt_labels (1, 256, 256) uint16--> instance segmentation, each instance has a different integer label
     ├── gt_tanh (1, 256, 256) float32-->squared distance transformation 
     └── raw (1, 256, 256) float32--> raw input for the model
'''

sample_path = glob.glob(os.path.join("data_epithelia", "train", "*.zarr"))
first_data = zarr.open(sample_path[0], 'r')
print(first_data.tree())

In [None]:
# we will use local shape descriptors as an auxiliary learning task for training affinities
# see here for info: https://localshapedescriptors.github.io/

from lsd import local_shape_descriptor

files = glob.glob(os.path.join("data_epithelia", "train", "*.zarr"))

f = zarr.open(files[random.randint(0,len(files)-1)])

raw = f['volumes/raw'][:,0:150,0:150]
labels = f['volumes/gt_labels'][:,0:150,0:150].astype(np.uint64)

labels = np.squeeze(labels)

# sigma is more or less the radius of the gaussian we want to grow around each voxel - 10 voxels is a good bet
# use an arbitrary voxel size of 1
lsds = local_shape_descriptor.get_local_shape_descriptors(
        segmentation=labels,
        sigma=(10,)*2,
        voxel_size=[1,1])

In [None]:
# utility function to view unique labels with matplotlib 

def create_lut(labels):

    max_label = np.max(labels)

    lut = np.random.randint(
            low=0,
            high=255,
            size=(int(max_label + 1), 3),
            dtype=np.uint64)

    lut = np.append(
            lut,
            np.zeros(
                (int(max_label + 1), 1),
                dtype=np.uint8) + 255,
            axis=1)

    lut[0] = 0
    colored_labels = lut[labels]

    return colored_labels

In [None]:
%matplotlib inline

fig, axes = plt.subplots(
        2,
        4,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][0].set_title('raw')

axes[0][1].imshow(create_lut(labels))
axes[0][1].set_title('labels')

axes[0][2].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][2].set_title('mean offset y')

axes[0][3].imshow(np.squeeze(lsds[1]), cmap='jet')
axes[0][3].set_title('mean offset x')

axes[1][0].imshow(np.squeeze(lsds[2]), cmap='jet')
axes[1][0].set_title('direction y')

axes[1][1].imshow(np.squeeze(lsds[3]), cmap='jet')
axes[1][1].set_title('direction x')

axes[1][2].imshow(np.squeeze(lsds[4]), cmap='jet')
axes[1][2].set_title('direction change xy')

axes[1][3].imshow(np.squeeze(lsds[5]), cmap='jet')
axes[1][3].set_title('size')

plt.show()

In [None]:
# we will use gunpowder for our training and prediction pipelines
# gunpowder is a library to facilitate machine learning on large, multi-dimensional arrays
# it can make it easier to create pipelines and handle augmentations
# here is a tutorial if you are interested in learning more: http://funkey.science/gunpowder/index.html

import gunpowder as gp

In [None]:
# wrapper for viewing our data with matplotlib

def imshow(
        raw=None,
        ground_truth=None,
        target=None,
        prediction=None,
        h=None,
        shader='jet',
        subplot=True,
        channel=0,
        target_name='target',
        prediction_name='prediction'):

    rows = 0

    if raw is not None:
        rows += 1
        cols = raw.shape[0] if len(raw.shape) > 2 else 1
    if ground_truth is not None:
        rows += 1
        cols = ground_truth.shape[0] if len(ground_truth.shape) > 2 else 1
    if target is not None:
        rows += 1
        cols = target.shape[0] if len(target.shape) > 2 else 1
    if prediction is not None:
        rows += 1
        cols = prediction.shape[0] if len(prediction.shape) > 2 else 1

    if subplot:
        fig, axes = plt.subplots(
            rows,
            cols,
            figsize=(10, 4),
            sharex=True,
            sharey=True,
            squeeze=False)

    if h is not None:
        fig.subplots_adjust(hspace=h)

    def wrapper(data,row,name="raw"):

        if subplot:
            if len(data.shape) == 2:
                if name == 'raw':
                    axes[0][0].imshow(data, cmap='gray')
                    axes[0][0].set_title(name)
                else:
                    axes[row][0].imshow(create_lut(data))
                    axes[row][0].set_title(name)

            elif len(data.shape) == 3:
                for i, im in enumerate(data):
                    if name == 'raw':
                        axes[0][i].imshow(im, cmap='gray')
                        axes[0][i].set_title(name)
                    else:
                        axes[row][i].imshow(create_lut(im))
                        axes[row][i].set_title(name)

            else:
                for i, im in enumerate(data):
                    axes[row][i].imshow(im[channel], cmap=shader)
                    axes[row][i].set_title(name)

        else:
            if name == 'raw':
                plt.imshow(data, cmap='gray')
            if name == 'labels':
                plt.imshow(data, alpha=0.5)

    row=0 

    if raw is not None:
        wrapper(raw,row=row)
        row += 1
    if ground_truth is not None:
        wrapper(ground_truth,row=row,name='labels')
        row += 1
    if target is not None:
        wrapper(target,row=row,name=target_name)
        row += 1
    if prediction is not None:
        wrapper(prediction,row=row,name=prediction_name)
        row += 1

    plt.show()

In [None]:
# import gunpowder local shape descriptor node

from lsd.gp import AddLocalShapeDescriptor
import math
import logging
logging.basicConfig(level=logging.INFO)

In [None]:
# gunpowder class to convert our labels data dtype (lsd node assumes np.uint64 but our labels data is np.uint16)

class ConvertLabels(gp.BatchFilter):

    def __init__(self, in_array, out_array, dtype):
        self.in_array = in_array
        self.out_array = out_array
        self.dtype = dtype

    def setup(self):

        self.provides(
            self.out_array,
            self.spec[self.in_array].copy())
        
    def prepare(self, request):

        deps = gp.BatchRequest()
        deps[self.in_array] = request[self.out_array].copy()

        return deps

    def process(self, batch, request):

        data = batch[self.in_array].data.astype(self.dtype)

        spec = batch[self.in_array].spec.copy()
        spec.roi = request[self.out_array].roi.copy()
        spec.dtype = self.dtype

        batch = gp.Batch()

        array = gp.Array(data, spec)

        batch[self.out_array] = array

        return batch

In [None]:
# create a training pipeline
# no training for now, let's just create our lsds and affinities and see how to implement some things in gunpowder

def train(
    iterations,
    show_every,
    batch_size,
    show_gt=True,
    show_pred=False,
    lsd_channels=None,
    aff_channels=None):
    
    # set arbitrary voxel size
    voxel_size = gp.Coordinate((1,)*2)
    
    # input and output size of our network (in voxels)
    input_size = gp.Coordinate((136,)*2)
    output_size = gp.Coordinate((96,)*2)
        
    # we will use valid padding so the context of our network will come in handy when writing data
    context = (input_size - output_size) /2
     
    # create some keys to keep track of our data
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    converted_labels = gp.ArrayKey('CONVERTED_LABELS')
    gt_lsds = gp.ArrayKey('GT_LSDS')
    gt_affs = gp.ArrayKey('GT_AFFS')
   
    # create a request to map our data to our patch sizes. The request behaves like a dictionary
    # mapping each array key to a region of interest (ROI), i.e., an offset and a size.
    request = gp.BatchRequest()

    # add our data to the request
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(converted_labels, output_size)
    request.add(gt_lsds, output_size)
    request.add(gt_affs, output_size)
        
    # get our training files
    files = glob.glob(os.path.join("data_epithelia", "train", "*.zarr"))

    # create a tuple of sources. for each source we will normalize the raw data
    # pad, and request a random location (equal to our input size - handled by the request)
    sources = tuple(
        gp.ZarrSource(
            sample,  
            {
                raw: 'volumes/raw',
                labels: 'volumes/gt_labels'
            },  
            {
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                labels: gp.ArraySpec(interpolatable=False, voxel_size=voxel_size)
            }) + 
            gp.Normalize(raw) +
            gp.Pad(raw, None) +
            gp.Pad(labels, context) +
            gp.RandomLocation()
            for sample in files
        )

    # right now our data has these shapes and dtypes:
    # raw: (1, h, w)
    # labels: (1, h, w) (dtype = np.uint16)

    # create a pipeline from our sources
    pipeline = sources

    # randomly choose a sample
    pipeline += gp.RandomProvider()
    
    # random mirror / transpose the batch
    pipeline += gp.SimpleAugment()

    # elastically deform 
    pipeline += gp.ElasticAugment(
        control_point_spacing=(32, 32),
        jitter_sigma=(5.0, 5.0),
        rotation_interval=(0, math.pi/2))

    # randomly scale the raw intensity
    pipeline += gp.IntensityAugment(
        raw,
        scale_min=0.9,
        scale_max=1.1,
        shift_min=-0.1,
        shift_max=0.1)
    
    # remove channel dim from labels so we can calculate lsds on them
    pipeline += gp.Squeeze([labels])
    
    # new shapes:
    # raw: (1, h, w)
    # labels: (h, w) (dtype = np.uint16)
    
    # convert labels to np.uint64, write to a new array (converted_labels)
    pipeline += ConvertLabels(
        labels,
        converted_labels,
        np.uint64)
    
    # raw: (1, h, w)
    # labels: (h, w) (dtype = np.uint64)
        
    # calculate lsds
    pipeline += AddLocalShapeDescriptor(
        converted_labels,
        gt_lsds,
        sigma=10,
        downsample=1)
    
    # raw: (1, h, w)
    # labels: (h, w)
    # gt_lsds: (6, h, w)
    
    # add affinities (nearest neighbor affinities so single voxel neighborhood in 2d)
    pipeline += gp.AddAffinities(
        affinity_neighborhood=[
            [0, -1],
            [-1, 0]],
        labels=converted_labels,
        affinities=gt_affs,
        dtype=np.float32)
    
    # raw: (1, h, w)
    # labels: (h, w)
    # gt_lsds: (6, h, w)
    # gt_affs: (2, h, w)

    # add a channel dim back to labels
    pipeline += gp.Unsqueeze([converted_labels])
    
    # raw: (1, h, w)
    # labels: (1, h, w)
    # gt_lsds: (6, h, w)
    # gt_affs: (2, h, w) 

    # stack batch size
    pipeline += gp.Stack(batch_size)

    # raw: (b, 1, h, w)
    # labels: (b, 1, h, w)
    # gt_lsds: (b, 6, h, w)
    # gt_affs: (b, 2, h, w) 
    
    # build pipeline
    with gp.build(pipeline):
        for i in range(iterations):
            batch = pipeline.request_batch(request)

            # crop raw data to labels roi for matplotlib 
            start = request[converted_labels].roi.get_begin()/voxel_size
            end = request[converted_labels].roi.get_end()/voxel_size

            # every nth iteration, show a batch
            if i % show_every == 0:
              
                imshow(raw=np.squeeze(batch[raw].data[:,:,start[0]:end[0],start[1]:end[1]]))
                imshow(ground_truth=batch[converted_labels].data)
            
                if lsd_channels:
                    for n,c in lsd_channels.items():
                        if show_gt:
                            imshow(target=batch[gt_lsds].data, target_name='gt '+n, channel=c)
                        if show_pred:
                            imshow(prediction=batch[pred_lsds].data, prediction_name='pred '+n, channel=c)

                if aff_channels:
                    for n,c in aff_channels.items():
                        if show_gt:
                            imshow(target=batch[gt_affs].data, target_name='gt '+n, channel=c)
                        if show_pred:
                            imshow(target=batch[pred_affs].data, target_name='pred '+n, channel=c)

In [None]:
# view a batch of ground truth lsds/affs, no need to show predicted lsds/affs yet

lsd_channels = {
    'offset (y)': 0,
    'offset (x)': 1,
    'orient (y)': 2,
    'orient (x)': 3,
    'yx change': 4,
    'voxel count': 5
}

aff_channels = {'affs y': 0, 'affs x': 1}

train(
    iterations=1,
    show_every=1,
    batch_size=5,
    lsd_channels=lsd_channels,
    aff_channels=aff_channels)

In [None]:
# we will use this library for our unet model
from funlib.learn.torch.models import UNet, ConvPass

In [None]:
# create a multi task lsd + affs model
# this is the idea behind the MTLSD model in the lsd paper
# we have two output heads of the unet (one for lsd and one for affs)
# we combine their losses and then minimize during training

class MtlsdModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        num_fmaps=12
        
        ds_fact = [(2,2),(2,2),(2,2)]
        num_levels = len(ds_fact) + 1
        ksd = [[(3,3), (3,3)]]*num_levels
        ksu = [[(3,3), (3,3)]]*(num_levels - 1)

        self.unet = UNet(
            in_channels=1,
            num_fmaps=num_fmaps,
            fmap_inc_factor=5,
            downsample_factors=ds_fact,
            kernel_size_down=ksd,
            kernel_size_up=ksu)
        
        # need 6 output channels (lsd on 2d data is 6 dimensional)
        # LSD[0:1] = mean offset in y
        # LSD[1:2] = mean offset in x
        # LSD[2:3] = orientation in y
        # LSD[3:4] = orientation in x
        # LSD[4:5] = change in orientation y-x
        # LSD[5:6] = size (voxel count)
        
        # and 2 for affs (x and y affinities in this case)

        self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')
        self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

# combine the lsds and affs losses
# use weighted loss for each (just multiply predictions and targets by weights)
class WeightedMSELoss(torch.nn.MSELoss):

    def __init__(self):
        super(WeightedMSELoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, lsds_weights, affs_prediction, affs_target, affs_weights,):

        loss1 = super(WeightedMSELoss, self).forward(
                lsds_prediction*lsds_weights,
                lsds_target*lsds_weights)

        loss2 = super(WeightedMSELoss, self).forward(
            affs_prediction*affs_weights,
            affs_target*affs_weights)
        
        return loss1 + loss2

In [None]:
# get a train node (this is a wrapper around the zero_grad, backward, step functions you have become familiar with)
from gunpowder.torch import Train

In [None]:
# add train node and associated keys to pipeline
def train(
    iterations,
    show_every,
    batch_size,
    show_gt=True,
    show_pred=False,
    lsd_channels=None,
    aff_channels=None):
    
    voxel_size = gp.Coordinate((1,)*2)
    
    input_size = gp.Coordinate((172,)*2)
    output_size = gp.Coordinate((80,)*2)
        
    context = (input_size - output_size) /2
        
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    converted_labels = gp.ArrayKey('CONVERTED_LABELS')
    gt_lsds = gp.ArrayKey('GT_LSDS')
    lsds_weights = gp.ArrayKey('LSDS_WEIGHTS')
    pred_lsds = gp.ArrayKey('PRED_LSDS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    affs_weights = gp.ArrayKey('AFFS_WEIGHTS')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    
    request = gp.BatchRequest()

    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(converted_labels, output_size)
    request.add(gt_lsds, output_size)
    request.add(lsds_weights, output_size)
    request.add(pred_lsds, output_size)
    request.add(gt_affs, output_size)
    request.add(affs_weights, output_size)
    request.add(pred_affs, output_size)
    
    # get model, loss, optimizer
    model = MtlsdModel()
    loss = WeightedMSELoss()
    optimizer = torch.optim.Adam(lr=0.5e-4, params=model.parameters())
        
    files = glob.glob(os.path.join("data_epithelia", "train", "*.zarr"))

    sources = tuple(
        gp.ZarrSource(
            sample,  
            {
                raw: 'volumes/raw',
                labels: 'volumes/gt_labels'
            },  
            {
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                labels: gp.ArraySpec(interpolatable=False, voxel_size=voxel_size)
            }) + 
            gp.Normalize(raw) +
            gp.Pad(raw, None) +
            gp.Pad(labels, context) +
            gp.RandomLocation()
            for sample in files
        )

    # raw: (1, h, w)
    # labels: (1, h, w) (dtype = np.uint16)

    pipeline = sources

    pipeline += gp.RandomProvider()
    
    pipeline += gp.SimpleAugment()

    pipeline += gp.ElasticAugment(
        control_point_spacing=(32, 32),
        jitter_sigma=(5.0, 5.0),
        rotation_interval=(0, math.pi/2))

    pipeline += gp.IntensityAugment(
        raw,
        scale_min=0.9,
        scale_max=1.1,
        shift_min=-0.1,
        shift_max=0.1)
    
    # raw: (1, h, w)
    # labels: (1, h, w) (dtype = np.uint16)
            
    pipeline += gp.Squeeze([labels])
    
    # raw: (1, h, w)
    # labels: (h, w) (dtype = np.uint16)
    
    pipeline += ConvertLabels(
        labels,
        converted_labels,
        np.uint64)
    
    # raw: (1, h, w)
    # labels: (1, h, w) (dtype = np.uint64)
    
    # erode the boundaries between labels, we want the network to learn what to do in between the labels
    pipeline += gp.GrowBoundary(converted_labels)
        
    pipeline += AddLocalShapeDescriptor(
        converted_labels,
        gt_lsds,
        mask=lsds_weights,
        sigma=10,
        downsample=1)
    
    # raw: (1, h, w)
    # labels: (h, w)
    # gt_lsds: (6, h, w)
    # lsds_weights: (6, h, w)
    
    pipeline += gp.AddAffinities(
        affinity_neighborhood=[
            [0, -1],
            [-1, 0]],
        labels=converted_labels,
        affinities=gt_affs,
        dtype=np.float32)
    
    # raw: (1, h, w)
    # labels: (h, w)
    # gt_lsds: (6, h, w)
    # lsds_weights: (6, h, w)
    # gt_affs: (2, h, w)
    
    # create a weights array for our affs
    pipeline += gp.BalanceLabels(
        gt_affs,
        affs_weights)
    
    # raw: (1, h, w)
    # labels: (h, w)
    # gt_lsds: (6, h, w)
    # lsds_weights: (6, h, w)
    # gt_affs: (2, h, w)
    # affs_weights: (2, h, w)

    pipeline += gp.Unsqueeze([converted_labels])
    
    # raw: (1, h, w)
    # labels: (1, h, w)
    # gt_lsds: (6, h, w)
    # lsds_weights: (6, h, w)
    # gt_affs: (2, h, w)
    # affs_weights: (2, h, w)

    pipeline += gp.Stack(batch_size)

    # raw: (b, 1, h, w)
    # labels: (b, 1, h, w)
    # gt_lsds: (b, 6, h, w)
    # lsds_weights: (b, 6, h, w)
    # gt_affs: (b, 2, h, w)
    # affs_weights: (b, 2, h, w)
    
    # pass everything to our train node
    pipeline += Train(
        model,
        loss,
        optimizer,
        inputs={
            'input': raw
        },
        outputs={
            0: pred_lsds,
            1: pred_affs
        },
        loss_inputs={
            0: pred_lsds,
            1: gt_lsds,
            2: lsds_weights,
            3: pred_affs,
            4: gt_affs,
            5: affs_weights
        })
    
    # raw: (b, 1, h, w)
    # labels: (b, 1, h, w)
    # gt_lsds: (b, 6, h, w)
    # lsds_weights: (b, 6, h, w)
    # gt_affs: (b, 2, h, w)
    # affs_weights: (b, 2, h, w)
    # pred_lsds: (b, 6, h, w)
    # pred_affs: (b, 2, h, w)
    
    with gp.build(pipeline):
        for i in range(iterations):
            batch = pipeline.request_batch(request)

            start = request[converted_labels].roi.get_begin()/voxel_size
            end = request[converted_labels].roi.get_end()/voxel_size

            if i % show_every == 0:
              
                imshow(raw=np.squeeze(batch[raw].data[:,:,start[0]:end[0],start[1]:end[1]]))
                imshow(ground_truth=batch[converted_labels].data)
            
                if lsd_channels:
                    for n,c in lsd_channels.items():
                        if show_gt:
                            imshow(target=batch[gt_lsds].data, target_name='gt '+n, channel=c)
                        if show_pred:
                            imshow(prediction=batch[pred_lsds].data, prediction_name='pred '+n, channel=c)

                if aff_channels:
                    for n,c in aff_channels.items():
                        if show_gt:
                            imshow(target=batch[gt_affs].data, target_name='gt '+n, channel=c)
                        if show_pred:
                            imshow(target=batch[pred_affs].data, target_name='pred '+n, channel=c)

In [None]:
# show predictions now. 
# feel free to train for longer - it will couple thousand iterations for the predictions to improve
# we will just use a pre-trained network for inference

train(
    iterations=2000,
    show_every=100,
    batch_size=5,
    show_pred=True,
    lsd_channels=lsd_channels,
    aff_channels=aff_channels)

In [None]:
# this library has some nice wrappers for data loading and cropping rois
import daisy

In [None]:
voxel_size = gp.Coordinate((1,)*2)
    
input_size = gp.Coordinate((172,)*2)
output_size = gp.Coordinate((80,)*2)

context = (input_size - output_size) / 2

In [None]:
# create a predict pipeline (this is also known as inference)
def predict(
    checkpoint,
    raw_file,
    raw_dataset,
    labels_dataset,
    out_file,
    out_datasets):
    
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    pred_lsds = gp.ArrayKey('PRED_LSDS')
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # create a scan request. we will tile over our full raw image size in patches of our network size
    scan_request = gp.BatchRequest()

    scan_request.add(raw, input_size)
    scan_request.add(labels, input_size)
    scan_request.add(pred_lsds, output_size)
    scan_request.add(pred_affs, output_size)

    source = gp.ZarrSource(
        raw_file,
            {
                raw: raw_dataset,
                labels: labels_dataset
            },
            {
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                labels: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size)
            }
        )
    
    # since our network uses valid convolutions, our network output size is smaller than our input size
    # but we want to evaluate on the full image.
    # so we need to grow the input by the context of the network and pad our data
    
    with gp.build(source):
        total_input_roi = source.spec[raw].roi.grow(context, context)
        total_output_roi = source.spec[raw].roi
        
    source += gp.Pad(raw, context)
    source += gp.Pad(labels, context)
        
    # create empty zarr datasets to write our data to
    for ds in out_datasets:
        if 'lsd' in ds:
            dims=6
        elif 'affs' in ds:
            dims=2
        else:
            dims=1
            
        dtype = np.uint64 if 'labels' in ds else np.float32
        
        daisy.prepare_ds(
                out_file,
                ds,
                daisy.Roi(
                    total_output_roi.get_offset(),
                    total_output_roi.get_shape()
                ),
                voxel_size,
                write_size=output_size,
                num_channels=dims,
                dtype=dtype)

    model = MtlsdModel()

    # set model to eval mode
    model.eval()

    # add a predict node
    predict = gp.torch.Predict(
        model=model,
        checkpoint=checkpoint,
        inputs = {
            'input': raw
        },
        outputs = {
            0: pred_lsds,
            1: pred_affs}
        )
    
    # this will scan in chunks equal to the input/output sizes of the respective arrays
    scan = gp.Scan(scan_request)

    # write out data
    write = gp.ZarrWrite(
        dataset_names={
            raw: 'raw',
            labels: 'labels',
            pred_lsds: 'pred_lsds',
            pred_affs: 'pred_affs'},
        output_filename=out_file)
    
    pipeline = source
    pipeline += gp.Normalize(raw)
    
    # only need a batch size of 1 for prediction
    pipeline += gp.Stack(1)
    pipeline += predict
    pipeline += scan
    
    # remove batch dim from raw and labels
    pipeline += gp.Squeeze([raw, labels])
    
    # remove channel dim from raw, labels and batch dim from pred lsds / pred_affs
    pipeline += gp.Squeeze(
        [raw,
         labels,
         pred_lsds,
         pred_affs])
    pipeline += write

    # we have another request for the full roi we are scanning over
    predict_request = gp.BatchRequest()

    # this lets us know to process the full image. we will scan over it until it is done
    predict_request[raw] = total_input_roi
    predict_request[labels] = total_input_roi
    predict_request[pred_lsds] = total_output_roi
    predict_request[pred_affs] = total_output_roi

    with gp.build(pipeline):
        pipeline.request_batch(predict_request)

In [None]:
# get all our testing data
test_files = glob.glob(os.path.join("data_epithelia", "test", "*.zarr"))

In [None]:
# fetch a pretrained model checkpoint
!wget https://www.dropbox.com/s/2avnabaftjwqqgs/model_checkpoint_10000

In [None]:
# load checkpoint and predict on a random image from the test set
checkpoint = 'model_checkpoint_10000' 
raw_file = test_files[random.randint(0,len(test_files)-1)]
raw_dataset = 'volumes/raw'
labels_dataset = 'volumes/gt_labels'
out_file = 'prediction.zarr'
out_datasets = ['raw', 'labels', 'pred_lsds', 'pred_affs']

predict(
    checkpoint,
    raw_file,
    raw_dataset,
    labels_dataset,
    out_file,
    out_datasets)

In [None]:
# load the data to visualize
pred_f = zarr.open('prediction.zarr')
test_raw = pred_f['raw'][:]
test_labels = pred_f['labels'][:].astype(np.uint64)
test_lsds = pred_f['pred_lsds'][:]
test_affs = pred_f['pred_affs'][:]

In [None]:
fig, axes = plt.subplots(
        1,
        4,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

axes[0][0].imshow(test_raw, cmap='gray')
axes[0][0].set_title('raw')

axes[0][1].imshow(create_lut(test_labels))
axes[0][1].set_title('labels')

axes[0][2].imshow(np.squeeze(test_lsds[0]), cmap='jet')
axes[0][2].imshow(np.squeeze(test_lsds[1]), cmap='jet', alpha=0.5)
axes[0][2].set_title('predicted offset vectors')

axes[0][3].imshow(np.squeeze(test_affs[0] + test_affs[1]), cmap='jet')
axes[0][3].set_title('predicted affinities')

plt.show()

In [None]:
# function to relabel connected components
from scipy.ndimage import label

In [None]:
# use daisy to load our data because it is easier to crop data to a specific roi
# in this case we want to crop a few pixels of the edge of our image
# because of border artifacts from prediction (bc of valid convolutions)

raw = daisy.open_ds('prediction.zarr', 'raw')
labels = daisy.open_ds('prediction.zarr', 'labels')
lsds = daisy.open_ds('prediction.zarr', 'pred_lsds')
affs =  daisy.open_ds('prediction.zarr', 'pred_affs')

# crop a little off the edges 
crop = gp.Coordinate((8,)*2)
roi = raw.roi.grow(-crop, -crop)

# intersect with roi and convert to numpy array
raw = raw[roi].to_ndarray()
labels = labels[roi].to_ndarray()
lsds = lsds[roi].to_ndarray()
affs = affs[roi].to_ndarray()

# remember to relabel gt since we cropped it
labels, _ = label(labels)
labels = labels.astype(np.uint64)

fig, axes = plt.subplots(
        1,
        4,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

axes[0][0].imshow(raw, cmap='gray')
axes[0][0].set_title('raw')

axes[0][1].imshow(create_lut(labels))
axes[0][1].set_title('labels')

axes[0][2].imshow(np.squeeze(lsds[0]), cmap='jet')
axes[0][2].imshow(np.squeeze(lsds[1]), cmap='jet', alpha=0.5)
axes[0][2].set_title('predicted offset vectors')

axes[0][3].imshow(np.squeeze(affs[0] + affs[1]), cmap='jet')
axes[0][3].set_title('predicted affinities')

plt.show()

In [None]:
# these libraries will be useful for performing watershed and getting a segmentation

import waterz
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import distance_transform_edt
from skimage.segmentation import watershed

In [None]:
# utility functions to get supervoxels (fragments) from affinities

def watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask,
        return_seeds=False,
        id_offset=0,
        min_seed_distance=10):

    max_filtered = maximum_filter(boundary_distances, min_seed_distance)
    maxima = max_filtered==boundary_distances
    seeds, n = label(maxima)

    print(f"Found {n} fragments")

    if n == 0:
        return np.zeros(boundary_distances.shape, dtype=np.uint64), id_offset

    seeds[seeds!=0] += id_offset

    fragments = watershed(
        boundary_distances.max() - boundary_distances,
        seeds,
        mask=boundary_mask)

    ret = (fragments.astype(np.uint64), n + id_offset)
    if return_seeds:
        ret = ret + (seeds.astype(np.uint64),)

    return ret

def watershed_from_affinities(
        affs,
        max_affinity_value=1.0,
        fragments_in_xy=True,
        return_seeds=False,
        min_seed_distance=10,
        labels_mask=None):

    mean_affs = 0.5*(affs[1] + affs[2])
    depth = mean_affs.shape[0]

    fragments = np.zeros(mean_affs.shape, dtype=np.uint64)
    if return_seeds:
        seeds = np.zeros(mean_affs.shape, dtype=np.uint64)

    id_offset = 0

    for z in range(depth):

        boundary_mask = mean_affs[z]>0.5*max_affinity_value
        boundary_distances = distance_transform_edt(boundary_mask)

        if labels_mask is not None:

            boundary_mask *= labels_mask.astype(bool)

        ret = watershed_from_boundary_distance(
            boundary_distances,
            boundary_mask,
            return_seeds=return_seeds,
            id_offset=id_offset,
            min_seed_distance=min_seed_distance)

        fragments[z] = ret[0]
        if return_seeds:
            seeds[z] = ret[2]

        id_offset = ret[1]

    ret = (fragments, id_offset)
    if return_seeds:
        ret += (seeds,)

    return ret

In [None]:
# utility function to agglomerate fragments using underlying affinities as edge weights
# returns a segmentation from a final threshold

def get_segmentation(affinities, threshold, labels_mask=None):

    fragments = watershed_from_affinities(
            affinities,
            labels_mask=labels_mask)[0]

    thresholds = [threshold]

    generator = waterz.agglomerate(
        affs=affinities.astype(np.float32),
        fragments=fragments,
        thresholds=thresholds,
    )

    segmentation = next(generator)

    return segmentation

In [None]:
# watershed assumes 3d arrays, create fake channel dim
affs = np.stack([
    np.zeros_like(affs[0]),
    affs[0],
    affs[1]]
)

# waterz agglomerate requires 4d affs (c, d, h, w) - add fake z dim
affs = np.expand_dims(affs, axis=1)

#just test a 0.5 threshold. higher thresholds will merge more, lower thresholds will split more
threshold = 0.5

segmentation = get_segmentation(affs, threshold)

In [None]:
# this is useful for removing small holes in an array
from skimage.morphology import remove_small_holes

In [None]:
# remove small holes and relabel connected components
seg, _ = label(
        remove_small_holes(
                segmentation.astype(bool),
                area_threshold=256))

seg = seg.astype(np.uint64)

In [None]:
fig, axes = plt.subplots(
        1,
        3,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

axes[0][0].imshow(raw, cmap='gray')
axes[0][0].set_title('raw')

axes[0][1].imshow(create_lut(labels))
axes[0][1].set_title('labels')

axes[0][2].imshow(create_lut(np.squeeze(seg)))
axes[0][2].set_title('segmentation')

plt.show()

In [None]:
# this package has a useful function for eroding boundaries
from scipy import ndimage

In [None]:
# we want to erode the ground truth boundaries before we evaluate 
def erode_boundaries(
        labels,
        iterations=1,
        border_value=1):

    foreground = np.zeros(shape=labels.shape, dtype=bool)

    for label in np.unique(labels):
        if label == 0:
            continue

        label_mask = labels ==label

        eroded_label_mask = ndimage.binary_erosion(label_mask, iterations=iterations, border_value=border_value)
        foreground = np.logical_or(eroded_label_mask, foreground)

    background = np.logical_not(foreground)
    labels[background] = 0

    return labels

In [None]:
# create eroded labels
eroded_labels = erode_boundaries(labels, iterations=2)

In [None]:
fig, axes = plt.subplots(
        1,
        3,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

axes[0][0].imshow(raw, cmap='gray')
axes[0][0].set_title('raw')

axes[0][1].imshow(create_lut(eroded_labels))
axes[0][1].set_title('eroded labels')

axes[0][2].imshow(create_lut(np.squeeze(seg)))
axes[0][2].set_title('segmentation')

plt.show()

In [None]:
# get util function for evaluation
from utils.evaluate import evaluate_linear_sum_assignment

In [None]:
# get metrics for this image
ap, precision, recall, tp, fp, fn = evaluate_linear_sum_assignment(seg,eroded_labels)

In [None]:
print(ap, precision, recall, tp, fp, fn)

In [None]:
# now lets do it for all images and get the averages for the metrics

avg_ap = 0.0
avg_precision = 0.0
avg_recall = 0.0
avg_tp = 0.0
avg_fp = 0.0
avg_fn = 0.0

for index, f in enumerate(test_files):
    
    print(f'running prediction on {f}')
        
    out_file = f'prediction_{index}.zarr'

    predict(
        checkpoint,
        f,
        raw_dataset,
        labels_dataset,
        out_file,
        out_datasets)
    
    raw = daisy.open_ds(out_file, 'raw')
    labels = daisy.open_ds(out_file, 'labels')
    lsds = daisy.open_ds(out_file, 'pred_lsds')
    affs =  daisy.open_ds(out_file, 'pred_affs')

    # crop a little off the edges 
    crop = gp.Coordinate((8,)*2)

    roi = raw.roi.grow(-crop, -crop)

    raw = raw[roi].to_ndarray()
    labels = labels[roi].to_ndarray()
    lsds = lsds[roi].to_ndarray()
    affs = affs[roi].to_ndarray()

    # remember to relabel gt since we cropped it
    labels, _ = label(labels)
    labels = labels.astype(np.uint64)

    eroded_labels = erode_boundaries(labels, iterations=2)
    
    affs = np.stack([
        np.zeros_like(affs[0]),
        affs[0],
        affs[1]])

    affs = np.expand_dims(affs, axis=1)

    threshold = 0.5

    segmentation = get_segmentation(affs, threshold)
    
    segmentation, _ = label(
        remove_small_holes(
                segmentation.astype(bool),
                area_threshold=256))

    segmentation = segmentation.astype(np.uint64)
    
    ap, precision, recall, tp, fp, fn = evaluate_linear_sum_assignment(segmentation,eroded_labels)
    
    avg_ap += ap
    avg_precision += precision
    avg_recall += recall
    avg_tp += tp
    avg_fp += fp
    avg_fn += fn
    
    print('\n', f'metrics on {f}', '\n')
    
    print(f'ap: {ap}')
    print(f'precision: {precision}')
    print(f'recall: {recall}')
    print(f'tp: {tp}')
    print(f'fp: {fp}')
    print(f'fn: {fn}')
    
    fig, axes = plt.subplots(
        1,
        5,
        figsize=(15, 8),
        sharex=True,
        sharey=True,
        squeeze=False)

    axes[0][0].imshow(raw, cmap='gray')
    axes[0][0].set_title('raw')
    
    axes[0][1].imshow(np.squeeze(lsds[0]), cmap='jet')
    axes[0][1].imshow(np.squeeze(lsds[1]), cmap='jet', alpha=0.5)
    axes[0][1].set_title('predicted offset vectors')

    axes[0][2].imshow(np.squeeze(affs[0] + affs[1]), cmap='jet')
    axes[0][2].set_title('predicted affs')

    axes[0][3].imshow(create_lut(eroded_labels))
    axes[0][3].set_title('labels')

    axes[0][4].imshow(create_lut(np.squeeze(segmentation)))
    axes[0][4].set_title('segmentation')

    plt.show()
        
avg_ap /= (index+1)
avg_precision /= (index+1)
avg_recall /= (index+1)
avg_tp /= (index+1)
avg_fp /= (index+1)
avg_fn /= (index+1)

print(f'Average ap: {avg_ap}')
print(f'Average precision: {avg_precision}')
print(f'Average recall: {avg_recall}')
print(f'Average tp: {avg_tp}')
print(f'Average fp: {avg_fp}')
print(f'Average fn: {avg_fn}')