In [1]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import numpy as np
import torch

import logging
import math

In [2]:
logging.basicConfig(level=logging.INFO)

# Constants
num_iterations = 1_000
checkpoint_every = 500
snapshot_every = 500

voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76))
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36))
output_size = output_shape * voxel_size
batch_size = 10
dims = len(voxel_size)
neighborhood = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]

In [3]:
unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)

affinities_model = torch.nn.Sequential(
    unet,
    torch.nn.Conv3d(16, dims, (1,) * dims),
    torch.nn.Sigmoid(),
)
affinities_model.train()

affinities_loss = torch.nn.MSELoss()
affinities_optimizer = torch.optim.Adam(affinities_model.parameters(), lr=1e-5)

In [4]:
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")
prediction_gradients = gp.ArrayKey("PREDICTION_GRADIENTS")
labels = gp.ArrayKey("LABELS")
affs = gp.ArrayKey("AFFS")

sources = tuple(
    [
        gp.ZarrSource(
            filename="../data/MB-Z1213-56.zarr",
            datasets={
                raw: f"{volume}/raw",
                labels: f"{volume}/labels",
            },
            array_specs={
                raw: gp.ArraySpec(interpolatable=True),
                labels: gp.ArraySpec(interpolatable=True),
            }
        )
        + gp.RandomLocation()
        for volume in ["A", "B"]
    ]
)
pipeline = (
    sources
    + gp.RandomProvider()
    + gp.Normalize(raw)
    + gp.ElasticAugment((10, 10, 10), (0.1, 0.1, 0.1), (math.pi / 2, math.pi / 2))
    + gp.SimpleAugment(transpose_only=(1, 2))
    + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1)
    # Add your affinities. See gunpowder documentation
    # Add a precache node for faster performance
    # Use a stack node here to get the correct batch size
    + gp.Unsqueeze([raw], axis=1)
    + gp.torch.Train(
        model=affinities_model,
        loss=affinities_loss,
        optimizer=affinities_optimizer,
        inputs={"input": raw},
        outputs={0: predictions},
        loss_inputs={0: predictions, 1: affs},
        gradients={0: prediction_gradients},
        save_every=checkpoint_every,
    )
    + gp.Squeeze([raw], axis=1)
    + gp.Snapshot(
        output_filename="{iteration}.hdf",
        dataset_names={
            raw: "volumes/raw",
            labels: "volumes/labels",
            predictions: "volumes/predictions",
            prediction_gradients: "volumes/prediction_gradients",
            affs: "volumes/affinities",
        },
        every=snapshot_every,
    )
)

request = gp.BatchRequest()
# Add raw/labels/affs to the request with appropriate sizes.
request.add(predictions, output_size)
request.add(prediction_gradients, output_size)

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

In [5]:
# main loop
if __name__ == "__main__":
    with gp.build(pipeline):
        for i in range(num_iterations):
            batch = pipeline.request_batch(request)

INFO:gunpowder.torch.nodes.train:Starting training from scratch
INFO:gunpowder.torch.nodes.train:Using device cuda
INFO:gunpowder.nodes.precache:starting new set of workers (20, cache size 50)...
INFO:gunpowder.nodes.generic_train:Train process: iteration=1 loss=0.248180 time=11.864300
INFO:gunpowder.nodes.snapshot:saving to snapshots/1.hdf
INFO:gunpowder.nodes.generic_train:Train process: iteration=2 loss=0.246153 time=1.788366
INFO:gunpowder.nodes.generic_train:Train process: iteration=3 loss=0.251732 time=1.780503
INFO:gunpowder.nodes.generic_train:Train process: iteration=4 loss=0.246628 time=1.786965
INFO:gunpowder.nodes.generic_train:Train process: iteration=5 loss=0.251449 time=1.792226
INFO:gunpowder.nodes.generic_train:Train process: iteration=6 loss=0.253759 time=1.782859
INFO:gunpowder.nodes.generic_train:Train process: iteration=7 loss=0.247197 time=1.782277
INFO:gunpowder.nodes.generic_train:Train process: iteration=8 loss=0.247759 time=1.785345
INFO:gunpowder.nodes.generi

KeyboardInterrupt: 