*  Before starting, click "Runtime" in the top panel, select "Change runtime type" and then choose "GPU"

*  This is a bonus tutorial based off the lsd train tutorial (**train_lsd.ipynb** ) and shows an example for training and learning to predict zero inside glia

*  Try running each cell consecutively to see what is happening before changing things around

*  Some cells are collapsed by default, these are generally utility functions or are expanded by defaullt in a previous tutorial. Double click to expand/collapse

*  sometimes colab can be slow when training, if this happens you may need to restart the runtime. also, you generally can only run one session at a time.

In [None]:
#@title install packages + repos

# packages
!pip install gunpowder
!pip install matplotlib
!pip install scikit-image
!pip install torch
!pip install zarr

# repos
!pip install git+https://github.com/funkelab/funlib.learn.torch.git
!pip install git+https://github.com/funkelab/lsd.git

In [None]:
#@title import packages

import gunpowder as gp
import h5py
import io
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import requests
import torch
import zarr

from funlib.learn.torch.models import UNet, ConvPass
from gunpowder.torch import Train
from lsd.train.gp import AddLocalShapeDescriptor
from tqdm import tqdm

%matplotlib inline
logging.basicConfig(level=logging.INFO)

In [None]:
#@title utility function to view labels

# matplotlib uses a default shader
# we need to recolor as unique objects

def create_lut(labels):

    max_label = np.max(labels)

    lut = np.random.randint(
            low=0,
            high=255,
            size=(int(max_label + 1), 3),
            dtype=np.uint8)

    lut = np.append(
            lut,
            np.zeros(
                (int(max_label + 1), 1),
                dtype=np.uint8) + 255,
            axis=1)

    lut[0] = 0
    colored_labels = lut[labels]

    return colored_labels

In [None]:
#@title utility function to mask specific ids, download / save data as zarr

def create_data(
    url, 
    name, 
    offset, 
    resolution,
    mask_ids,
    sections=None,
    squeeze=True):

  in_f = h5py.File(io.BytesIO(requests.get(url).content), 'r')

  raw = in_f['volumes/raw']
  labels = in_f['volumes/labels/neuron_ids']
  
  container = zarr.open(name, 'a')

  if sections is None:
    sections=range(raw.shape[0]-1)

  for index, section in enumerate(sections):

    print(f'Writing data for section {section}')

    raw_slice = raw[section]
    labels_slice = labels[section]

    # set mask id(s) to zero
    labels_slice[np.isin(labels_slice, mask_ids)] = 0

    # create labels mask (ones like labels)
    labels_mask_slice = np.ones_like(labels_slice).astype(np.uint8)

    # create ids mask (1 where ids == true), just to use for random location node
    ids_mask_slice = (1 - (labels_slice > 0)).astype(np.uint8)

    for ds_name, data in [
        ('raw', raw_slice),
        ('labels', labels_slice),
        ('labels_mask', labels_mask_slice),
        ('ids_mask', ids_mask_slice)]:

        if squeeze:
          data = np.squeeze(data)
        
        container[f'{ds_name}/{index}'] = data
        container[f'{ds_name}/{index}'].attrs['offset'] = offset
        container[f'{ds_name}/{index}'].attrs['resolution'] = resolution

In [None]:
#@title utility function to view a batch

# matplotlib.pyplot wrapper to view data
# default shape should be 2 - 2d data

