In [4]:
from funlib.learn.torch.models import UNet, ConvPass
import gunpowder as gp
import logging
import math
import numpy as np
import torch

logging.basicConfig(level=logging.INFO)

# training parameters:

# number of iterations to train for
num_iterations = int(1e5)
# input size of the U-Net (7 time frames)
input_shape = (7, 64, 124, 124)
# output size of the U-Net (1 time frame)
output_shape = (1, 32, 32, 32)

# declare arrays and graphs to use

# the raw data
raw = gp.ArrayKey("RAW")
# point annotations for cell centers
centers = gp.GraphKey("CENTERS")
# Gaussian blobs drawn around each cell center
blobs = gp.ArrayKey("BLOBS")
# the prediction of the network
prediction = gp.ArrayKey("PREDICTION")

# create model, loss, and optimizer

unet = UNet(
    in_channels=7,
    num_fmaps=12,
    fmap_inc_factor=5,
    downsample_factors=[
        (1, 2, 2),
        (1, 2, 2),
        (2, 2, 2)],
    constant_upsample=True,
    padding='valid')
model = torch.nn.Sequential(
    unet,
    ConvPass(12, 1, [(1, 1, 1)], activation=None),
    torch.nn.Sigmoid()
)
loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# assemble the training pipeline

raw_source = gp.ZarrSource(
    '../data/mouse.n5',
    {raw: 'raw'},
    {raw: gp.ArraySpec(interpolatable=True)}
)
centers_source = gp.CsvPointsSource(
    '../data/mouse_cells.csv',
    centers
)

# get the voxel size of the raw data
with gp.build(raw_source):
    voxel_size = raw_source.spec[raw].voxel_size

# from here on, all sizes are in world units
input_size = voxel_size*input_shape
output_size = voxel_size*output_shape

# create a request for a training batch
request = gp.BatchRequest()
request.add(raw, input_size)
request.add(blobs, output_size)

# snapshots should contain the prediction as well
snapshot_request = gp.BatchRequest()
snapshot_request[prediction] = request[blobs]

pipeline = (
    
    # combine both raw and cell center source into a single provider
    (raw_source, centers_source) +
    gp.MergeProvider() +
    
    # pick a random location to train on, but ensure that there is a cell in the center
    gp.RandomLocation(ensure_nonempty=centers, ensure_centered=True) +
    
    # augment data using rotations and elastic deformation
    gp.ElasticAugment(
        control_point_spacing=(5, 10, 10),
        jitter_sigma=(1.0, 1.0, 1.0),
        rotation_interval=[0, math.pi/2.0],
        subsample=8
    ) +
    
    # turn cell center annotations into Gaussian blobs
    gp.RasterizeGraph(
        centers,
        blobs,
        array_spec=gp.ArraySpec(voxel_size=voxel_size, dtype=np.float32),
        settings=gp.RasterizationSettings(
            radius=(1, 10, 10, 10),
            mode='peak'
        )
    ) +
    
    # introduce a "batch" dimension
    gp.Stack(1) +
    # (arrays have shape (1, t, d, h, w) now)
    
    # use parallel processes to pre-fetch batches from upstream
    gp.PreCache() +
    
    # train the model to match the prediction with the blobs
    gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'input': raw
        },
        outputs={
            0: prediction
        },
        loss_inputs={
            0: prediction,
            1: blobs
        },
        save_every=10000
    ) +
    
    # remove the batch dimension again
    gp.Squeeze([raw, blobs, prediction]) +
    # (arrays have shape (t, d, h, w) now)
    
    # store a training snapshot every 1000 iterations
    gp.Snapshot({
        raw: 'raw',
        blobs: 'blobs',
        prediction: 'prediction'
    },
    output_filename='it_{iteration}.hdf',
    additional_request=snapshot_request,
    every=1000)
)

# build the pipeline and train
with gp.build(pipeline):
    for i in range(num_iterations):
        batch = pipeline.request_batch(request)

AttributeError: module 'gunpowder' has no attribute 'RasterizeGraph'