# pix2pix facades
This example uses the pix2pix notebook which implements the pix2pix conditional GAN to convert images of facades from densel labels of the facade elements.

This example is also used in the pix2pix paper and is included to show general functionality of our implementaion of the pix2pix architecture.

## Load images
Before training the network we need to download and prepare our dataset.

### download dataset
To get the dataset we download is directly from the official [website of the dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/) if it has not been downloaded yet.

In [None]:
import os

path = "datasets/facades/base/"

if not os.path.isdir(path):
    ! mkdir -p datasets/facades
    ! wget -O datasets/facades.zip http://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_base.zip
    ! unzip -q -o datasets/facades.zip -d datasets/facades

### common dependencies

In [None]:
import numpy as np
import glob
import random
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm_notebook
from keras_tqdm import TQDMNotebookCallback

### load dataset
After downloading the dataset, we prepare it. Because it contains per entry one png for the labels and one jpg for the image, we simply replace the .png with a .jpg to get the image for one label

In [None]:
def random_jitter(x, y):
    x_out = x.resize((286, 286))
    y_out = y.resize((286, 286))
    
    x_start = random.randint(0, 30)
    y_start = random.randint(0, 30)
    
    x_out = x_out.crop((x_start, y_start, x_start + 256, y_start + 256))
    y_out = y_out.crop((x_start, y_start, x_start + 256, y_start + 256))
    
    if random.randint(0, 1) > 0:
        x_out = x_out.transpose(Image.FLIP_LEFT_RIGHT)
        y_out = y_out.transpose(Image.FLIP_LEFT_RIGHT)
        
    return (x_out, y_out)

In [None]:
xs = []
ys = []

for image_path in tqdm_notebook(glob.glob(path + "*.png")):
    # load the labels of the current entry
    x = Image.open(image_path)
    x = x.convert('RGB')

    # load the image of the current entry
    y = Image.open(image_path.replace(".png", ".jpg"))
    y = y.convert('RGB')
    
    # apply random jitter
    for i in range(random.randint(0, 4)):
        x_out, y_out = random_jitter(x, y)
        xs.append(np.array(x_out.getdata()).reshape((256, 256, 3)) / 255)
        ys.append(np.array(y_out.getdata()).reshape((256, 256, 3)) / 255)

xs = np.array(xs)
ys = np.array(ys)

#### Visualization of random images and their labels
After loading the images and labels we visualize some pairs of our dataset to see if all went correctly.

In [None]:
%matplotlib inline
fig, ax = plt.subplots(6,6,figsize=(16,16))
fig.tight_layout()
ax = ax.flatten()

for i in range(18):
    rand = np.random.randint(len(xs)-1)
    x = xs[rand]
    y = ys[rand]
    
    ax[2 * i].imshow(x)
    ax[2 * i].set_title(f"{i}_x")
    ax[2 * i].axis("off")
    ax[2 * i + 1].imshow(y)
    ax[2 * i + 1].set_title(f"{i}_y")
    ax[2 * i + 1].axis("off")

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir {"logs/facades/"} --host 0.0.0.0 --port 8008

## Import pix2pix and generate model
After creating our dataset we load and create our model

In [None]:
%run pix2pix.ipynb

model = Pix2pix(discriminator=Discriminator286())

To allow a consistent split between training and test dataset we split our dataset before training

In [None]:
((train_x, train_y), (test_x, test_y)) = model.split_dataset(xs, ys, validation_split=0.025)

#### Checkpoint stuff
To pause the training process and resume it later we use automatic checkpoints.
Before setting up these automatic checkpoints we check if a earlier checkpoint exists and if so load it.

In [None]:
checkpoint_dir = 'checkpoints/facades/'
checkpoint_path = checkpoint_dir + 'checkpoint-{epoch:04d}.ckpt'
os.makedirs(checkpoint_dir, exist_ok=True)

# if a checkpoint exists => build model and load weights
if tf.train.latest_checkpoint(checkpoint_dir) != None:
    model.build(train_x.shape)
    status = model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

### fit model
Now, after loading the model and potentially restoring it from a earlier checkpoint, we create our needed callbacks and train the model.

To show TQDM progress bars in Jupyter Lab run install the jupyterlab-manager widget before training:
``` bash
$ jupyter labextension install @jupyter-widgets/jupyterlab-manager
```

In [None]:
# callback for progressbars
tqdm_callback = TQDMNotebookCallback(inner_description_update="Epoch: {epoch}")

# callback for automatic checkpoints
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=False)

# callback for tensorboard
tensorboard_calback = tf.keras.callbacks.TensorBoard(log_dir='logs/facades/', histogram_freq=0, write_graph=True, write_images=False, update_freq='epoch')

# train the model
model.fit(train_x, train_y, batch_size=10, epochs=50, initial_epoch=0, validation_data=(test_x, test_y), callbacks=[tqdm_callback, checkpoint_callback, tensorboard_calback])

### visualize results of test data
After training the model, we visualize some examples from our tesing dataset including the input, output and expected output.

In [None]:
out = model.predict(test_x, batch_size=10)
for i in range(10):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = test_x[i]
    y = test_y[i]
    o = out[i]
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")

### visualize results of training data
Now we also visualize our results on our training data

In [None]:
out = model.predict(train_x[:10], batch_size=10)
for i in range(10):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = train_x[i]
    y = train_y[i]
    o = out[i]
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")

# 