# Specify Network Architecture
## mknet.py

In [2]:
from funlib.learn.tensorflow import models
import malis
import tensorflow as tf
import json

def create_network(input_shape, name):

    tf.reset_default_graph()

    with tf.variable_scope('setup0'):

        raw = tf.placeholder(tf.float32, shape=input_shape)
        raw_batched = tf.reshape(raw, (1, 1) + input_shape)

        unet, _, _ = models.unet(
                raw_batched,
                12,
                5,
                [[1,3,3],[1,3,3],[3,3,3]])

        affs_batched, _ = models.conv_pass(
            unet,
            kernel_sizes=[1],
            num_fmaps=3,
            activation='sigmoid',
            name='affs')

        output_shape_batched = affs_batched.get_shape().as_list()
        output_shape = output_shape_batched[1:] # strip the batch dimension

        affs = tf.reshape(affs_batched, output_shape)

        gt_affs = tf.placeholder(tf.float32, shape=output_shape)
        affs_loss_weights = tf.placeholder(tf.float32, shape=output_shape)

        loss = tf.losses.mean_squared_error(
            gt_affs,
            affs,
            affs_loss_weights)

        summary = tf.summary.scalar('setup0', loss)

        opt = tf.train.AdamOptimizer(
            learning_rate=0.5e-4,
            beta1=0.95,
            beta2=0.999,
            epsilon=1e-8)
        optimizer = opt.minimize(loss)

        output_shape = output_shape[1:]
        print("input shape : %s"%(input_shape,))
        print("output shape: %s"%(output_shape,))

        tf.train.export_meta_graph(filename=name + '.meta')

        config = {
            'raw': raw.name,
            'affs': affs.name,
            'gt_affs': gt_affs.name,
            'affs_loss_weights': affs_loss_weights.name,
            'loss': loss.name,
            'optimizer': optimizer.name,
            'input_shape': input_shape,
            'output_shape': output_shape,
            'summary': summary.name
        }

        config['outputs'] = {'affs': {"out_dims": 3, "out_dtype": "uint8"}}

        with open(name + '.json', 'w') as f:
            json.dump(config, f)

if __name__ == "__main__":

    z=18
    xy=162

    create_network((84, 268, 268), 'train_net')
    create_network((96+z, 484+xy, 484+xy), 'config')


Creating U-Net layer 0
f_in: (1, 1, 84, 268, 268)
number of variables added: 4236, new total: 4236
    Creating U-Net layer 1
    f_in: (1, 12, 80, 88, 88)
    number of variables added: 116760, new total: 120996
        Creating U-Net layer 2
        f_in: (1, 60, 76, 28, 28)
        number of variables added: 2916600, new total: 3037596
            Creating U-Net layer 3
            f_in: (1, 300, 24, 8, 8)
            bottom layer
            f_out: (1, 1500, 20, 4, 4)
            number of variables added: 72903000, new total: 75940596
        g_out: (1, 1500, 20, 4, 4)
        g_out_upsampled: (1, 300, 60, 12, 12)
        f_left_cropped: (1, 300, 60, 12, 12)
        f_right: (1, 600, 60, 12, 12)
        f_out: (1, 300, 56, 8, 8)
        number of variables added: 19440900, new total: 95381496
    g_out: (1, 300, 56, 8, 8)
    g_out_upsampled: (1, 60, 56, 24, 24)
    f_left_cropped: (1, 60, 56, 24, 24)
    f_right: (1, 120, 56, 24, 24)
    f_out: (1, 60, 52, 20, 20)
    number of v

# Train Network

In [None]:
from __future__ import print_function
#import multiprocessing
#multiprocessing.set_start_method('forkserver')
import sys
from gunpowder import *
from gunpowder.tensorflow import *
import os
import math
import json
import tensorflow as tf
import numpy as np
import logging

logging.basicConfig(level=logging.INFO)

data_dir = '../data'

samples = [
    'sample_A',
    'sample_B',
    'sample_C'
]

neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

