In [None]:
import gunpowder as gp
import daisy
import zarr
import numpy as np
import torch

# Constants
checkpoint = "checkpoint"
grow = (36, 36, 36)
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76)) + grow
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36)) + grow
output_size = output_shape * voxel_size
context = (input_size - output_size) / 2


# initialize model
unet = UNet(
    in_channels=config.in_channels,
    num_fmaps=config.num_fmaps,
    fmap_inc_factor=config.fmap_inc_factor,
    downsample_factors=[tuple(x) for x in config.downsample_factors],
    activation=config.activation,
    voxel_size=config.voxel_size,
    num_heads=1,
    constant_upsample=True,
)
logits = ConvPass(
    in_channels=config.num_fmaps,
    out_channels=config.num_classes + 1,
    kernel_sizes=[[1 for _ in range(len(config.downsample_factors[0]))]],
    activation=None,
)
probs = torch.nn.Softmax()

semantic_model = torch.nn.Sequential([unet, logits, probs])

model.load_state_dict(torch.load(checkpoint)["model_state_dict"])
model.eval()
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")

reference_request = gp.BatchRequest()
reference_request.add(raw, input_size)
reference_request.add(predictions, output_size)

source = gp.ZarrSource(
    filename="../MB-Z1213-56.zarr",
    datasets={
        raw: "TEST/raw",
    },
)

with gp.build(source):
    total_input_roi = source.spec[raw].roi
total_output_roi = total_input_roi.grow(-context, -context)

daisy.prepare_ds(
    f"{checkpoint}.zarr",
    "volumes/predictions",
    daisy.Roi(total_output_roi.get_offset(), total_output_roi.get_shape()),
    (4, 4, 4),
    np.float32,
    write_size=output_size,
    num_channels=4,
)

pipeline = (
    source
    + gp.Normalize(raw)
    + Unsqueeze([raw])
    + Unsqueeze([raw])
    + gp.torch.Predict(
        model=model,
        inputs={"raw": raw},
        outputs={0: predictions},
    )
    + Squeeze([raw, predictions])
    + Squeeze([raw])
    + gp.ZarrWrite(
        output_dir="./",
        output_filename=f"{checkpoint}.zarr",
        dataset_names={
            raw: "volumes/raw",
            predictions: "volumes/predictions",
        },
        dataset_dtypes={predictions: gp.ArraySpec(roi=total_output_roi)},
    )
    + gp.Scan(reference_request)
)

request = gp.BatchRequest()
request[raw] = gp.ArraySpec(roi=total_input_roi)
request[predictions] = gp.ArraySpec(roi=total_output_roi)

with gp.build(pipeline):
    pipeline.request_batch(request)


# Post processing
results_zarr = zarr.open(f"predictions.zarr", "r+")
results = results["volumes/predictions"]

with build(pipeline):
    spec = pipeline.spec[predictions]

semantic_segmentation = np.argmax(predictions.data, axis=0)

results["semantic_segmentation"] = semantic_segmentation
results["semantic_segmentation"].attrs["offset"] = spec.roi.get_offset()
results["semantic_segmentation"].attrs["resolution"] = spec.voxel_size