def imshow(
        raw=None,
        ground_truth=None,
        target=None,
        prediction=None,
        h=None,
        shader='jet',
        subplot=True,
        channel=0,
        target_name='target',
        prediction_name='prediction'):

    rows = 0

    if raw is not None:
        rows += 1
        cols = raw.shape[0] if len(raw.shape) > 2 else 1
    if ground_truth is not None:
        rows += 1
        cols = ground_truth.shape[0] if len(ground_truth.shape) > 2 else 1
    if target is not None:
        rows += 1
        cols = target.shape[0] if len(target.shape) > 2 else 1
    if prediction is not None:
        rows += 1
        cols = prediction.shape[0] if len(prediction.shape) > 2 else 1

    if subplot:
        fig, axes = plt.subplots(
            rows,
            cols,
            figsize=(10, 4),
            sharex=True,
            sharey=True,
            squeeze=False)

    if h is not None:
        fig.subplots_adjust(hspace=h)

    def wrapper(data,row,name="raw"):

        if subplot:
            if len(data.shape) == 2:
                if name == 'raw':
                    axes[0][0].imshow(data, cmap='gray')
                    axes[0][0].set_title(name)
                else:
                    axes[row][0].imshow(create_lut(data))
                    axes[row][0].set_title(name)

            elif len(data.shape) == 3:
                for i, im in enumerate(data):
                    if name == 'raw':
                        axes[0][i].imshow(im, cmap='gray')
                        axes[0][i].set_title(name)
                    else:
                        axes[row][i].imshow(create_lut(im))
                        axes[row][i].set_title(name)

            else:
                for i, im in enumerate(data):
                    axes[row][i].imshow(im[channel], cmap=shader)
                    axes[row][i].set_title(name)


        else:
            if name == 'raw':
                plt.imshow(data, cmap='gray')
            if name == 'labels':
                plt.imshow(data, alpha=0.5)

    row=0 

    if raw is not None:
        wrapper(raw,row=row)
        row += 1
    if ground_truth is not None:
        wrapper(ground_truth,row=row,name='labels')
        row += 1
    if target is not None:
        wrapper(target,row=row,name=target_name)
        row += 1
    if prediction is not None:
        wrapper(prediction,row=row,name=prediction_name)
        row += 1

    plt.show()

In [None]:
# this is a dark stringy glia id in sample A. Let's learn to ignore it
glia_ids = [20474]

create_data(
    'https://cremi.org/static/data/sample_A_20160501.hdf',
    'training_data.zarr',
    offset=[0,0],
    resolution=[4,4],
    mask_ids=glia_ids)

In [None]:
# view a random 5 sections...

fig, axes = plt.subplots(
            2,
            5,
            figsize=(20, 6),
            sharex=True,
            sharey=True,
            squeeze=False)

rand = random.sample(range(0, 124), 5)

for i,j in enumerate(rand):

  raw = zarr.open('training_data.zarr')[f'raw/{j}'][:]
  labels = zarr.open('training_data.zarr')[f'labels/{j}'][:]
  ids_mask = zarr.open('training_data.zarr')[f'ids_mask/{j}'][:]

  axes[0][i].imshow(create_lut(labels))
  axes[1][i].imshow(raw, cmap='gray')
  axes[1][i].imshow(ids_mask, alpha=0.5)

In [None]:
voxel_size = gp.Coordinate((4, 4))

input_shape = gp.Coordinate((164, 164))
output_shape = gp.Coordinate((124, 124))

input_size = input_shape * voxel_size
output_size = output_shape * voxel_size

num_samples=124
batch_size=5

In [None]:
# weighted mean squared error loss

class WeightedMSELoss(torch.nn.MSELoss):

    def __init__(self):
        super(WeightedMSELoss, self).__init__()

    def forward(self, prediction, target, weights):

        scaled = (weights * (prediction - target) ** 2)

        if len(torch.nonzero(scaled)) != 0:

            mask = torch.masked_select(scaled, torch.gt(weights, 0))
            loss = torch.mean(mask)

        else:
            loss = torch.mean(scaled)

        return loss

