In [None]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import torch

import logging
import math

logging.basicConfig(level=logging.INFO)

# Constants
num_iterations = 1_000
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76))
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36))
output_size = output_shape * voxel_size
batch_size = 10
dims = len(voxel_size)
neighborhood = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]

unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)

affinities_model = [
    unet,
    torch.nn.Conv3d(config.num_fmaps, dims, (1,) * dims),
    torch.nn.Sigmoid(),
]
affinities_model.train()

affinities_loss = torch.nn.MSELoss()
affinities_optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# get pipeline

raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")
prediction_gradients = gp.ArrayKey("PREDICTION_GRADIENTS")
labels = gp.ArrayKey("LABELS")
affs = gp.ArrayKey("AFFS")

sources = tuple(
    [
        gp.ZarrSource(
            filename="../MB-Z1213-56.zarr",
            datasets={
                raw: f"{volume}/raw",
                labels: f"{volume}/labels",
            },
        )
        + gp.RandomLocation()
        for volume in ["A", "B", "C"]
    ]
)
pipeline = (
    sources
    + gp.RandomProvider()
    + gp.Normalize(raw)
    + gp.ElasticAugment((10, 10, 10), (0.1, 0.1, 0.1), (math.pi / 2, math.pi / 2))
    + gp.SimpleAugment(transpose_only=(1, 2))
    + gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1)
    + gp.AddAffinities(
        affinity_neighborhood=neighborhood, labels=labels, affinities=affs
    )
    + gp.PreCache()
    + gp.Stack(batch_size)
    + gp.Unsqueeze([raw], axis=1)
    + gp.torch.Train(
        model=affinities_model,
        loss=affinities_loss,
        optimizer=affinities_optimizer,
        inputs={"raw": raw},
        outputs={0: predictions},
        loss_inputs={0: predictions, 1: affs},
        gradients={0: prediction_gradients},
        save_every=500,
    )
    + gp.Squeeze([raw], axis=1)
    + gp.Snapshot(
        output_filename="{iteration}.hdf",
        dataset_names={
            raw: "volumes/raw",
            labels: "volumes/labels",
            predictions: "volumes/predictions",
            prediction_gradients: "volumes/prediction_gradients",
            affs: "volumes/affinities",
        },
        every=500,
    )
)

request = gp.BatchRequest()
request.add(raw, input_size)
request.add(labels, output_size)
request.add(affs, output_size)
request.add(predictions, output_size)
request.add(prediction_gradients, output_size)

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# main loop
if __name__ == "__main__":
    with gp.build(pipeline):
        for i in range(num_iterations):
            batch = pipeline.request_batch(request)
