*  Before starting, click "Runtime" in the top panel, select "Change runtime type" and then choose "GPU"

*  This tutorial follows the lsds tutorial, and is therefore condensed. Check out the lsds tutorial (**train_lsds.ipynb**) if there is any confusion throughout

*  Try running each cell consecutively to see what is happening before changing things around

*  Some cells are collapsed by default, these are generally utility functions or are expanded by defaullt in a previous tutorial. Double click to expand/collapse

*  sometimes colab can be slow when training, if this happens you may need to restart the runtime. also, you generally can only run one session at a time.

In [None]:
#@title install packages + repos

# packages
!pip install mahotas
!pip install matplotlib
!pip install scikit-image
!pip install torch
!pip install zarr

# repos
!pip install git+git://github.com/funkelab/daisy.git
!pip install git+git://github.com/funkey/gunpowder.git
!pip install git+git://github.com/funkelab/funlib.learn.torch.git
!pip install git+git://github.com/funkelab/funlib.segment.git
!pip install git+git://github.com/funkelab/lsd.git
!pip install git+git://github.com/funkey/waterz.git

In [2]:
#@title import packages

import gunpowder as gp
import h5py
import io
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import requests
import torch
import zarr

from funlib.learn.torch.models import UNet, ConvPass
from lsd.gp import AddLocalShapeDescriptor
from gunpowder.torch import Train

%matplotlib inline
logging.basicConfig(level=logging.INFO)

In [3]:
#@title utility function to view labels

# matplotlib uses a default shader
# we need to recolor as unique objects

def create_lut(labels):

    max_label = np.max(labels)

    lut = np.random.randint(
            low=0,
            high=255,
            size=(int(max_label + 1), 3),
            dtype=np.uint64)

    lut = np.append(
            lut,
            np.zeros(
                (int(max_label + 1), 1),
                dtype=np.uint8) + 255,
            axis=1)

    lut[0] = 0
    colored_labels = lut[labels]

    return colored_labels

In [4]:
#@title utility  function to download / save data as zarr
def create_data(
    url, 
    name, 
    offset, 
    resolution,
    sections=None,
    squeeze=True):

  in_f = h5py.File(io.BytesIO(requests.get(url).content), 'r')

  raw = in_f['volumes/raw']
  labels = in_f['volumes/labels/neuron_ids']
  
  f = zarr.open(name, 'a')

  if sections is None:
    sections=range(raw.shape[0]-1)

  for i, r in enumerate(sections):

    print(f'Writing data for section {r}')

    raw_slice = raw[r:r+1,:,:]
    labels_slice = labels[r:r+1,:,:]

    if squeeze:
      raw_slice = np.squeeze(raw_slice)
      labels_slice = np.squeeze(labels_slice)

    f[f'raw/{i}'] = raw_slice
    f[f'labels/{i}'] = labels_slice

    f[f'raw/{i}'].attrs['offset'] = offset
    f[f'raw/{i}'].attrs['resolution'] = resolution

    f[f'labels/{i}'].attrs['offset'] = offset
    f[f'labels/{i}'].attrs['resolution'] = resolution

In [5]:
#@title utility function to view a batch

# matplotlib.pyplot wrapper to view data
# default shape should be 2 - 2d data