In [None]:
def train(
    iterations,
    batch_size,
    show_every,
    show_gt=True,
    show_pred=False,
    channels={'offset in y': 0}):

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    labels_mask = gp.ArrayKey('LABELS_MASK')
    ids_mask = gp.ArrayKey('IDS_MASK')
    gt_lsds = gp.ArrayKey('GT_LSDS')
    pred_lsds = gp.ArrayKey('PRED_LSDS')

    num_samples = 124
    num_fmaps = 12
    
    ds_fact = [(2,2),(2,2)]
    num_levels = len(ds_fact) + 1
    ksd = [[(3,3), (3,3)]]*num_levels
    ksu = [[(3,3), (3,3)]]*(num_levels - 1)

    # create unet
    unet = UNet(
      in_channels=1,
      num_fmaps=num_fmaps,
      fmap_inc_factor=5,
      downsample_factors=ds_fact,
      kernel_size_down=ksd,
      kernel_size_up=ksu,
      constant_upsample=True)

    model = torch.nn.Sequential(
        unet,
        ConvPass(num_fmaps, 6, [[1, 1]], activation='Sigmoid'))
    
    loss = WeightedMSELoss()
    optimizer = torch.optim.Adam(lr=0.5e-4, params=model.parameters())

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_mask, output_size)
    request.add(ids_mask, output_size)
    request.add(gt_lsds, output_size)
    request.add(pred_lsds, output_size)

    sources = tuple(
        gp.ZarrSource(
            'training_data.zarr',  
            {
                raw: f'raw/{i}',
                labels: f'labels/{i}',
                labels_mask: f'labels_mask/{i}',
                ids_mask: f'ids_mask/{i}'
            },  
            {
                raw: gp.ArraySpec(interpolatable=True),
                labels: gp.ArraySpec(interpolatable=False),
                labels_mask: gp.ArraySpec(interpolatable=False),
                ids_mask: gp.ArraySpec(interpolatable=False)
            }) + 
            gp.Normalize(raw) +
            # just use ids mask for restricting batches
            gp.RandomLocation(mask=ids_mask, min_masked=0.05)
            for i in range(num_samples)
        )

    # raw:      (h, w)
    # labels:   (h, w)

    pipeline = sources

    pipeline += gp.RandomProvider()

    pipeline += gp.SimpleAugment()

    pipeline += gp.IntensityAugment(
        raw,
        scale_min=0.9,
        scale_max=1.1,
        shift_min=-0.1,
        shift_max=0.1)

    pipeline += gp.GrowBoundary(labels)

    pipeline += AddLocalShapeDescriptor(
        labels,
        gt_lsds,
        sigma=80,
        downsample=2)

    pipeline += gp.Unsqueeze([raw, labels_mask])

    pipeline += gp.Stack(batch_size)

    pipeline += gp.PreCache(num_workers=10)

    # use labels mask as weights. We want predictions to go close to zero in these regions
    pipeline += Train(
        model,
        loss,
        optimizer,
        inputs={
            'input': raw
        },
        outputs={
            0: pred_lsds
        },
        loss_inputs={
            0: pred_lsds,
            1: gt_lsds,
            2: labels_mask
        })

    with gp.build(pipeline):
        progress = tqdm(range(iterations))
        for i in progress:
            batch = pipeline.request_batch(request)

            start = request[labels].roi.get_begin()/voxel_size
            end = request[labels].roi.get_end()/voxel_size

            if i % show_every == 0:
              
              imshow(raw=np.squeeze(batch[raw].data[:,:,start[0]:end[0],start[1]:end[1]]))
              imshow(ground_truth=batch[labels].data)

              for n,c in channels.items():
                
                if show_gt:
                  imshow(target=batch[gt_lsds].data, target_name='gt '+n, channel=c)
                if show_pred:
                  imshow(prediction=batch[pred_lsds].data, prediction_name='pred '+n, channel=c)

            progress.set_description(f'Training iteration {i}') 
            pass

In [None]:
# view a batch of ground truth lsds, no need to show predicted lsds yet

channels = {
    'offset (y)': 0,
    'offset (x)': 1,
    'orient (y)': 2,
    'orient (x)': 3,
    'yx change': 4,
    'voxel count': 5
}

train(iterations=1, batch_size=5, show_every=1, channels=channels)

In [None]:
# train for ~1k iterations, view every 100th batch
# lets just view the mean offset y channel
# show the prediction as well as the ground truth
# will take longer to converge than affs

channels = {'offset (y)': 0}

train(iterations=1001, batch_size=5, show_every=100, show_pred=True, channels=channels)