## Please download a pre trained model from
`https://www.dropbox.com/s/tir8ob3q67p79fj/precomputed_model_checkpoint?dl=0`

In [3]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import zarr
import numpy as np
import torch

import logging

In [6]:
logging.basicConfig(level=logging.INFO)
# Constants
checkpoint = "precomputed_model_checkpoint"
grow = (36, 36, 36)
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76)) + grow
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36)) + grow
output_size = output_shape * voxel_size
context = (input_size - output_size) / 2

In [3]:
# adapt the training model to output probabilities rather than logprobabilities
semantic_model.eval()

In [7]:
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")

reference_request = gp.BatchRequest()
reference_request.add(raw, input_size)
reference_request.add(predictions, output_size)

# Create your ZarrSource

In [10]:
with gp.build(source):
    total_input_roi = source.spec[raw].roi
total_output_roi = total_input_roi.grow(-context, -context)

zarr_container = zarr.open("predictions.zarr")
zarr_container.create_dataset(
    "volumes/predictions",
    data=np.zeros((4, *total_output_roi.get_shape()/voxel_size)),
    chunks=(4, *output_size)
)

<zarr.core.Array '/volumes/predictions' (4, 360, 360, 360) float64>

In [11]:
pipeline = (
    source
    # make sure your data is normalized and has both a batch and a channel dimension
    + gp.torch.Predict(
        model=semantic_model,
        inputs={"input": raw},
        outputs={0: predictions},
        checkpoint=checkpoint,
    )
    # make sure you remove the unneeded batch/channel dimensions
    + gp.ZarrWrite(
        output_dir="./",
        output_filename="predictions.zarr",
        dataset_names={
            raw: "volumes/raw",
            predictions: "volumes/predictions",
        },
        dataset_dtypes={predictions: gp.ArraySpec(roi=total_output_roi)},
    )
    # Add a scan node to predict accross the whole volume
)

request = gp.BatchRequest()
request[raw] = gp.ArraySpec(roi=total_input_roi)
request[predictions] = gp.ArraySpec(roi=total_output_roi)

NameError: name 'semantic_model' is not defined

In [7]:
with gp.build(pipeline):
    pipeline.request_batch(request)

In [8]:
# Post processing
results_zarr = zarr.open(f"predictions.zarr", "r+")
results = results_zarr["volumes/predictions"]

semantic_segmentation = np.argmax(results, axis=0)

results_zarr["volumes/semantic_segmentation"] = semantic_segmentation
results_zarr["volumes/semantic_segmentation"].attrs["offset"] = total_output_roi.get_offset()
results_zarr["volumes/semantic_segmentation"].attrs["resolution"] = voxel_size
