# Pix2pix

In [None]:
import time
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
class RC:
    plot_initial_preds = False
    logs = f'/tf/logs/oxford_iiit_pet/{int(time.time())}'

class MC:
    downstack_layers = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project',      # 4x4
    ]
    
    upstack_layers = [
        dict(filters=512, size=3, dropout=.2),
        dict(filters=256, size=3, dropout=.2),
        dict(filters=128, size=3, dropout=.2),
        dict(filters=64, size=3, dropout=.2),
    ]

class DC:
    image_size = (256, 256)
    batch_size = 64
    buffer_size = 1000
    
    output_channels = 3
    
class TC:
    epochs = 30
    lr = 0.001
    augmentation = True

class Config:
    run = RC
    data = DC
    model = MC
    training = TC

## Setup

In [None]:
sns.set()

In [None]:
def plot(y, titles=None, rows=1, i0=0):
    from math import ceil
    
    for i, image in enumerate(y):
        if image is None:
            plt.subplot(rows, ceil(len(y) / rows), i0+i+1)
            plt.axis('off')
            continue

        t = titles[i] if titles else None
        plt.subplot(rows, ceil(len(y) / rows), i0+i+1, title=t)
        plt.imshow(image)
        plt.axis('off')

In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

@tf.function
def load_image(datapoint):
    input_image = tf.image.resize(datapoint['image'], Config.data.image_size)
    input_mask = tf.image.resize(datapoint['segmentation_mask'], Config.data.image_size)

    input_image, input_mask = normalize(input_image, input_mask)
    return input_image, input_mask

@tf.function
def augment_fn(image, segmask):
    image = tf.image.random_brightness(image, .2)
    image = tf.image.random_contrast(image, .75, 1.)

    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        segmask = tf.image.flip_left_right(segmask)
    
    return image, segmask

In [None]:
def load_dataset(ds, augment=False):
    ds = ds.map(load_image, num_parallel_calls=tf.data.AUTOTUNE).cache()
    
    if augment:
        ds = ds.map(augment_fn)
    
    return (ds.shuffle(Config.data.buffer_size)
              .batch(Config.data.batch_size))

## Dataset

In [None]:
class Data:
    (train, val, test), info = tfds.load('oxford_iiit_pet:3.*.*',
                               split=['train[:70%]', 'train[70%:]', 'test'],
                               with_info=True)
    train = load_dataset(train, augment=True)
    val = load_dataset(val, augment=True)
    test = load_dataset(test)

In [None]:
(x, y), = ((x[:8], y[:8]) for x, y in Data.train.take(1))

plt.figure(figsize=(16, 4))
plot((*tf.clip_by_value(x, 0, 1), *y), rows=2)
plt.tight_layout()

In [None]:
(x, y), = ((x[:8], y[:8]) for x, y in Data.test.take(1))

plt.figure(figsize=(16, 4))
plot((*x, *y), rows=2)
plt.tight_layout()

## Network

In [None]:
from tensorflow.keras import applications, Model, Input

base_model = applications.MobileNetV2(input_shape=[*Config.data.image_size, 3],
                                      include_top=False)

# Use the activations of these layers
layers = [base_model.get_layer(name).output for name in Config.model.downstack_layers]
down_stack = Model(inputs=base_model.input, outputs=layers, name='downstack')
down_stack.trainable = False

In [None]:
from tensorflow.keras.layers import (Layer, Conv2DTranspose, Dropout, Activation,
                                     BatchNormalization, ZeroPadding2D)

