In [None]:
! pip install git+https://github.com/tensorflow/examples.git

In [None]:
"""
TODO: docstring
"""
import matplotlib.pyplot as pyplot
import tensorflow
import tensorflow_datasets
import tensorflow_examples.models.pix2pix.pix2pix as pix2pix

def create_mask(pred_mask):
    """
    TODO: docstring
    """
    pred_mask = tensorflow.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tensorflow.newaxis]
    return pred_mask[0]

def display(display_list):
    """
    TODO: docstring
    """
    pyplot.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        pyplot.subplot(1, len(display_list), i+1)
        pyplot.title(title[i])
        pyplot.imshow(tensorflow.keras.utils.array_to_img(display_list[i]))
        pyplot.axis('off')
    pyplot.show()

def load_image(datapoint):
    """
    TODO: docstring
    """
    input_image = tensorflow.image.resize(datapoint['image'], (128, 128))
    input_mask = tensorflow.image.resize(
        datapoint['segmentation_mask'], (128, 128))
    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

def normalize(input_image, input_mask):
    """
    TODO: docstring
    """
    input_image = tensorflow.cast(input_image, tensorflow.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

def show_predictions(dataset=None, num=1):
    """
    TODO: docstring
    """
    if dataset:
        for image, mask in dataset.take(num):
          pred_mask = model.predict(image)
          display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([
            sample_image, sample_mask,
            create_mask(model.predict(sample_image[tensorflow.newaxis, ...]))])

def unet_model(output_channels:int):
    """
    TODO: docstring
    """
    inputs = tensorflow.keras.layers.Input(shape=[128, 128, 3])
    # downsampling through the model
    skips = down_stack(inputs)
    x = skips[-1]
    skips = reversed(skips[:-1])
    # upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tensorflow.keras.layers.Concatenate()
        x = concat([x, skip])
    # this is the last layer of the model
    last = tensorflow.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3,
        strides=2, padding='same') #64x64 -> 128x128
    x = last(x)
    return tensorflow.keras.Model(inputs=inputs, outputs=x)

class Augment(tensorflow.keras.layers.Layer):
    """
    TODO: docstring
    """
    def __init__(self, seed=42):
        """
        TODO: docstring
        """
        super().__init__()
        # both use the same seed, so they'll make the same random changes
        self.augment_inputs = tensorflow.keras.layers.RandomFlip(
            mode='horizontal', seed=seed)
        self.augment_labels = tensorflow.keras.layers.RandomFlip(
            mode='horizontal', seed=seed)

    def call(self, inputs, labels):
        """
        TODO: docstring
        """
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

class DisplayCallback(tensorflow.keras.callbacks.Callback):
    """
    TODO: docstring
    """
    def on_epoch_end(self, epoch, logs=None):
        """
        TODO: docstring
        """
        show_predictions()
        print ('\nsample Prediction after epoch {}\n'.format(epoch+1))

dataset, info = tensorflow_datasets.load('oxford_iiit_pet:3.*.*', with_info=True)

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train_images = dataset['train'].map(
    load_image, num_parallel_calls=tensorflow.data.AUTOTUNE)

test_images = dataset['test'].map(
    load_image, num_parallel_calls=tensorflow.data.AUTOTUNE)

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tensorflow.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

for images, masks in train_batches.take(2):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])

base_model = tensorflow.keras.applications.MobileNetV2(
    input_shape=[128, 128, 3], include_top=False)

# use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project']      # 4x4

base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# create the feature extraction model
down_stack = tensorflow.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3)]   # 32x32 -> 64x64

model = unet_model(output_channels=3)

model.compile(
    optimizer='adam', metrics=['accuracy'],
    loss=tensorflow.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
# NOTE: The loss function for training vector inputs should be
# SparseCategoricalCrossentropy, which is incorrectly stated in the tutorial.

tensorflow.keras.utils.plot_model(model, show_shapes=True)

show_predictions()

VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples // BATCH_SIZE // VAL_SUBSPLITS

model_history = model.fit(
    train_batches, epochs=100,
    steps_per_epoch=STEPS_PER_EPOCH, validation_steps=VALIDATION_STEPS,
    validation_data=test_batches, callbacks=[DisplayCallback()])

loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

pyplot.figure()
pyplot.plot(model_history.epoch, loss, 'r', label='Training loss')
pyplot.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
pyplot.title('Training and Validation Loss')
pyplot.xlabel('Epoch')
pyplot.ylabel('Loss Value')
pyplot.ylim([0, 1])
pyplot.legend()
pyplot.show()

show_predictions(test_batches, 3)