In [None]:
# TMP
import sys; sys.path.append('..'); del sys

import nobrainer

In [None]:
# A glob pattern to match the files we want to train on.
file_pattern = 'tfrecords/data_shard-*.tfrecords'

# The number of classes the model predicts. A value of 1 means the model performs
# binary classification (i.e., target vs background).
n_classes = 1

# Batch size is the number of features and labels we train on with each step.
batch_size = 4

# The shape of the original volumes.
volume_shape = (256, 256, 256)

# The shape of the non-overlapping sub-volumes. Most models cannot be trained on
# full volumes because of hardware and memory constraints, so we train and evaluate
# on sub-volumes.
block_shape = (64, 64, 64)

# Whether or not to apply random rigid transformations to the data on the fly.
# This can improve model generalizability but increases processing time.
augment = False

# The tfrecords filepaths will be shuffled before reading, but we can also shuffle
# the data. This will shuffle 10 volumes at a time. Larger buffer sizes will require
# more memory, so choose a value based on how much memory you have available.
shuffle_buffer_size = 0

# Number of parallel processes to use.
num_parallel_calls = 6

In [None]:
dataset = nobrainer.volume.get_dataset(
    file_pattern=file_pattern,
    n_classes=n_classes,
    batch_size=batch_size,
    volume_shape=volume_shape,
    block_shape=block_shape,
    augment=augment,
    n_epochs=None,
    shuffle_buffer_size=shuffle_buffer_size,
    num_parallel_calls=num_parallel_calls)

dataset

In [None]:
steps_per_epoch = nobrainer.volume.get_steps_per_epoch(
    n_volumes=10, 
    volume_shape=volume_shape, 
    block_shape=block_shape, 
    batch_size=batch_size)

steps_per_epoch

## Create model

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

from nobrainer.layers import VariationalConv3D


def meshnet(n_classes, input_shape, filters=21, activation='relu', batch_size=None, name='meshnet'):

    def one_layer(x, layer_num, dilation_rate=(1, 1, 1)):
        x = VariationalConv3D(filters, kernel_size=(3, 3, 3), padding='same', dilation_rate=dilation_rate, name='layer{}/conv3d'.format(layer_num))(x)
        x = layers.BatchNormalization(name='layer{}/batchnorm'.format(layer_num))(x)
        x = layers.Activation(activation, name='layer{}/activation'.format(layer_num))(x)
        # x = layers.Dropout(dropout_rate, name='layer{}/dropout'.format(layer_num))(x)
        return x

    inputs = layers.Input(shape=input_shape, batch_size=batch_size, name='inputs')

    x = one_layer(inputs, 1)
    x = one_layer(x, 2)
    x = one_layer(x, 3, dilation_rate=(2, 2, 2))
    x = one_layer(x, 4, dilation_rate=(4, 4, 4))
    x = one_layer(x, 5, dilation_rate=(8, 8, 8))
    x = one_layer(x, 6, dilation_rate=(16, 16, 16))
    x = one_layer(x, 7)

    x = VariationalConv3D(filters=n_classes, kernel_size=(1, 1, 1), padding='same', name='classification/conv3d')(x)

    final_activation = 'sigmoid' if n_classes <= 2 else 'softmax'
    x = layers.Activation(final_activation, name='classification/activation')(x)

    return tf.keras.Model(inputs=inputs, outputs=x, name=name)


model = meshnet(1, (*block_shape, 1))

In [None]:
model.compile(tf.keras.optimizers.Adam(1e-04), 'binary_crossentropy')

In [None]:
model.fit(dataset, steps_per_epoch=steps_per_epoch)

In [None]:
outputs = model.predict(dataset)

In [None]:
outputs.shape

In [None]:
def _maybe_get_layer_attr(model, weight_attribute):
    """Get the weight attribute"""
    weights = []
    for layer in model.layers:
        try:
            this_weight = getattr(layer, weight_attribute)
            weights.append(this_weight)
        except AttributeError:
            pass
    return weights