class UpSample(Layer):
    def __init__(self,
                 filters,
                 size=3,
                 norm='batch',
                 dropout=0.,
                 activation='relu',
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.filters = filters
        self.size = size
        self.conv2d_tr = Conv2DTranspose(
            filters, size, strides=2,
            padding='same',
            kernel_initializer=tf.random_normal_initializer(0., 0.02),
            use_bias=False)
        self.norm = norm
        self.norm_fn = (BatchNormalization() if norm == 'batch' else None)
        self.dropout = dropout
        self.dropout_fn = (Dropout(dropout) if dropout else None)
        self.activation_fn = Activation(activation)

    def call(self, x):
        y = self.conv2d_tr(x)
        if self.norm: y = self.norm_fn(y)
        if self.dropout: y = self.dropout_fn(y)
        
        return self.activation_fn(y)

In [None]:
from tensorflow.keras.layers import Concatenate

def unet_model(
        downstack_layers,
        upstack_layers,
        image_size,
        output_channels):
    inputs = Input(shape=[*image_size, 3], name='images')
    outputs = down_stack(inputs)
    x = outputs[-1]
    skips = reversed(outputs[:-1])

    print(f'last :- {x.shape}')

    # Upsampling and establishing the skip connections
    for l, args, skip in zip(downstack_layers, upstack_layers, skips):
        y = UpSample(**args, name=f'{l}/upsampling')(x)
        y = Concatenate(name=f'{l}/concat')([y, skip])
        print(f'{l} {args} {y.shape} :- {x.shape}, {skip.shape}')
        x = y

    x = Conv2DTranspose(output_channels, 3,
                        strides=2,
                        padding='same',
                        name='segments')(x)
    return Model(inputs=inputs, outputs=x, name='unet')

In [None]:
u = unet_model(
    Config.model.downstack_layers,
    Config.model.upstack_layers,
    Config.data.image_size,
    Config.data.output_channels)

In [None]:
tf.keras.utils.plot_model(
    u,
    to_file='pix2pix.png',
    show_shapes=True,
    show_dtype=False,
    show_layer_names=True)

In [None]:
u.compile(optimizer=tf.keras.optimizers.Adam(lr=Config.training.lr),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          metrics=['accuracy'])

## Training

In [None]:
import os
from tensorflow.keras import callbacks

os.makedirs(Config.run.logs, exist_ok=True)

u.fit(
    Data.train,
    epochs=Config.training.epochs,
    validation_data=Data.val,
    callbacks=[
        callbacks.TerminateOnNaN(),
        callbacks.EarlyStopping(patience=Config.training.epochs // 2, verbose=1),
        callbacks.ModelCheckpoint(Config.run.logs + '/weights.h5',
                                  save_weights_only=True,
                                  save_best_only=True,
                                  verbose=1),
        callbacks.ReduceLROnPlateau(patience=5, verbose=1),
    ],
    verbose=2);

In [None]:
plt.figure(figsize=(16, 8))
plt.subplot(221)
plt.plot(u.history.history['accuracy'], label='train accuracy')
plt.plot(u.history.history['val_accuracy'], label='val accuracy')
plt.legend();
plt.subplot(222)
plt.plot(u.history.history['loss'], label='train loss')
plt.plot(u.history.history['val_loss'], label='val loss')
plt.legend();
plt.subplot(223)
plt.plot(u.history.history['lr'], label='learning rate')
plt.legend();

## Testing

In [None]:
u.load_weights(Config.run.logs + '/weights.h5')

In [None]:
def predictions_to_segments(p):
    p = tf.argmax(p, axis=-1)
    p = tf.expand_dims(p, -1)

    return p

def show_predictions(model, ds, num=1):
    (x, y), = ((x[:8], y[:8]) for x, y in ds.take(1))
    p = model.predict(x)
    p = predictions_to_segments(p)
    
    plt.figure(figsize=(16, 6))
    plot((*tf.clip_by_value(x, 0, 1), *y, *p), rows=3)
    plt.tight_layout()

In [None]:
show_predictions(u, Data.train)

In [None]:
show_predictions(u, Data.test)

In [None]:
pd.DataFrame(
    [u.evaluate(Data.train, verbose=0),
     u.evaluate(Data.val, verbose=0),
     u.evaluate(Data.test, verbose=0)],
    columns=u.metrics_names,
    index=['train', 'val', 'test']).T