In [1]:
from funlib.learn.torch.models import UNet, ConvPass
import gunpowder as gp
import logging
import math
import numpy as np
import torch

logging.basicConfig(level=logging.INFO)

### Network Input/Output Shape

Similar to what we did before, but this time, the input and output shape of our network is 4D!

In [2]:
# training parameters:

# number of iterations to train for
num_iterations = int(1e5)
# input size of the U-Net (7 time frames)
input_shape = (7, 64, 124, 124)
# output size of the U-Net (1 time frame)
output_shape = (1, 32, 32, 32)

### Model, Loss, and Optimizer Setup

Here we create a U-Net as before, but this time we use several input channels to feed different time frames into the network. We give the network seven frames and let it output a single feature map for the center frame at the end.

In [3]:
# create model, loss, and optimizer

unet = UNet(
    in_channels=7,
    num_fmaps=12,
    fmap_inc_factor=5,
    downsample_factors=[
        (1, 2, 2),
        (1, 2, 2),
        (2, 2, 2)],
    constant_upsample=True,
    padding='valid')
model = torch.nn.Sequential(
    unet,
    ConvPass(12, 1, [(1, 1, 1)], activation=None),
    torch.nn.Sigmoid()
)
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

### Arrays and Graphs

This time, we will have a new kind of data in a batch: a graph, called 'centers'. Those are the locations of cells (each cell is a node with an x, y, and z coordinate). There are no edges in this example.

Since we can't train on graphs directly, we will have to convert the graph into a volumetric array. We will do that by drawing a Gaussian blob around each node in 'centers' and store the result in an array called 'blobs'.

In [4]:
# declare arrays and graphs to use

# the raw data
raw = gp.ArrayKey("RAW")
# point annotations for cell centers
centers = gp.GraphKey("CENTERS")
# Gaussian blobs drawn around each cell center
blobs = gp.ArrayKey("BLOBS")
# the prediction of the network
prediction = gp.ArrayKey("PREDICTION")

### Multiple Data Sources

This time, we have data coming from two sources: raw data from a zarr container, and a graph from a text file. Here we create one source for each, and merge both sources into one using the `MergeProvider`.

In [5]:
# create data sources

raw_source = gp.ZarrSource(
    '../data/mouse.n5',
    {raw: 'raw'},
    {raw: gp.ArraySpec(interpolatable=True)}
)
centers_source = gp.CsvPointsSource(
    '../data/mouse_cells.csv',
    centers
)

# combine both raw and cell center source into a single provider
sources = (raw_source, centers_source) + gp.MergeProvider()

### Train Pipeline

The train pipeline is very similar to the previous exercises. The biggest differences are:
1. We deal with 4D data now.
2. We need to convert the graph in 'centers' into an array of Gaussian blobs in 'blobs'.

In [None]:
# get the voxel size of the raw data
with gp.build(raw_source):
    voxel_size = raw_source.spec[raw].voxel_size

# from here on, all sizes are in world units
input_size = voxel_size*input_shape
output_size = voxel_size*output_shape

# create a request for a training batch
request = gp.BatchRequest()
request.add(raw, input_size)
request.add(blobs, output_size)

# snapshots should contain the prediction as well
snapshot_request = gp.BatchRequest()
snapshot_request[prediction] = request[blobs]

pipeline = (
    
    sources +
    
    # pick a random location to train on, but ensure that there is a cell in the center
    gp.RandomLocation(ensure_nonempty=centers, ensure_centered=True) +
    
    # augment data using rotations and elastic deformation
    gp.ElasticAugment(
        control_point_spacing=(5, 10, 10),
        jitter_sigma=(1.0, 1.0, 1.0),
        rotation_interval=[0, math.pi/2.0],
        subsample=8
    ) +
    
    # turn cell center annotations into Gaussian blobs
    gp.RasterizeGraph(
        centers,
        blobs,
        array_spec=gp.ArraySpec(voxel_size=voxel_size, dtype=np.float32),
        settings=gp.RasterizationSettings(
            radius=(1, 10, 10, 10),
            mode='peak'
        )
    ) +
    
    # introduce a "batch" dimension
    gp.Stack(1) +
    # (arrays have shape (1, t, d, h, w) now)
    
    # use parallel processes to pre-fetch batches from upstream
    gp.PreCache() +
    
    # train the model to match the prediction with the blobs
    gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'input': raw
        },
        outputs={
            0: prediction
        },
        loss_inputs={
            0: prediction,
            1: blobs
        },
        save_every=10000
    ) +
    
    # remove the batch dimension again
    gp.Squeeze([raw, blobs, prediction]) +
    # (arrays have shape (t, d, h, w) now)
    
    # store a training snapshot every 1000 iterations
    gp.Snapshot({
        raw: 'raw',
        blobs: 'blobs',
        prediction: 'prediction'
    },
    output_filename='it_{iteration}.hdf',
    additional_request=snapshot_request,
    every=1000)
)


In [None]:
# build the pipeline and train
with gp.build(pipeline):
    for i in range(num_iterations):
        batch = pipeline.request_batch(request)