## Please download a pre trained model from
`https://www.dropbox.com/s/tir8ob3q67p79fj/precomputed_model_checkpoint?dl=0`

In [1]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import zarr
import numpy as np
import torch

import logging

In [2]:
logging.basicConfig(level=logging.INFO)
# Constants
checkpoint = "precomputed_model_checkpoint"
grow = (36, 36, 36)
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76)) + grow
input_size = # calculate the input size here
output_shape = gp.Coordinate((36, 36, 36)) + grow
output_size = # calculate the output size here
context = (input_size - output_size) / 2

In [3]:
# initialize model
unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)
logits = torch.nn.Conv3d(16, 4, (1,) * 3)
# add an operation to get probabilities (torch.nn.Softmax)

# build your semantic segmentation model here to get probabilities
semantic_model.eval()

In [4]:
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")

reference_request = gp.BatchRequest()
reference_request.add(raw, input_size)
reference_request.add(predictions, output_size)

source = gp.ZarrSource(
    filename="../data/MB-Z1213-56.zarr",
    datasets={
        raw: "TEST/raw",
    },
    array_specs={
        raw: gp.ArraySpec(interpolatable=True)
    }
)

In [5]:
with gp.build(source):
    total_input_roi = source.spec[raw].roi
total_output_roi = total_input_roi.grow(-context, -context)

zarr_container = zarr.open("predictions.zarr")
zarr_container.create_dataset(
    "volumes/predictions",
    data=np.zeros((4, *total_output_roi.get_shape()/voxel_size)),
    chunks=(4, *output_size)
)

In [6]:
pipeline = (
    source
    + gp.Normalize(raw)
    + gp.Unsqueeze([raw])
    + gp.Unsqueeze([raw])
    + gp.torch.Predict(
        model=semantic_model,
        inputs={"input": raw},
        outputs={0: predictions},
        checkpoint=checkpoint,
    )
    + gp.Squeeze([raw, predictions])
    + gp.Squeeze([raw])
    + gp.ZarrWrite(
        output_dir="./",
        output_filename="predictions.zarr",
        dataset_names={
            raw: "volumes/raw",
            predictions: "volumes/predictions",
        },
        dataset_dtypes={predictions: gp.ArraySpec(roi=total_output_roi)},
    )
    # Use the scan node to apply the torch model blockwise
)

request = gp.BatchRequest()
request[raw] = gp.ArraySpec(roi=total_input_roi)
request[predictions] = gp.ArraySpec(roi=total_output_roi)

In [7]:
with gp.build(pipeline):
    pipeline.request_batch(request)

In [8]:
# Post processing
results_zarr = zarr.open(f"predictions.zarr", "r+")
results = results_zarr["volumes/predictions"]

semantic_segmentation = np.argmax(results, axis=0)

results_zarr["volumes/semantic_segmentation"] = semantic_segmentation
results_zarr["volumes/semantic_segmentation"].attrs["offset"] = total_output_roi.get_offset()
results_zarr["volumes/semantic_segmentation"].attrs["resolution"] = voxel_size
