*  Before starting, click "Runtime" in the top panel, select "Change runtime type" and then choose "GPU"

*  This tutorial follows the mtlsd tutorial, and is therefore condensed. Check out the mtlsd tutorial (**train_mtlsd.ipynb**) if there is any confusion throughout

*  Try running each cell consecutively to see what is happening before changing things around

*  Some cells are collapsed by default, these are generally utility functions or are expanded by defaullt in a previous tutorial. Double click to expand/collapse

*  sometimes colab can be slow, if this happens you may need to restart the runtime. also, you generally can only run one session at a time.

In [None]:
#@title install packages + repos

# packages
!pip install gunpowder
!pip install matplotlib
!pip install torch
!pip install zarr

# repos
!pip install git+https://github.com/funkelab/funlib.learn.torch.git

In [None]:
#@title import packages

import gunpowder as gp
import h5py
import io
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import requests
import torch
import zarr

from funlib.learn.torch.models import UNet, ConvPass
from gunpowder.torch import Predict

%matplotlib inline
logging.basicConfig(level=logging.INFO)

In [None]:
#@title utility function to view labels

# matplotlib uses a default shader
# we need to recolor as unique objects

def create_lut(labels):

    max_label = np.max(labels)

    lut = np.random.randint(
            low=0,
            high=255,
            size=(int(max_label + 1), 3),
            dtype=np.uint8)

    lut = np.append(
            lut,
            np.zeros(
                (int(max_label + 1), 1),
                dtype=np.uint8) + 255,
            axis=1)

    lut[0] = 0
    colored_labels = lut[labels]

    return colored_labels

In [None]:
#@title utility  function to download / save data as zarr

def create_data(
    url, 
    name, 
    offset, 
    resolution,
    sections=None,
    squeeze=True):

  in_f = h5py.File(io.BytesIO(requests.get(url).content), 'r')

  raw = in_f['volumes/raw']
  labels = in_f['volumes/labels/neuron_ids']
  
  container = zarr.open(name, 'a')

  if sections is None:
    sections=range(raw.shape[0]-1)

  for index, section in enumerate(sections):

    print(f'Writing data for section {section}')

    raw_slice = raw[section]
    labels_slice = labels[section]

    if squeeze:
      raw_slice = np.squeeze(raw_slice)
      labels_slice = np.squeeze(labels_slice)

    for ds_name, data in [
        ('raw', raw_slice),
        ('labels', labels_slice)]:
        
        container[f'{ds_name}/{index}'] = data
        container[f'{ds_name}/{index}'].attrs['offset'] = offset
        container[f'{ds_name}/{index}'].attrs['resolution'] = resolution

In [None]:
# fetch a random section

create_data(
    'https://cremi.org/static/data/sample_A_20160501.hdf',
    'testing_data.zarr',
    offset=[0,0],
    resolution=[4,4],
    sections=random.sample(range(0,124),1),
    squeeze=False)

In [None]:
voxel_size = gp.Coordinate((4, 4))

input_shape = gp.Coordinate((164, 164))
output_shape = gp.Coordinate((124, 124))

input_size = input_shape * voxel_size
output_size = output_shape * voxel_size

# total roi of image to predict on
total_roi = gp.Coordinate((1250,1250))*voxel_size

num_fmaps=12

In [None]:
#@title create mtlsd model

class MtlsdModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.unet = UNet(
            in_channels=1,
            num_fmaps=num_fmaps,
            fmap_inc_factor=5,
            downsample_factors=[
                [2, 2],
                [2, 2]],
            kernel_size_down=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]],
            kernel_size_up=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]])

        self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')
        self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

In [None]:
def predict(
    checkpoint,
    raw_file,
    raw_dataset):
  
  raw = gp.ArrayKey('RAW')
  pred_lsds = gp.ArrayKey('PRED_LSDS')
  pred_affs = gp.ArrayKey('PRED_AFFS')

  scan_request = gp.BatchRequest()

  scan_request.add(raw, input_size)
  scan_request.add(pred_lsds, output_size)
  scan_request.add(pred_affs, output_size)

  context = (input_size - output_size) / 2

  source = gp.ZarrSource(
              raw_file,
          {
              raw: raw_dataset
          },
          {
              raw: gp.ArraySpec(interpolatable=True)
          })
  
  with gp.build(source):
    total_input_roi = source.spec[raw].roi
    total_output_roi = source.spec[raw].roi.grow(-context,-context)

  model = MtlsdModel()

  # set model to eval mode
  model.eval()

  # add a predict node
  predict = gp.torch.Predict(
      model=model,
      checkpoint=checkpoint,
      inputs = {
                'input': raw
      },
      outputs = {
          0: pred_lsds,
          1: pred_affs})
  
  # this will scan in chunks equal to the input/output sizes of the respective arrays
  scan = gp.Scan(scan_request)
  
  pipeline = source
  pipeline += gp.Normalize(raw)

  # raw shape = h,w

  pipeline += gp.Unsqueeze([raw])

  # raw shape = c,h,w

  pipeline += gp.Stack(1)

  # raw shape = b,c,h,w

  pipeline += predict
  pipeline += scan
  pipeline += gp.Squeeze([raw])

  # raw shape = c,h,w
  # pred_lsds shape = b,c,h,w
  # pred_affs shape = b,c,h,w

  pipeline += gp.Squeeze([raw, pred_lsds, pred_affs])

  # raw shape = h,w
  # pred_lsds shape = c,h,w
  # pred_affs shape = c,h,w

  predict_request = gp.BatchRequest()

  # this lets us know to process the full image. we will scan over it until it is done
  predict_request.add(raw, total_input_roi.get_end())
  predict_request.add(pred_lsds, total_output_roi.get_end())
  predict_request.add(pred_affs, total_output_roi.get_end())

  with gp.build(pipeline):
      batch = pipeline.request_batch(predict_request)

  return batch[raw].data, batch[pred_lsds].data, batch[pred_affs].data

In [None]:
# fetch checkpoint
!wget https://www.dropbox.com/s/r1u8pvji5lbanyq/model_checkpoint_50000

In [None]:
checkpoint = 'model_checkpoint_50000' 
raw_file = 'testing_data.zarr'
raw_dataset = 'raw/0'

raw, pred_lsds, pred_affs = predict(checkpoint, raw_file, raw_dataset)

In [None]:
fig, axes = plt.subplots(
            1,
            3,
            figsize=(20, 6),
            sharex=True,
            sharey=True,
            squeeze=False)

# view predictions (for lsds we will just view the mean offset component)
axes[0][0].imshow(raw, cmap='gray')
axes[0][1].imshow(np.squeeze(pred_affs[0]), cmap='jet')
axes[0][2].imshow(np.squeeze(pred_lsds[0]), cmap='jet')
axes[0][2].imshow(np.squeeze(pred_lsds[1]), cmap='jet', alpha=0.5)

*  see how to generate a segmentation in **segment.ipynb**