In [None]:
from funlib.learn.torch.models import UNet

import numpy as np
import gunpowder as gp
import torch
import zarr
import daisy

import waterz
from watershed_helpers import watershed_from_affinities

# constants
grow = (36, 36, 36)
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76)) + grow
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36)) + grow
output_size = output_shape * voxel_size
context = (input_size - output_size) / 2
checkpoint = "checkpoint"

unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)

affinities_model = [
    unet,
    torch.nn.Conv3d(config.num_fmaps, dims, (1,) * dims),
    torch.nn.Sigmoid(),
]

model.load_state_dict(torch.load(checkpoint)["model_state_dict"])
model.eval()

raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")

reference_request = gp.BatchRequest()
reference_request.add(raw, input_size)
reference_request.add(predictions, output_size)

source = gp.ZarrSource(
    filename="../MB-Z1213-56.zarr",
    datasets={
        raw: f"TEST/raw",
    },
)

# use gunpowder to check what data is provided
with gp.build(source):
    total_input_roi = source.spec[raw].roi
# shrink this down to the expected output size
total_output_roi = total_input_roi.grow(-context, -context)

# prepare a dataset
daisy.prepare_ds(
    f"predictions.zarr",
    "volumes/predictions",
    daisy.Roi(total_output_roi.get_offset(), total_output_roi.get_shape()),
    (4, 4, 4),
    np.float32,
    write_size=output_size,
    num_channels=3,
)

pipeline = (
    source
    + gp.Normalize(raw)
    + Unsqueeze([raw])
    + Unsqueeze([raw])
    + gp.torch.Predict(
        model=model,
        inputs={"raw": raw},
        outputs={0: predictions},
        checkpoint=checkpoint,
    )
    + Squeeze([raw, predictions])
    + Squeeze([raw])
    + gp.ZarrWrite(
        output_dir="./",
        output_filename=f"predictions.zarr",
        dataset_names={
            raw: "volumes/raw",
            predictions: "volumes/predictions",
        },
        dataset_dtypes={predictions: gp.ArraySpec(roi=total_output_roi)},
    )
    + gp.Scan(reference_request)
)

request = gp.BatchRequest()
request[raw] = gp.ArraySpec(roi=total_input_roi)
request[predictions] = gp.ArraySpec(roi=total_output_roi)

with gp.build(pipeline):
    pipeline.request_batch(request)

# Post processing:
results_zarr = zarr.open("model.checkpoint.zarr", "r+")
results = np.array(results_zarr["volumes/predictions"])

threshold = 0.5

predictions_source = gp.ZarrSource(
    filename="model.checkpoint.zarr",
    datasets={
        predictions: "volumes/predictions",
    },
)
with build(predictions_source):
    spec = predictions_source.spec[predictions]

fragments = watershed_from_affinities(results)[0]
thresholds = [threshold]
segmentations = waterz.agglomerate(
    affs=results.astype(np.float32),
    fragments=fragments,
    thresholds=thresholds,
)

segmentation = next(segmentations)

results["volumes/segmentation"] = segmentation
results["volumes/segmentation"].attrs["offset"] = spec.roi.get_offset()
results["volumes/segmentation"].attrs["resolution"] = spec.voxel_size