In [None]:
'''
Segment images using a trained CNN 
with @jkimmel's fork of @vanvalen's DeepCell

DeepCell transfers weights from vanilla CNNs trained on small receptive
field sized patches to models employing atrous kernels to allow for
segmentation of full sized images without patchwise classification.

This notebook outlines how to transfer weights from a trained vanilla CNN
to the corresponding atrous kernel network for segmentation.
NOTE: segmentImages.py wraps this process in a CLI.
'''

import h5py
import tifffile as tiff
from keras.backend.common import _UID_PREFIXES
import os
import numpy as np
import argparse
from cnn_functions import nikon_getfiles, get_image, run_models_on_directory, get_image_sizes, dice_jaccard_indices

# NOTE: The 'sparse' model with atrous kernels you employ for segmentation
# must mirror the structure of the vanilla CNN layer-for-layer
# The only difference should be the use of atrous kernels
# In model zoo, atrous kernel networks are merely prefaced with 'sparse_'
# for the corresponding vanilla CNN
from model_zoo import sparse_bn_feature_net_81x81 as fnet

In [None]:
# specify directories
direc_name = '/path/to/images/for/segmentation/'
seg_location = '/path/to/save/segmentation/outputs/'
# specify channel and feature names
channel_names = ['DIC']
feature_names = ['feature0', 'feature1']
# specify locations and prefixes of trained network weights
trained_network_dir = '/path/to/trained/models/'
# prefix of network weight filenames
net_prefix = 'trained_network_' 
# number of trained networks to ensemble
nb_networks = 1

In [None]:
# set window sizes
window_x = 40
window_y = 40
# set the image size
sz = list(get_image_sizes(data_location, channel_names[0]))
image_size_x = sz[0]
image_size_y = sz[1]

In [None]:
# create a list of model weight paths from all trained networks
list_of_cyto_weights = []
for j in range(nb_networks):
    cyto_weights = os.path.join(trained_network_dir, net_prefix + str(j) + ".h5")
    list_of_cyto_weights.append(cyto_weights)

In [None]:
# the 'split' parameter can be used to split
# images that are too large to fit on the GPU all at once
# into smaller pieces, then stitch them back together
# options = {0 : no split, 1 : split into quarters, 2 : split to sixteenths}
split = 0

# run models
predictions = run_models_on_directory(
        data_location=data_location,
        channel_names=channel_names,
        output_location=seg_location,
        model_fn = fnet,
        list_of_weights = list_of_cyto_weights,
        n_features = 2,
        image_size_x = image_size_x,
        image_size_y = image_size_y,
        win_x = window_x,
        win_y = window_y,
        split = split)