def train_until(max_iteration):

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= max_iteration:
        return

    with open('train_net.json', 'r') as f:
        config = json.load(f)

    raw = ArrayKey('RAW')
    labels = ArrayKey('GT_LABELS')
    labels_mask = ArrayKey('GT_LABELS_MASK')
    affs = ArrayKey('PREDICTED_AFFS')
    gt = ArrayKey('GT_AFFINITIES')
    gt_mask = ArrayKey('GT_AFFINITIES_MASK')
    gt_scale = ArrayKey('GT_AFFINITIES_SCALE')
    affs_gradient = ArrayKey('AFFS_GRADIENT')

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(config['input_shape'])*voxel_size
    output_size = Coordinate(config['output_shape'])*voxel_size
    context = output_size/2
    print('CONTEXT: ', context)

    request = BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_mask, output_size)
    request.add(gt, output_size)
    request.add(gt_mask, output_size)
    request.add(gt_scale, output_size)

    snapshot_request = BatchRequest({
        affs: request[gt],
        affs_gradient: request[gt]
    })

    data_sources = tuple(
        N5Source(
            os.path.join(data_dir, sample + '.n5'),
            datasets = {
                raw: 'volumes/raw',
                labels: 'volumes/labels/neuron_ids',
                labels_mask: 'volumes/labels/mask',
            },
            array_specs = {
                raw: ArraySpec(interpolatable=True),
                labels: ArraySpec(interpolatable=False),
                labels_mask: ArraySpec(interpolatable=False)
            }
        ) +
        Normalize(raw) +
        Pad(labels, context) +
        Pad(labels_mask, context) +
        RandomLocation() +
        Reject(mask=labels_mask)
        for sample in samples
    )


    train_pipeline = (
        data_sources +
        RandomProvider() +
        ElasticAugment(
            control_point_spacing=[4,40,40],
            jitter_sigma=[0,2,2],
            rotation_interval=[0,math.pi/2.0],
            prob_slip=0.05,
            prob_shift=0.05,
            max_misalign=10,
            subsample=8) +
        SimpleAugment(transpose_only=[1, 2]) +
        IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        GrowBoundary(labels, labels_mask, steps=1, only_xy=True) +
        AddAffinities(
            neighborhood,
            labels=labels,
            affinities=gt,
            labels_mask=labels_mask,
            affinities_mask=gt_mask) +
        BalanceLabels(
            gt,
            gt_scale,
            gt_mask) +
        DefectAugment(
            raw,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            contrast_scale=0.5,
            axis=0) +
        IntensityScaleShift(raw, 2,-1) +
        PreCache(cache_size=40,
                 num_workers=10) +
        Train(
            'train_net',
            optimizer=config['optimizer'],
            loss=config['loss'],
            inputs={
                config['raw']: raw,
                config['gt_affs']: gt,
                config['affs_loss_weights']: gt_scale,
            },
            outputs={
                config['affs']: affs
            },
            gradients={
                config['affs']: affs_gradient
            },
            summary=config['summary'],
            log_dir='log',
            save_every=10000) +
        IntensityScaleShift(raw, 0.5, 0.5) +
        Snapshot({
                raw: 'volumes/raw',
                labels: 'volumes/labels/neuron_ids',
                gt: 'volumes/gt_affinities',
                affs: 'volumes/pred_affinities',
                gt_mask: 'volumes/labels/gt_mask',
                labels_mask: 'volumes/labels/mask',
                affs_gradient: 'volumes/affs_gradient'
            },
            dataset_dtypes={
                labels: np.uint64
            },
            every=1000,
            output_filename='batch_{iteration}.hdf',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=10)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration - trained_until):
            b.request_batch(request)
    print("Training finished")

if __name__ == "__main__":

    iteration = 500000
    train_until(iteration)

INFO:gunpowder.tensorflow.local_server:Creating local tensorflow server
INFO:gunpowder.tensorflow.local_server:Server running at b'grpc://localhost:36769'
INFO:gunpowder.tensorflow.nodes.train:Initializing tf session, connecting to b'grpc://localhost:36769'...


CONTEXT:  (960, 112, 112)
Starting training...


INFO:gunpowder.tensorflow.nodes.train:Reading meta-graph...


Instructions for updating:
Colocations handled automatically by placer.


Instructions for updating:
Colocations handled automatically by placer.
INFO:gunpowder.tensorflow.nodes.train:No checkpoint found
