# Training a model with variational weights

Nobrainer implements variational convolutions. These layers, which can be used like any other Keras layer, learn gaussian distributions instead of scalar weights. In other words, these layers learn a mean and a standard deviation. This increases the memory footprint of the model, but variational convolutions enable things like [Distributed Weight Consolidation](https://arxiv.org/abs/1805.10863) and [measuring model uncertainty](https://arxiv.org/abs/1812.01719).

In this notebook, we will train a variational MeshNet model.

In [None]:
# TMP
import sys; sys.path.append('..'); del sys

import nobrainer

# Create Dataset of data

This assumes you have downloaded sample data in 'getting started' notebook.

In [None]:
n_classes = 1
batch_size = 10
volume_shape = (256, 256, 256)
block_shape = (64, 64, 64)

dataset = nobrainer.volume.get_dataset(
    file_pattern='tfrecords/data_shard-*.tfrecords',
    n_classes=n_classes,
    batch_size=batch_size,
    volume_shape=volume_shape,
    block_shape=block_shape,
    augment=False,
    n_epochs=None,
    shuffle_buffer_size=5)

dataset

In [None]:
steps_per_epoch = nobrainer.volume.get_steps_per_epoch(
    n_volumes=10, 
    volume_shape=volume_shape, 
    block_shape=block_shape, 
    batch_size=batch_size)

steps_per_epoch

# Instantiate model

Setting the flag `is_mc` to `True` will cause the model to sample a weight from its learned distributions. It will sample a different weight for every item in every minibatch.

Setting `is_mc` to `False` will cause the model to use the mean of every weight distribution (i.e., the most likely sample).

In [None]:
model = nobrainer.models.meshnet_vwn(
    n_classes, 
    input_shape=(*block_shape, 1), 
    filters=21, 
    is_mc=True)

# Compile model

You _must_ use the loss `nobrainer.losses.Variational` to train variational models. It is the only loss function in nobrainer that is aware of gaussian weights.

In [None]:
import tensorflow as tf

model.compile(
    tf.keras.optimizers.Adam(1e-04), 
    loss=nobrainer.losses.Variational(model=model, n_examples=256**3),
)

# Train the model

Here, we train on one GPU, but this model can be trained on multiple GPUs or a TPU. Please refer to other notebooks in the Nobrainer guide to learn how to train models on multiple GPUs or TPU.

In [None]:
model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=3)

# Predict

For sake of simplicity, we predict on our training data. Never do this in practice!

In [None]:
outputs = model.predict(dataset, steps=steps_per_epoch)