# Prepare CREMI Data

In this notebook, we will project the CREMI ground-truth neuron segmentation to (watershed) fragments.
These will be stored in the bigcat CREMI format, to enable proofreading.

The specific fragments used in this notebook were generated with a different CNN
than the base segmentation of the original CREMI ground-truth.
Because since CREMI was created 2 - 3 years ago, the networks have improved due to more training data and advances in design.
(Also I don't know where to find original CREMI network predictions and fragments).

Prerequisites to run this notebook:
Besides standard python packages, we will need vigra, nifty and z5py to run the code.
The easiest way to install it is in a clean conda environment:

`conda create -n fix-cremi -c conda-forge -c ilastik-forge -c cpape python=3.6  nifty z5py jupyter vigra h5py`

`source activate fix-cremi`

In [None]:
# import all necessary packages
# this might throw a runtime error, that can be ignored 
import sys
import numpy as np
import h5py
import z5py
import vigra
import nifty.graph.rag as nrag

## Create output file

We assume the following input data:
- HDF5 file with the (padded and realigned) cremi data, containing raw data ('/volumes/raw'), ground-truth segmentation ('/volumes/neuron_ids') and a mask ('/volumes/labels/mask') indicating the actually labeled region.
- N5 file containing the affinity predictions. For the cremi test samples, these predictions can be found in `/groups/saalfeld/home/papec/Work/neurodata_hdd/cremi_warped/sample*.n5`, `predictions/affs_glia`.

The notebook will produce a HDF5 file with raw data, fragments and look-up table from segments to 
projected groundtruth segments, to be ingested by bigcat.
The fragments will be reduced to the relevant bounding box.

In [None]:
# calculate the bounding box around the parts of the ground-truth
# that is actually labeled
def find_bounding_box_and_offset(path, mask_key):
    # load the mask that indicates labeled parts of the data
    with h5py.File(path, 'r') as f:
        mask = f[mask_key][:]

    # find the coordinates that are in the mask
    coordinates = np.where(mask == 1)
    # extract min and max masked coordinates
    min_coords = [np.min(coords) for coords in coordinates]
    max_coords = [np.max(coords) for coords in coordinates]
    
    # construct the bounding box
    bb = tuple(slice(minc, maxc + 1) for minc, maxc in zip(min_coords, max_coords))

    # compute the offset in nanometer for the bigcat format
    resolution = (40., 4., 4.)
    offset = tuple(b.start * res for b, res in zip(bb, resolution))
    return bb, offset

In [None]:
# specifiy all relevant paths
# for this example we use the mala training data,
# change paths for the actual test data accordingly

# the sample we will process
sample = 'A'

# path and key to the cremi h5 file with groundtruth, raw data and groundtruth mask
path_gt = '/home/papec/mnt/papec/20170312_mala_v2/sample_%s.augmented.0.hdf' % sample
key_gt = 'volumes/labels/neuron_ids'
key_raw = 'volumes/raw'
key_mask = 'volumes/labels/mask'

# path and key to the n5 file with affinity predictions
path_affs = '/home/papec/mnt/papec/20170312_mala_v2/affs/sample_%s_affs.n5' % sample
key_affs = 'predictions_mala'

# the path to the output h5 file
path_out = '/home/papec/mnt/papec/sample%s_fix_test.h5' % sample
key_fragments = 'volumes/labels/fragments'

# get the relevant bounding box
bb, offset = find_bounding_box_and_offset(path_gt, key_mask)
print("Groundtruth bounding box for sample", sample)
print(bb)
print("Corresponding to offset in nanomenters:")
print(offset)

In [None]:
# create the output dataset and copy the raw data
def create_output_file(in_path, out_path, key):
    # check if we have the raw data in our out file already
    with h5py.File(out_path) as f:
        have_raw = key in f
    
    # if we do, we don't need to do anything here
    if have_raw:
        return
    
    # otherwise, load the raw data and write it to the out file
    with h5py.File(in_path) as f:
        raw = f[key][:]
    with h5py.File(out_path) as f:
        ds = f.create_dataset(key, data=raw, compression='gzip')
        ds.attrs['resolution'] = [40., 4., 4.]
        ds.attrs['offset'] = [0., 0., 0.]
        # bigcat format
        f.attrs['file_format'] = 0.2

In [None]:
create_output_file(path_gt, path_out, key_raw)

## Make watershed fragments

We produce fragments using a 2d watershed on the averaged in-plane affinity predictions. For seeds, we
use local maxima of its distance transform.
This will result in 2d fragments, that are appropriate for anisotropic data.

To run the function below, please download the `cremi_tools` repository:
https://github.com/constantinpape/cremi_tools and append it to the pythonpath.

In [None]:
# function to redo the watersheds 
# to use this function, we need the additional `cremi_tools` repository:

# it can be simply added to the pythonpath
def run_watersheds(affinity_path, affinity_key,
                   mask_path, mask_key, bounding_box):
    # change to path where cremi_tools is located
    sys.path.append('/home/papec/Work/my_projects/cremi_tools')
    import cremi_tools.segmentation as cseg
    
    # load the xy affinities, average them over the xy-channels
    # (we assume that xy correspond to channels 1 and 2 !)
    f = z5py.File(affinity_path)
    affs = f[affinity_key][(slice(1,3),) + bounding_box]

    # convert affinities from 8bit to float if necessary
    if affs.dtype == np.dtype('uint8'):
        affs = affs.astype('float32') / 255.
    # invert and average affinities to obtain height map for the 2d watersheds
    hmap = np.mean(1. - affs, axis=0)

    # load the mask to make sure watersheds adhere to the ground-truth boundaries
    with h5py.File(mask_path) as f_mask:
        mask = f_mask[mask_key][bounding_box].astype('bool')

    # run the distance transform watersheds in 2d
    # with threshold 0.4, sigma for smoothing 1.6 and minimum fragment size 20
    ws = cseg.DTWatershed(0.4, 1.6, size_filter=20, is_anisotropic=True)
    fragments, _ = ws(hmap, mask)
    return fragments.astype('uint64')

In [None]:
# create and save the fragments if necessary, otherwise load them from file
with h5py.File(path_out) as f:
    have_fragments = key_fragments in f
    if have_fragments:
        # load fragments from file
        print("Fragments are already present, will load from file")
        frags = f[key_fragments][:]
    else:
        print("Fragments are not present, will compute and save")
        # create fragments
        frags = run_watersheds(path_affs, key_affs, 
                               path_gt, key_mask, bb)
        # write fragments to file and save necessary attributes
        print("Computation done, saving fragments to", path_out, key_fragments)
        ds = f.create_dataset(key_fragments, data=frags, compression='gzip')
        ds.attrs['offset'] = offset
        ds.attrs['resolution'] = [40., 4., 4.]

In [None]:
# plot a slice of raw data overlaid with the fragments (or any other segmentation)
def plot_segmentation(raw_path, raw_key, fragments, bounding_box,
                      slice_id=0, alpha=0.6):
    import matplotlib.pyplot as plt
    from skimage import img_as_float, color
    
    # load the slice from the raw data
    with h5py.File(raw_path) as f:
        bb_im = (slice_id + bounding_box[0].start,) + bounding_box[1:]
        raw_im = f[raw_key][bb_im]
    
    # relabel this slice to make our live easier
    frag_im = fragments[slice_id].copy()
    vigra.analysis.relabelConsecutive(frag_im, out=frag_im,
                                      start_label=1, keep_zeros=True)
    
    # convert the grayscale raw data to rgb
    im = img_as_float(raw_im)
    img_color = np.dstack((im, im, im))
    # create random colors for fragments
    
    n_fragments = int(frag_im.max()) + 1
    random_colors = np.random.rand(n_fragments, 3)
    
    # create the color mask
    color_mask = np.zeros_like(img_color)
    # we skip 0 (ignore label)
    for frag_id in range(1, n_fragments):
        color_mask[frag_im == frag_id, :] = random_colors[frag_id]
    
    # convert raw and fragments to hsv images
    im_hsv = color.rgb2hsv(img_color)
    mask_hsv = color.rgb2hsv(color_mask)
    
    # replace hue and saturation of the raw data
    # with that of color mask
    im_hsv[..., 0] = mask_hsv[..., 0]
    im_hsv[..., 1] = mask_hsv[..., 1] * alpha
    
    im_colored = color.hsv2rgb(im_hsv)
    f, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im_colored)
    plt.show()