def imshow(
        raw=None,
        ground_truth=None,
        target=None,
        prediction=None,
        h=None,
        shader='jet',
        subplot=True,
        channel=0,
        target_name='target',
        prediction_name='prediction'):

    rows = 0

    if raw is not None:
        rows += 1
        cols = raw.shape[0] if len(raw.shape) > 2 else 1
    if ground_truth is not None:
        rows += 1
        cols = ground_truth.shape[0] if len(ground_truth.shape) > 2 else 1
    if target is not None:
        rows += 1
        cols = target.shape[0] if len(target.shape) > 2 else 1
    if prediction is not None:
        rows += 1
        cols = prediction.shape[0] if len(prediction.shape) > 2 else 1

    if subplot:
        fig, axes = plt.subplots(
            rows,
            cols,
            figsize=(10, 4),
            sharex=True,
            sharey=True,
            squeeze=False)

    if h is not None:
        fig.subplots_adjust(hspace=h)

    def wrapper(data,row,name="raw"):

        if subplot:
            if len(data.shape) == 2:
                if name == 'raw':
                    axes[0][0].imshow(data, cmap='gray')
                    axes[0][0].set_title(name)
                else:
                    axes[row][0].imshow(create_lut(data))
                    axes[row][0].set_title(name)

            elif len(data.shape) == 3:
                for i, im in enumerate(data):
                    if name == 'raw':
                        axes[0][i].imshow(im, cmap='gray')
                        axes[0][i].set_title(name)
                    else:
                        axes[row][i].imshow(create_lut(im))
                        axes[row][i].set_title(name)

            else:
                for i, im in enumerate(data):
                    axes[row][i].imshow(im[channel], cmap=shader)
                    axes[row][i].set_title(name)


        else:
            if name == 'raw':
                plt.imshow(data, cmap='gray')
            if name == 'labels':
                plt.imshow(data, alpha=0.5)

    row=0 

    if raw is not None:
        wrapper(raw,row=row)
        row += 1
    if ground_truth is not None:
        wrapper(ground_truth,row=row,name='labels')
        row += 1
    if target is not None:
        wrapper(target,row=row,name=target_name)
        row += 1
    if prediction is not None:
        wrapper(prediction,row=row,name=prediction_name)
        row += 1

    plt.show()

In [None]:
create_data(
    'https://cremi.org/static/data/sample_A_20160501.hdf',
    'training_data.zarr',
    offset=[0,0],
    resolution=[4,4])

In [7]:
voxel_size = gp.Coordinate((4, 4))

input_shape = gp.Coordinate((164, 164))
output_shape = gp.Coordinate((124, 124))

input_size = input_shape * voxel_size
output_size = output_shape * voxel_size

num_fmaps=12
num_samples=124
batch_size=5

In [8]:
# mtlsd model - designed to use lsds as an auxiliary learning task for improving affinities
# raw --> lsds / affs

# wrap model in a class. need two out heads, one for lsds, one for affs

class MtlsdModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.unet = UNet(
            in_channels=1,
            num_fmaps=num_fmaps,
            fmap_inc_factor=5,
            downsample_factors=[
                [2, 2],
                [2, 2]],
            kernel_size_down=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]],
            kernel_size_up=[
                [[3, 3], [3, 3]],
                [[3, 3], [3, 3]]])

        self.lsd_head = ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid')
        self.aff_head = ConvPass(num_fmaps, 2, [[1, 1]], activation='Sigmoid')

    def forward(self, input):

        z = self.unet(input)
        lsds = self.lsd_head(z)
        affs = self.aff_head(z)

        return lsds, affs

# combine the lsds and affs losses

class WeightedMSELoss(torch.nn.MSELoss):

    def __init__(self):
        super(WeightedMSELoss, self).__init__()

    def forward(self, lsds_prediction, lsds_target, lsds_weights, affs_prediction, affs_target, affs_weights,):

        loss1 = super(WeightedMSELoss, self).forward(
                lsds_prediction*lsds_weights,
                lsds_target*lsds_weights)

        loss2 = super(WeightedMSELoss, self).forward(
            affs_prediction*affs_weights,
            affs_target*affs_weights)
        
        return loss1 + loss2

