In [0]:
import os
from typing import Optional

import tensorflow as tf

In [0]:
# Authenticate to GCS.
from google.colab import auth
auth.authenticate_user()

In [3]:
tf.gfile.ListDirectory('gs://pontiml/inria/records')

['train-0.tfrecord',
 'train-1.tfrecord',
 'train-2.tfrecord',
 'train-3.tfrecord',
 'train-4.tfrecord',
 'train-5.tfrecord',
 'train-6.tfrecord',
 'train-7.tfrecord',
 'train-8.tfrecord',
 'train-9.tfrecord',
 'validation-0.tfrecord']

## Data

In [0]:
def _map_batch(record_batch):
    img_batch = tf.decode_raw(record_batch['image/bytes'], tf.uint8)
    img_batch = tf.reshape(img_batch, (-1, 1000, 1000, 3))
    img_batch = tf.image.resize_images(img_batch, (512, 512))

    img_batch = tf.image.convert_image_dtype(img_batch, dtype=tf.float32)

    label_batch = tf.decode_raw(record_batch['image/label'], tf.float32)
    #         label_batch = tf.expand_dims(label_batch, axis=-1)
    label_batch = tf.reshape(label_batch, (-1, 1000, 1000))
    label_batch = tf.expand_dims(label_batch, axis=-1)

    label_batch = tf.image.resize_images(label_batch, (512, 512))

    # Need to make sure same transformation is done to image and label...maybe use seed?
    # Or something like https://stackoverflow.com/a/38403715/1602729
#     if training:
#         # Do some random image augmentation - need to make sure it's random but applied to both image and label
#         image_raw = tf.image.random_flip_left_right(image_raw)
#         label_raw = tf.image.random_flip_left_right(label_raw)

    return { 'image': img_batch }, label_batch

def make_dataset(file_pattern, num_epochs=None, batch_size=32, shuffle=True):

    features = {
        'image/height': tf.FixedLenFeature([], tf.int64),
        'image/width': tf.FixedLenFeature([], tf.int64),
        'image/label': tf.FixedLenFeature([], tf.string),
        'image/bytes': tf.FixedLenFeature([], tf.string)
    }

    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern, batch_size, features,
        num_epochs=num_epochs, shuffle=shuffle,
        shuffle_buffer_size=4*batch_size, sloppy_ordering=True,
        reader_num_threads=os.cpu_count(), parser_num_threads=os.cpu_count(),
        prefetch_buffer_size=4)

    dataset = dataset.map(_map_batch)

    return dataset

## Model

<img alt="unet model architecture" src="http://deeplearning.net/tutorial/_images/unet.jpg" width=500 />


In [0]:
l = tf.keras.layers

def _conv_block(inputs: tf.Tensor, filters: int, name: str, repeat: int=2) -> tf.Tensor:
    """Repeated 3x3 2d convolutions with ReLU"""
    layer = inputs
    for idx, _ in enumerate(range(repeat)):
        layer = l.Conv2D(filters, kernel_size=3, activation=tf.nn.relu, 
                         padding='same', name=f'{name}_{idx}')(layer)
    return layer


def _up_conv(inputs: tf.Tensor) -> tf.Tensor:
    """Deconvolution"""
    input_shape = inputs.get_shape().as_list()
    return l.Conv2DTranspose(filters=input_shape[-1], kernel_size=2, strides=2, 
                             padding='same', activation=tf.nn.relu)(inputs)


def unet_model(batch_size: Optional[int]=None) -> tf.keras.Model:
    """Function to build the unet model architecture"""
    
    image_input = l.Input(shape=(512, 512, 3), batch_size=batch_size,
                          name='image', dtype=tf.float32)

    # Provide sample of original image to tensorboard
    tf.summary.image('inputs', image_input)
    
#     if labels is not None:
#         # Sample of our truth label for tensorboard
#         tf.summary.image('labels', labels)

    # Use the image as the input or our current head of the model
    net = image_input
    
    # Used to store layers that are copied to the up-path of the unet model
    copy_layers = []
    # Define the number of filters in each stage of the up and down path
    filter_counts = [64, 128, 256, 512]

    # Contractive path
    for idx, num_filters in enumerate(filter_counts):
        net = _conv_block(net, num_filters, name=f'contracting{idx}')
        copy_layers.append(net)
        net = l.MaxPool2D(pool_size=2, strides=2, padding='same')(net)
    
    # Apply final conv that doesn't have output copied and no maxpool
    net = _conv_block(net, 2 * filter_counts[-1], name='conexp')

    # Expansive path
    for idx, num_filters in enumerate(reversed(filter_counts)):
        net = _up_conv(net)
        copy_layer = copy_layers.pop()
        net = l.concatenate([net, copy_layer], axis=3)
        net = _conv_block(net, num_filters, name=f'expansive{idx}')

    # Conv 1x1 to get output segmentation map
    logits = l.Conv2D(filters=1, kernel_size=1, activation=tf.nn.sigmoid, 
                      padding='same')(net)
    
    tf.summary.image('predictions', logits)
    
    return tf.keras.Model(inputs=image_input, outputs=logits)

## Train

In [0]:
data_dir = os.environ.get('DATA_DIR', 'gs://pontiml/inria/records')
model_dir = os.environ.get('MODEL_DIR', 'gs://pontiml/inria/model/run6')
max_steps = int(os.environ.get('MAX_STEPS', 10000))
batch_size = int(os.environ.get('BATCH_SIZE', 2))

In [0]:
tf.gfile.MakeDirs(model_dir)

In [0]:
batch_size_multiplier = 1

if 'COLAB_TPU_ADDR' in os.environ:
    distribute = tf.contrib.distribute.TPUStrategy()
    
    # Cloud TPU contains 8 TPU cores
    batch_size_multiplier = 8
else:
    distribute = tf.contrib.distribute.MirroredStrategy()

In [0]:
batch_size *= batch_size_multiplier

In [0]:
model = unet_model(batch_size=batch_size)

In [0]:
model.compile(
    optimizer=tf.train.AdamOptimizer(learning_rate=1e-2),
    loss=tf.losses.sigmoid_cross_entropy,
    metrics=[tf.keras.metrics.binary_accuracy],
#     distribute=distribute
)

In [0]:
train_ds = make_dataset(os.path.join(data_dir, 'train-*.tfrecord'), batch_size=batch_size, shuffle=True)
val_ds = make_dataset(os.path.join(data_dir, 'validation-0.tfrecord'), batch_size=batch_size, shuffle=False)

In [0]:
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=model_dir)
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(model_dir, 'model.ckpt.hdf5'))

callbacks = [tensorboard, model_checkpoint]

In [20]:
model.fit(
    train_ds, validation_data=val_ds, batch_size=batch_size, # steps_per_epoch=100, 
    validation_steps=100//batch_size, epochs=10, 
    callbacks=callbacks)

InvalidArgumentError: ignored

In [0]:
model.save_weights(os.path.join(model_dir, 'unet_model.h5'))

In [0]:


if 'COLAB_TPU_ADDR' in os.environ:
    TPU_WORKER = f'grpc://{os.environ["COLAB_TPU_ADDR"]}'
    
    tpu_model = tf.contrib.tpu.keras_to_tpu_model(
        model, strategy=tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))
    
    
else:
    model.fit()