*  Before starting, click "Runtime" in the top panel, select "Change runtime type" and then choose "GPU"

*  This tutorial follows the inference tutorial, and is therefore condensed. Check out the inference tutorial (**inference.ipynb**) if there is any confusion throughout

*  Try running each cell consecutively to see what is happening before changing things around

*  Some cells are collapsed by default, these are generally utility functions or are expanded by defaullt in a previous tutorial. Double click to expand/collapse

*  sometimes colab can be slow, if this happens you may need to restart the runtime. also, you generally can only run one session at a time.

In [None]:
#@title install packages + repos

# packages
!pip install matplotlib
!pip install scikit-image
!pip install scipy
!pip install torch
!pip install zarr

# repos
!pip install git+git://github.com/funkelab/daisy.git
!pip install git+git://github.com/funkey/gunpowder.git
!pip install git+git://github.com/funkelab/funlib.learn.torch.git
!pip install git+git://github.com/funkey/waterz.git

In [None]:
#@title fetch model checkpoint

!git clone https://github.com/funkelab/lsd.git
!cd lsd && git checkout tutorial

In [None]:
#@title import packages

import daisy
import gunpowder as gp
import h5py
import io
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import requests
import torch
import waterz
import zarr

from funlib.learn.torch.models import UNet, ConvPass
from gunpowder.torch import Predict
from scipy.ndimage import label
from scipy.ndimage import measurements
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import distance_transform_edt
from skimage.segmentation import watershed

%matplotlib inline
logging.basicConfig(level=logging.INFO)

In [None]:
#@title utility function to view labels

# matplotlib uses a default shader
# we need to recolor as unique objects

def create_lut(labels):

    max_label = np.max(labels)

    lut = np.random.randint(
            low=0,
            high=255,
            size=(int(max_label + 1), 3),
            dtype=np.uint64)

    lut = np.append(
            lut,
            np.zeros(
                (int(max_label + 1), 1),
                dtype=np.uint8) + 255,
            axis=1)

    lut[0] = 0
    colored_labels = lut[labels]

    return colored_labels

In [None]:
#@title utility  function to download / save data as zarr
def create_data(
    url, 
    name, 
    offset, 
    resolution,
    sections=None,
    squeeze=True):

  in_f = h5py.File(io.BytesIO(requests.get(url).content), 'r')

  raw = in_f['volumes/raw']
  labels = in_f['volumes/labels/neuron_ids']
  
  f = zarr.open(name, 'a')

  if sections is None:
    sections=range(raw.shape[0]-1)

  for i, r in enumerate(sections):

    print(f'Writing data for section {r}')

    raw_slice = raw[r:r+1,:,:]
    labels_slice = labels[r:r+1,:,:]

    if squeeze:
      raw_slice = np.squeeze(raw_slice)
      labels_slice = np.squeeze(labels_slice)

    f[f'raw/{i}'] = raw_slice
    f[f'labels/{i}'] = labels_slice

    f[f'raw/{i}'].attrs['offset'] = offset
    f[f'raw/{i}'].attrs['resolution'] = resolution

    f[f'labels/{i}'].attrs['offset'] = offset
    f[f'labels/{i}'].attrs['resolution'] = resolution

In [None]:
# get first section

create_data(
    'https://cremi.org/static/data/sample_A_20160501.hdf',
    'testing_data.zarr',
    offset=[0,0],
    resolution=[4,4],
    sections=[0],
    squeeze=False)

In [None]:
#@title inference wrapper
voxel_size = gp.Coordinate((4, 4))

input_shape = gp.Coordinate((164, 164))
output_shape = gp.Coordinate((124, 124))

input_size = input_shape * voxel_size
output_size = output_shape * voxel_size

# total roi of image to predict on
total_roi = gp.Coordinate((1250,1250))*voxel_size

num_fmaps=12

class MtlsdModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.unet = UNet(
            in_channels=1,
            num_fmaps=num_fmaps,
            fmap_inc_factor=5,
            downsample_factors=[
                [2, 2],
                [2, 2]],
            kernel_size_down=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]],
            kernel_size_up=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]])

        self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')
        self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