In [9]:
def train(
    iterations,
    show_every,
    show_gt=True,
    show_pred=False,
    lsd_channels=None,
    aff_channels=None):

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    gt_lsds = gp.ArrayKey('GT_LSDS')
    lsds_weights = gp.ArrayKey('LSDS_WEIGHTS')
    pred_lsds = gp.ArrayKey('PRED_LSDS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    affs_weights = gp.ArrayKey('AFFS_WEIGHTS')
    pred_affs = gp.ArrayKey('PRED_AFFS')

    model = MtlsdModel()
    loss = WeightedMSELoss()
    optimizer = torch.optim.Adam(lr=0.5e-4, params=model.parameters())

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(gt_lsds, output_size)
    request.add(lsds_weights, output_size)
    request.add(pred_lsds, output_size)
    request.add(gt_affs, output_size)
    request.add(affs_weights, output_size)
    request.add(pred_affs, output_size)

    sources = tuple(
        gp.ZarrSource(
            'training_data.zarr',  
            {
                raw: f'raw/{i}',
                labels: f'labels/{i}'
            },  
            {
                raw: gp.ArraySpec(interpolatable=True),
                labels: gp.ArraySpec(interpolatable=False)
            }) + 
            gp.Normalize(raw) +
            gp.RandomLocation()
            for i in range(num_samples)
        )

    # raw:      (h, w)
    # labels:   (h, w)

    pipeline = sources

    pipeline += gp.RandomProvider()

    pipeline += gp.SimpleAugment()

    pipeline += gp.ElasticAugment(
        control_point_spacing=(64, 64),
        jitter_sigma=(5.0, 5.0),
        rotation_interval=(0, math.pi/2))

    pipeline += gp.IntensityAugment(
        raw,
        scale_min=0.9,
        scale_max=1.1,
        shift_min=-0.1,
        shift_max=0.1)

    pipeline += gp.GrowBoundary(labels)

    pipeline += AddLocalShapeDescriptor(
        labels,
        gt_lsds,
        mask=lsds_weights,
        sigma=80,
        downsample=1)
        
    pipeline += gp.AddAffinities(
    affinity_neighborhood=[
        [0, -1],
        [-1, 0]],
    labels=labels,
    affinities=gt_affs,
    dtype=np.float32)

    pipeline += gp.BalanceLabels(
        gt_affs,
        affs_weights)

    pipeline += gp.Unsqueeze([raw])

    pipeline += gp.Stack(batch_size)

    pipeline += gp.PreCache(num_workers=10)

    pipeline += Train(
        model,
        loss,
        optimizer,
        inputs={
            'input': raw
        },
        outputs={
            0: pred_lsds,
            1: pred_affs
        },
        loss_inputs={
            0: pred_lsds,
            1: gt_lsds,
            2: lsds_weights,
            3: pred_affs,
            4: gt_affs,
            5: affs_weights
        })

    with gp.build(pipeline):
        for i in range(iterations):
            batch = pipeline.request_batch(request)

            start = request[labels].roi.get_begin()/voxel_size
            end = request[labels].roi.get_end()/voxel_size

            if i % show_every == 0:
              
              imshow(raw=np.squeeze(batch[raw].data[:,:,start[0]:end[0],start[1]:end[1]]))
              imshow(ground_truth=batch[labels].data)

              if lsd_channels:
                for n,c in lsd_channels.items():
                  
                  if show_gt:
                    imshow(target=batch[gt_lsds].data, target_name='gt '+n, channel=c)
                  if show_pred:
                    imshow(prediction=batch[pred_lsds].data, prediction_name='pred '+n, channel=c)

              if aff_channels:
                for n,c in aff_channels.items():

                  if show_gt:
                    imshow(target=batch[gt_affs].data, target_name='gt '+n, channel=c)
                  if show_pred:
                    imshow(target=batch[pred_affs].data, target_name='pred '+n, channel=c)

In [None]:
# view a batch of ground truth lsds/affs, no need to show predicted lsds/affs yet

lsd_channels = {
    'offset (y)': 0,
    'offset (x)': 1,
    'orient (y)': 2,
    'orient (x)': 3,
    'yx change': 4,
    'voxel count': 5
}

#just view first y affs
aff_channels = {'affs': 0}

train(
    iterations=1,
    show_every=1,
    lsd_channels=lsd_channels,
    aff_channels=aff_channels)

In [None]:
# lets just view the mean offset channels
# train for 1k iterations, view every 100th batch
# show the prediction as well as the ground truth

lsd_channels = {
    'offset (y)': 0,
    'offset (x)': 1
}

aff_channels = {'affs': 0}

train(
    iterations=1000,
    show_every=100,
    show_pred=True,
    lsd_channels=lsd_channels,
    aff_channels=aff_channels)

*  Just a general idea of how to use gunpowder - the networks in the paper are all in 3d and should be trained on sufficient hardware

*  Results will probably vary since these are 2d slices of 3d data - sometimes more information is required in the z-dimension to inform predictions (especially for neuron segmentation). Feel free to try training for longer.

*  see how to run inference in **inference.ipynb**