In [None]:
class VWLoss(tf.keras.losses.Loss):
    def __init__(self, model, other_loss_callable, priors=None, n_examples=1, only_kld=False):
        super(VWLoss, self).__init__(reduction=tf.losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, name='vwloss')
        self.model = model
        self.other_loss_callable = other_loss_callable
        self.priors = priors
        self.n_examples = n_examples
        self.only_kld = only_kld
    
    def call(self, y_true, y_pred):
        
        ms = _maybe_get_layer_attr(model, 'kernel_m') \
                + _maybe_get_layer_attr(model, 'bias_m')
        sigmas = _maybe_get_layer_attr(model, 'kernel_sigma') \
                + _maybe_get_layer_attr(model, 'bias_sigma')
        if self.priors is None:
            ms_prior = [
                tf.constant(1, dtype=v.dtype, shape=v.shape)
                for v in ms]
            sigmas_prior = [
                tf.constant(1, dtype=v.dtype, shape=v.shape)
                for v in ms]
        else:
            ms_prior = [
                tf.Variable(tf.convert_to_tensor(self.priors[1][i]), trainable=False) 
                for i, _ in enumerate(ms)]
            sigmas_prior = [
                tf.Variable(tf.convert_to_tensor(self.priors[1][i]), trainable=False)
                for i, _ in enumerate(ms)]
            
        nll_loss = self.other_loss_callable(y_true=y_true, y_pred=y_pred)
        l2_loss = tf.add_n(
            [
                tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) 
                for i, _ in enumerate(ms)], name='l2_loss')
        
        sigma_squared_loss = tf.add_n([tf.reduce_sum(tf.square(sigmas[i]) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(sigmas))],name = 'sigma_squared_loss')
        log_sigma_loss = tf.add_n([tf.reduce_sum(tf.log(v+1e-8)) for v in sigmas],name='log_sigmas_loss')
        
        if not self.only_kld:
            loss = nll_loss + (l2_loss + sigma_squared_loss - log_sigma_loss) / float(self.n_examples)
        else:
            mse_m_loss = tf.add_n([tf.reduce_sum(tf.square(ms[i] - ms_prior[i])) for i in range(len(ms))], name='mse_m_loss')
            mse_sigmas_loss = tf.add_n([tf.reduce_sum(tf.square(sigmas[i] - sigmas_prior[i])) for i in range(len(sigmas))], name='mse_sigmas_loss')
            loss = mse_m_loss + mse_sigmas_loss
        
        return loss

In [None]:
def variational_loss(model, other_loss_callable, priors=None, n_examples=1, only_kld=False):
    def l(y_true, y_pred):
        ms = _maybe_get_layer_attr(model, 'kernel_m') \
                + _maybe_get_layer_attr(model, 'bias_m')
        sigmas = _maybe_get_layer_attr(model, 'kernel_sigma') \
                + _maybe_get_layer_attr(model, 'bias_sigma')
        if priors is None:
            ms_prior = [
                tf.constant(1, dtype=v.dtype, shape=v.shape)
                for v in ms]
            sigmas_prior = [
                tf.constant(1, dtype=v.dtype, shape=v.shape)
                for v in ms]
        else:
            ms_prior = [
                tf.Variable(tf.convert_to_tensor(self.priors[1][i]), trainable=False) 
                for i, _ in enumerate(ms)]
            sigmas_prior = [
                tf.Variable(tf.convert_to_tensor(self.priors[1][i]), trainable=False)
                for i, _ in enumerate(ms)]
            
        nll_loss = tf.reduce_mean(other_loss_callable(y_true=y_true, y_pred=y_pred), axis=0)
        l2_loss = tf.add_n(
            [
                tf.reduce_sum((tf.square(ms[i] - ms_prior[i])) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) 
                for i, _ in enumerate(ms)], name='l2_loss')
        
        sigma_squared_loss = tf.add_n([tf.reduce_sum(tf.square(sigmas[i]) / ((tf.square(sigmas_prior[i]) + 1e-8) * 2.0)) for i in range(len(sigmas))], name='sigma_squared_loss')
        log_sigma_loss = tf.add_n([tf.reduce_sum(tf.log(v+1e-8)) for v in sigmas],name='log_sigmas_loss')
        
        if not only_kld:
            loss = nll_loss + (l2_loss + sigma_squared_loss - log_sigma_loss) / float(n_examples)
        else:
            mse_m_loss = tf.add_n([tf.reduce_sum(tf.square(ms[i] - ms_prior[i])) for i in range(len(ms))], name='mse_m_loss')
            mse_sigmas_loss = tf.add_n([tf.reduce_sum(tf.square(sigmas[i] - sigmas_prior[i])) for i in range(len(sigmas))], name='mse_sigmas_loss')
            loss = mse_m_loss + mse_sigmas_loss
        
        return loss

    return l

In [None]:
model.compile(tf.keras.optimizers.Adam(1e-04), loss=variational_loss(model, nobrainer.losses.jaccard))



In [None]:
model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=20)

In [None]:
outputs = model.predict(dataset, steps=steps_per_epoch)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
outputs = (outputs > 0.3).squeeze()