class WeightedMSELoss(torch.nn.MSELoss):

    def __init__(self):
        super(WeightedMSELoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, lsds_weights, affs_prediction, affs_target, affs_weights,):

        loss1 = super(WeightedMSELoss, self).forward(
                lsds_prediction*lsds_weights,
                lsds_target*lsds_weights)

        loss2 = super(WeightedMSELoss, self).forward(
            affs_prediction*affs_weights,
            affs_target*affs_weights)
        
        return loss1 + loss2

def predict(
    checkpoint,
    raw_file,
    raw_dataset,
    out_file,
    out_lsds,
    out_affs):
  
  raw = gp.ArrayKey('RAW')
  pred_lsds = gp.ArrayKey('PRED_LSDS')
  pred_affs = gp.ArrayKey('PRED_AFFS')

  if os.path.exists(out_file):
    mode='r+'
  else:
    mode='w'

  of = zarr.open(out_file, mode=mode)
  rd = zarr.open(raw_file)[raw_dataset]

  for ds in [out_lsds, out_affs]:
      if ds not in of:
          if 'lsd' in ds:
              dims = 6
          else:
              dims = 2
          out_ds = of.create_dataset(
                  ds,
                  shape= (dims,) + np.squeeze(rd).shape,
                  dtype=np.float32)
          out_ds.attrs['resolution'] = voxel_size

  scan_request = gp.BatchRequest()

  scan_request.add(raw, input_size)
  scan_request.add(pred_lsds, output_size)
  scan_request.add(pred_affs, output_size)

  source = gp.ZarrSource(
              raw_file,
          {
              raw: raw_dataset
          },
          {
              raw: gp.ArraySpec(interpolatable=True)
          })

  model = MtlsdModel()
  model.eval()

  predict = gp.torch.Predict(
      model=model,
      checkpoint=checkpoint,
      inputs = {
                'input': raw
      },
      outputs = {
          0: pred_lsds,
          1: pred_affs})
  
  scan = gp.Scan(scan_request)

  write = gp.ZarrWrite(
      dataset_names={
          pred_lsds: out_lsds,
          pred_affs: out_affs},
      output_filename=out_file)
  
  pipeline = source
  pipeline += gp.Normalize(raw)
  pipeline += gp.Stack(1)
  pipeline += predict
  pipeline += scan
  pipeline += gp.Squeeze([pred_lsds, pred_affs])
  pipeline += write

  predict_request = gp.BatchRequest()

  predict_request.add(raw, total_roi)
  predict_request.add(pred_lsds, total_roi)
  predict_request.add(pred_affs, total_roi)

  with gp.build(pipeline):
      pipeline.request_batch(predict_request)

In [None]:
checkpoint = 'lsd/lsd/tutorial/notebooks/model_checkpoint_50000' 
raw_file = 'testing_data.zarr'
raw_dataset = 'raw/0'
out_file = 'prediction.zarr'
out_lsds = 'pred_lsds/0'
out_affs = 'pred_affs/0'

predict(
    checkpoint,
    raw_file,
    raw_dataset,
    out_file,
    out_lsds,
    out_affs)

In [None]:
test_f = zarr.open('testing_data.zarr')
predict_f = zarr.open('prediction.zarr')

raw = test_f['raw/0'][:]
labels = test_f['labels/0'][:]

pred_affs = predict_f['pred_affs/0'][:]
pred_lsds = predict_f['pred_lsds/0'][:]

fig, axes = plt.subplots(
            1,
            4,
            figsize=(20, 6),
            sharex=True,
            sharey=True,
            squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][1].imshow(create_lut(np.squeeze(labels)))
axes[0][2].imshow(np.squeeze(pred_affs[0:1,:,:]), cmap='jet')
axes[0][3].imshow(np.squeeze(pred_lsds[0:1,:,:]), cmap='jet')
axes[0][3].imshow(np.squeeze(pred_lsds[1:2,:,:]), cmap='jet', alpha=0.5)

In [None]:
#@title watershed wrappers
def watershed_from_boundary_distance(
        boundary_distances,
        boundary_mask,
        return_seeds=False,
        id_offset=0,
        min_seed_distance=10):

    max_filtered = maximum_filter(boundary_distances, min_seed_distance)
    maxima = max_filtered==boundary_distances
    seeds, n = label(maxima)

    print(f"Found {n} fragments")

    if n == 0:
        return np.zeros(boundary_distances.shape, dtype=np.uint64), id_offset

    seeds[seeds!=0] += id_offset

    fragments = watershed(
        boundary_distances.max() - boundary_distances,
        seeds,
        mask=boundary_mask)

    ret = (fragments.astype(np.uint64), n + id_offset)
    if return_seeds:
        ret = ret + (seeds.astype(np.uint64),)

    return ret

