# Inference
As for training, we use `gunpowder` for predicting a larger area.
## Exercise 1
Your task is to run the script twice, one time with datasetsize set to small, and one to big. With this we get two affinity prediction that we can play with in `03_agglomeration.ipynb`

## Exercise 2
1. While you are waiting for the prediction to finish, you can calculate how many chunks needed to be predicted, if we were to process an entire fly brain
2. After the prediction, you can make an estimate for how long it would take, if we only had a single gpu available

In [None]:
from __future__ import print_function
from gunpowder import *
from gunpowder.tensorflow import *
import json
import numpy as np
import os
import sys
import logging
import daisy
import os
import time


def predict(iteration, in_file, out_file, setup_dir,
            out_dataset, datasetsize='small'):

    with open(os.path.join(setup_dir, 'train_net.json'), 'r') as f:
        config = json.load(f)

    # voxels
    voxel_size = Coordinate((40, 4, 4))
    input_shape = Coordinate(config['input_shape'])
    output_shape = Coordinate(config['output_shape'])
    context = (input_shape - output_shape) // 2
    if datasetsize == 'big':
        offset = (50, 1400, 1400) * np.array(voxel_size)  # avoid big glia and cell body in sample 0
        chunkgrid = [2, 12, 12]

    elif datasetsize == 'small':
        offset = (90, 1600, 1400) * np.array(voxel_size)  # avoid big glia
        chunkgrid = [1, 5, 5]



    # Initial ROI
    # got nice offset visually from neuroglancer : 1507, 1678, 100
    roi = Roi(
        offset=offset,
        shape=output_shape * voxel_size * Coordinate(chunkgrid),
    )


    # nm
    context_nm = context * voxel_size
    read_roi = roi.copy()
    read_roi = read_roi.grow(context_nm, context_nm)


    # read_roi = read_roi.snap_to_grid(input_shape * voxel_size)
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    output_roi = read_roi.grow(-context_nm, -context_nm)
    print("Read ROI in nm is %s" % read_roi)
    print("Output ROI in nm is %s" % output_roi)

    print("Read ROI in voxel space is {}".format(read_roi / voxel_size))
    print("Output ROI in voxel space is {}".format(output_roi / voxel_size))

    raw = ArrayKey('RAW')
    affs = ArrayKey('AFFS')

    output_roi = daisy.Roi(
        output_roi.get_begin(),
        output_roi.get_shape()
    )

    # TODO: Introduces daisy dependency, does that work without ?
    # Also, daisy.ROI and gunpowder.Roi have different behaviour, important
    # source of confusion. Prepare_ds only works with daisy.Roi, while
    # gunpowder node eg. Crop only works with gunpowder.Roi
    ds = daisy.prepare_ds(
        out_file,
        out_dataset,
        output_roi,
        voxel_size,
        'float32',
        # write_size=output_size,
        write_roi=daisy.Roi((0, 0, 0), output_size),
        num_channels=3,
        # temporary fix until
        # https://github.com/zarr-developers/numcodecs/pull/87 gets approved
        # (we want gzip to be the default)
        compressor={'id': 'gzip', 'level': 5}
    )

    chunk_request = BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(affs, output_size)

    pipeline = (
            N5Source(
                in_file,
                datasets={
                    raw: 'volumes/raw'
                },
            ) +
            Pad(raw, size=None) +
            Crop(raw, read_roi) +
            Normalize(raw) +
            IntensityScaleShift(raw, 2, -1) +
            Predict(
                os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
                inputs={
                    config['raw']: raw
                },
                outputs={
                    config['affs']: affs
                },
                # TODO: change to predict graph
                graph=os.path.join(setup_dir, 'train_net.meta')
            ) +
            IntensityScaleShift(raw, 0.5, 0.5) +  # Just for visualization.
            ZarrWrite(
                dataset_names={
                    affs: out_dataset,
                    raw: 'volumes/raw',
                },
                output_filename=out_file
            ) +  # TODO: Would be nice to have a consistent file format (eg. only n5)
            PrintProfilingStats(every=10) +
            Scan(chunk_request)
    )
    start_time = time.time()
    print("Starting prediction...")
    with build(pipeline):
        pipeline.request_batch(BatchRequest())
    print("Prediction finished in {:0.2f}".format(time.time() - start_time))

In [None]:
logging.basicConfig(level=logging.INFO)
logging.getLogger('gunpowder.nodes.hdf5like_write_base').setLevel(
    logging.DEBUG)

in_file = '../jan/segmentation/data/sample_0.n5'  # This is our raw file
setup_dir = '../../jan/segmentation/snapshots/setup58_p/' # This is the pretrained model. Change
# this to the train directory.

out_dataset = 'volumes/affs'
datasetsize = 'big' # choose big or small
iteration = 500000

out_file = 'affinities_{}_{:05}.zarr'.format(datasetsize, iteration)

predict(
    iteration,
    in_file,
    out_file,
    setup_dir,
    out_dataset,
    datasetsize=datasetsize)