# Training on TPU

## Import nobrainer

In [None]:
!pip install --no-cache-dir nobrainer[gpu]

In [None]:
# TMP
import sys; sys.path.append('..'); del sys

import nobrainer

## Create dataset

We create a `tf.data.Dataset` with sample data. The sample data are 10 T1-weighted brain scans and their corresponding FreeSurfer segmentations. If you want more information about this part, please refer to notebook 01.

In [None]:
csv_path = nobrainer.utils.get_data()
csv_path

In [None]:
!nobrainer convert \
    --csv='/tmp/nobrainer-data/filepaths.csv' \
    --tfrecords-template='tfrecords/data_shard-{shard:03d}.tfrecords' \
    --volumes-per-shard=4 \
    --volume-shape 256 256 256 \
    --num-parallel-calls=8 \
    --verbose

In [None]:
file_pattern = 'tfrecords/data_shard-*.tfrecords'
n_classes = 1
batch_size = 2
volume_shape = (256, 256, 256)
block_shape = (128, 128, 128)
augment = False
n_epochs = 1
shuffle_buffer_size = 4
num_parallel_calls = 4

dataset = nobrainer.volume.get_dataset(
    file_pattern=file_pattern,
    n_classes=n_classes,
    batch_size=batch_size,
    volume_shape=volume_shape,
    block_shape=block_shape,
    augment=augment,
    n_epochs=n_epochs,
    shuffle_buffer_size=shuffle_buffer_size,
    num_parallel_calls=num_parallel_calls)

dataset

## Instantiate Model

Important notes: the shapes of all tensors must be static to train on TPU. This means that the batch size must be set upon model instantiation.

In [None]:
import os
import tensorflow as tf

In [None]:
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
strategy = tf.contrib.distribute.TPUStrategy(
    tpu_cluster_resolver=tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER),
)

In [None]:
with strategy.scope():
    model = nobrainer.models.unet(n_classes=n_classes, input_shape=(*block_shape, 1), batch_size=batch_size)
    model.compile(
        optimizer=tf.train.AdamOptimizer(1e-04),
        loss=nobrainer.losses.jaccard,
        metrics=[nobrainer.metrics.dice])

In [None]:
steps_per_epoch = nobrainer.volume.get_steps_per_epoch(
    n_volumes=10, 
    volume_shape=volume_shape, 
    block_shape=block_shape, 
    batch_size=batch_size)
steps_per_epoch

In [None]:
history = model.fit(dataset, epochs=20)