def watershed_from_affinities(
        affs,
        max_affinity_value=1.0,
        fragments_in_xy=True,
        return_seeds=False,
        min_seed_distance=10,
        labels_mask=None):

    mean_affs = 0.5*(affs[1] + affs[2])
    depth = mean_affs.shape[0]

    fragments = np.zeros(mean_affs.shape, dtype=np.uint64)
    if return_seeds:
        seeds = np.zeros(mean_affs.shape, dtype=np.uint64)

    id_offset = 0

    for z in range(depth):

        boundary_mask = mean_affs[z]>0.5*max_affinity_value
        boundary_distances = distance_transform_edt(boundary_mask)

        if labels_mask is not None:

            boundary_mask *= labels_mask.astype(bool)

        ret = watershed_from_boundary_distance(
            boundary_distances,
            boundary_mask,
            return_seeds=return_seeds,
            id_offset=id_offset,
            min_seed_distance=min_seed_distance)

        fragments[z] = ret[0]
        if return_seeds:
            seeds[z] = ret[2]

        id_offset = ret[1]

    ret = (fragments, id_offset)
    if return_seeds:
        ret += (seeds,)

    return ret

In [None]:
#@title segmentation wrapper
def get_segmentation(affinities, threshold, labels_mask=None):

    fragments = watershed_from_affinities(
            affinities,
            labels_mask=labels_mask)[0]

    thresholds = [threshold]

    generator = waterz.agglomerate(
        affs=affinities.astype(np.float32),
        fragments=fragments,
        thresholds=thresholds,
    )

    segmentation = next(generator)

    return segmentation

In [None]:
affs = daisy.open_ds("prediction.zarr", "pred_affs/0")
labels = daisy.open_ds("testing_data.zarr", "labels/0")

# affs shape: c, h, w (2, h, w)
# labels shape: d, h, w (fake d (1))

# get affs meta data
roi = affs.roi
voxel_size = affs.voxel_size
offset = affs.roi.get_begin()

# get labels back to 2d daisy array
labels.materialize()

labels_roi = daisy.Roi(
        (labels.roi.get_begin()[0], labels.roi.get_begin()[1]),
        (labels.roi.get_shape()[0], labels.roi.get_shape()[1]))

labels = daisy.Array(
        labels.data.reshape(
            labels.shape[1],
            labels.shape[2]),
        labels_roi,
        voxel_size)

# intersect labels with affs roi (affs were predicted on smaller roi)
labels = labels[roi]

# convert to numpy arrays
affs = affs.to_ndarray()
labels = labels.to_ndarray()

# watershed assumes 3d arrays, create fake channel dim
affs = np.stack([
    np.zeros_like(affs[0]),
    affs[0],
    affs[1]]
)

# affs shape: 3, h, w

# waterz agglomerate requires 4d affs (c, d, h, w) - add fake z dim
affs = np.expand_dims(affs, axis=1)

#affs shape: 3, 1, h, w

#just test a 0.5 threshold. higher thresholds will merge more, lower thresholds will split more
threshold = 0.5

segmentation = get_segmentation(affs, threshold, labels_mask=labels)

In [None]:
fig, axes = plt.subplots(
            1,
            5,
            figsize=(20, 6),
            sharex=True,
            sharey=True,
            squeeze=False)

axes[0][0].imshow(np.squeeze(raw), cmap='gray')
axes[0][1].imshow(create_lut(np.squeeze(labels)))
axes[0][2].imshow(np.squeeze(pred_affs[0:1,:,:] + pred_affs[1:2,:,:]), cmap='jet')
axes[0][3].imshow(np.squeeze(pred_lsds[0:1,:,:]), cmap='jet')
axes[0][3].imshow(np.squeeze(pred_lsds[1:2,:,:]), cmap='jet', alpha=0.5)
axes[0][4].imshow(create_lut(np.squeeze(segmentation)))