In [3]:
from funlib.learn.torch.models import UNet, ConvPass
import gunpowder as gp
import logging
import numpy as np
import torch
import zarr

logging.basicConfig(level=logging.INFO)

checkpoint = 100000
predict_roi = gp.Roi((250, 2000, 500, 500), (1, 160, 1000, 1000))

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

input_shape = (7, 64, 124, 124)
output_shape = (1, 32, 32, 32)

unet = UNet(
    in_channels=7,
    num_fmaps=12,
    fmap_inc_factor=5,
    downsample_factors=[
        (1, 2, 2),
        (1, 2, 2),
        (2, 2, 2)],
    constant_upsample=True,
    padding='valid')
model = torch.nn.Sequential(
    unet,
    ConvPass(12, 1, [(1, 1, 1)], activation=None),
    torch.nn.Sigmoid()
)
model.eval()

raw_source = gp.ZarrSource(
    '../data/mouse.n5',
    {raw: 'volumes/raw'},
    {raw: gp.ArraySpec(interpolatable=True)}
)

with gp.build(raw_source):
    print(raw_source)
    voxel_size = raw_source.spec[raw].voxel_size

input_size = voxel_size*input_shape
output_size = voxel_size*output_shape

scan_request = gp.BatchRequest()
scan_request.add(raw, input_size)
scan_request.add(prediction, output_size)

pipeline = (
    raw_source +
    gp.Pad(raw, None) +
    gp.Normalize(raw, dtype=np.float32) +
    gp.Stack(1) +
    gp.torch.Predict(
        model,
        inputs={
            'input': raw
        },
        outputs={
            0: prediction
        },
        checkpoint=f'model_checkpoint_{checkpoint}',
        spawn_subprocess=True
    ) +
    gp.Squeeze([raw, prediction]) +
    gp.ZarrWrite(
        output_filename=f'prediction_{checkpoint}.zarr',
        dataset_names={
            prediction: 'prediction'
        }
    ) +
    gp.Scan(scan_request, num_workers=8)
)

total_request = gp.BatchRequest()
total_request[prediction] = predict_roi

# prepare the output dataset
f = zarr.open(f'prediction_{checkpoint}.zarr', 'w')
ds = f.create_dataset(
    'prediction',
    shape=predict_roi.get_shape()/voxel_size,
    chunks=output_shape, dtype=np.float32)
ds.attrs['offset'] = predict_roi.get_begin()
ds.attrs['resolution'] = voxel_size

with gp.build(pipeline):
    batch = pipeline.request_batch(total_request)

ZarrSource[../data/mouse.n5], providing: 
	RAW: ROI: [225:275, 0:4940, 0:2048, 0:2169] (50, 4940, 2048, 2169), voxel size: (1, 5, 1, 1), interpolatable: True, non-spatial: False, dtype: uint16, placeholder: False



INFO:gunpowder.nodes.scan:scanning over 1024 chunks
INFO:gunpowder.torch.nodes.predict:Predicting on gpu
  0%|          | 0/1024 [00:00<?, ?it/s]INFO:gunpowder.nodes.scan:allocating array of shape (1, 32, 1000, 1000) for PREDICTION
100%|██████████| 1024/1024 [04:44<00:00,  3.60it/s]
INFO:gunpowder.producer_pool:terminating workers...
INFO:gunpowder.producer_pool:joining workers...
INFO:gunpowder.producer_pool:done
INFO:gunpowder.producer_pool:terminating workers...
INFO:gunpowder.producer_pool:joining workers...
INFO:gunpowder.producer_pool:done