In [None]:
# plot the fragments for one slice
plot_segmentation(path_gt, key_raw, frags, bb, slice_id=0)

## Project ground-truth to fragments

Finally, we project the ground-truth segmentation to the fragments
we have just produced, via maximum overlap projection.

In [None]:
# project the cremi groundtruth to the watershed fragments
# TODO correct cremi-ignore label ?!
def project_gt_to_fragments(path, gt_key, fragments,
                            bounding_box,
                            cremi_ignore=0xfffffffffffffffd):
    # load the CREMI groundtruth from h5 
    with h5py.File(path) as f:
        gt = f[gt_key][bounding_box]
    assert gt.shape == fragments.shape

    # build region adjacency graph and us it to
    # extract th assignment of groundtruth labels
    # to fragments
    rag = nrag.gridRag(fragments.astype('uint32'))
    assignment = nrag.gridRagAccumulateLabels(rag, gt).astype('uint64')
    
    # assign ignore label to 0 and relabel the assignment ids
    if 0 in assignment:
        assignment += 1
    assignment[0] = 0
    vigra.analysis.relabelConsecutive(assignment, out=assignment,
                                      start_label=1, keep_zeros=True)

    # add the number of fragments as offset to the assignments, because
    # in bigcat fragment ids are also segment ids
    n_fragments = rag.numberOfNodes
    assignment[1:] += n_fragments
    # find the next valid segment id
    next_id = int(assignment.max() + 1)

    # build the correct lut format for bigcat (fragment ids are consecutive, starting at 0)
    lut = np.array([(frag_id, seg_id)
                    for frag_id, seg_id in enumerate(assignment)],
                   dtype='uint64')
    lut = lut.transpose()

    # set the masked area to the cremi ignore value
    lut[1, 0] = cremi_ignore

    # map the assignments to a volumetric segmentation for inspection purposes
    projected = nrag.projectScalarNodeDataToPixels(rag, assignment)
    return lut, next_id, projected

In [None]:
# get the segment to fragment lut
lut, next_id, projected = project_gt_to_fragments(path_gt, key_gt, frags, bb)

In [None]:
# plot the projection for a slice
plot_segmentation(path_gt, key_raw, projected, bb, slice_id=0)

In [None]:
# serialize the lut and next-id
with h5py.File(path_out) as f:
    f.create_dataset('fragment_segment_lut', data=lut, chunks=True, maxshape=(2, None))
    f.attrs['next_id'] = int(next_id)