## Please download a pre trained model from
`https://www.dropbox.com/s/tir8ob3q67p79fj/precomputed_model_checkpoint?dl=0`

In [1]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import zarr
import numpy as np
import torch

import logging

In [2]:
logging.basicConfig(level=logging.INFO)
# Constants
checkpoint = "precomputed_model_checkpoint"
grow = (36, 36, 36)
voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76)) + grow
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36)) + grow
output_size = output_shape * voxel_size
context = (input_size - output_size) / 2

In [3]:
# initialize model
unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)
logits = torch.nn.Conv3d(16, 4, (1,) * 3)
probs = torch.nn.Softmax()

semantic_model = torch.nn.Sequential(unet, logits, probs)
semantic_model.eval()

Sequential(
  (0): UNet(
    (l_conv): ModuleList(
      (0): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (1): ReLU()
          (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (3): ReLU()
        )
      )
      (1): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(16, 80, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (1): ReLU()
          (2): Conv3d(80, 80, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (3): ReLU()
        )
      )
      (2): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(80, 400, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (1): ReLU()
          (2): Conv3d(400, 400, kernel_size=(3, 3, 3), stride=(1, 1, 1))
          (3): ReLU()
        )
      )
    )
    (l_down): ModuleList(
      (0): Downsample(
        (down): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
      )
   

In [4]:
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")

reference_request = gp.BatchRequest()
reference_request.add(raw, input_size)
reference_request.add(predictions, output_size)

source = gp.ZarrSource(
    filename="../data/MB-Z1213-56.zarr",
    datasets={
        raw: "TEST/raw",
    },
    array_specs={
        raw: gp.ArraySpec(interpolatable=True)
    }
)

In [5]:
with gp.build(source):
    total_input_roi = source.spec[raw].roi
total_output_roi = total_input_roi.grow(-context, -context)

f = zarr.open(f'predictions.zarr', 'w')
ds = f.create_dataset(
    "volumes/predictions",
    shape=total_output_roi.get_shape()/voxel_size,
    chunks=output_shape, dtype=np.float32)
ds.attrs['offset'] = total_output_roi.get_begin()
ds.attrs['resolution'] = voxel_size

INFO:daisy.datasets:Reusing existing dataset


<daisy.array.Array at 0x7fa8249f5880>

In [6]:
pipeline = (
    source
    + gp.Normalize(raw)
    + gp.Unsqueeze([raw])
    + gp.Unsqueeze([raw])
    + gp.torch.Predict(
        model=semantic_model,
        inputs={"input": raw},
        outputs={0: predictions},
        checkpoint=checkpoint,
    )
    + gp.Squeeze([raw, predictions])
    + gp.Squeeze([raw])
    + gp.ZarrWrite(
        output_dir="./",
        output_filename="predictions.zarr",
        dataset_names={
            raw: "volumes/raw",
            predictions: "volumes/predictions",
        },
        dataset_dtypes={predictions: gp.ArraySpec(roi=total_output_roi)},
    )
    + gp.Scan(reference_request, num_workers=3)
)

request = gp.BatchRequest()
request[raw] = gp.ArraySpec(roi=total_input_roi)
request[predictions] = gp.ArraySpec(roi=total_output_roi)

In [7]:
with gp.build(pipeline):
    pipeline.request_batch(request)

INFO:gunpowder.nodes.scan:scanning over 125 chunks
  0%|          | 0/125 [00:00<?, ?it/s]INFO:gunpowder.torch.nodes.predict:Predicting on gpu
INFO:gunpowder.torch.nodes.predict:Predicting on gpu
INFO:gunpowder.torch.nodes.predict:Predicting on gpu
  input = module(input)
  input = module(input)
  input = module(input)
INFO:gunpowder.nodes.scan:allocating array of shape (400, 400, 400) for RAW
INFO:gunpowder.nodes.scan:allocating array of shape (4, 360, 360, 360) for PREDICTIONS
100%|██████████| 125/125 [01:31<00:00,  1.37it/s]INFO:gunpowder.producer_pool:terminating workers...
INFO:gunpowder.producer_pool:joining workers...

INFO:gunpowder.producer_pool:done


In [8]:
# Post processing
results_zarr = zarr.open(f"predictions.zarr", "r+")
results = results_zarr["volumes/predictions"]

semantic_segmentation = np.argmax(results, axis=0)

results_zarr["volumes/semantic_segmentation"] = semantic_segmentation
results_zarr["volumes/semantic_segmentation"].attrs["offset"] = total_output_roi.get_offset()
results_zarr["volumes/semantic_segmentation"].attrs["resolution"] = voxel_size
