# pix2pix

## Load images

## download dataset

In [None]:
import os

path = "datasets/facades/base/"

if not os.path.isdir(path):
    ! mkdir -p datasets/facades
    ! wget -O datasets/facades.zip http://cmp.felk.cvut.cz/~tylecr1/facade/CMP_facade_DB_base.zip
    ! unzip -q -o datasets/facades.zip -d datasets/facades

### common dependencies

In [None]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm_notebook
from keras_tqdm import TQDMNotebookCallback

## load dataset

In [None]:
xs = []
ys = []

for image_path in glob.glob(path + "*.png"):
    x = Image.open(image_path)
    x = x.convert('RGB')
    x = x.resize((256, 256))
    xs.append(np.array(x.getdata()).reshape((256, 256, 3)) / 255)

    y = Image.open(image_path.replace(".png", ".jpg"))
    y = y.convert('RGB')
    y = y.resize((256, 256))
    ys.append(np.array(y.getdata()).reshape((256, 256, 3)) / 255)

xs = np.array(xs)
ys = np.array(ys)

#### Visualization of random images and their labels

In [None]:
%matplotlib inline
fig, ax = plt.subplots(6,6,figsize=(16,16))
fig.tight_layout()
ax = ax.flatten()

for i in range(18):
    rand = np.random.randint(len(xs)-1)
    x = xs[rand]
    y = ys[rand]
    
    ax[2 * i].imshow(x)
    ax[2 * i].set_title(f"{i}_x")
    ax[2 * i].axis("off")
    ax[2 * i + 1].imshow(y)
    ax[2 * i + 1].set_title(f"{i}_y")
    ax[2 * i + 1].axis("off")

### import pix2pix and generate model

In [None]:
%run pix2pix.ipynb

model = Pix2pix()

In [None]:
((train_x, train_y), (test_x, test_y)) = model.split_dataset(xs, ys, validation_split=0.05)

## Checkpoint stuff

In [None]:
checkpoint_dir = 'checkpoints/facades/'
checkpoint_path = checkpoint_dir + 'checkpoint-{epoch:04d}.ckpt'
os.makedirs(checkpoint_dir, exist_ok=True)

# if a checkpoint exists => build model and load weights
if tf.train.latest_checkpoint(checkpoint_dir) != None:
    model.build(train_x.shape)
    status = model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

### fit model

In [None]:
tqdm_callback = TQDMNotebookCallback(inner_description_update="Epoch: {epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=False)

model.fit(train_x, train_y, batch_size=10, epochs=200, validation_data=(test_x, test_y), callbacks=[tqdm_callback, checkpoint_callback])

### visualize results of test data

In [None]:
out = model.predict(test_x, batch_size=10)
for i in range(len(out)):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = test_x[i]
    y = test_y[i]
    o = out[i]
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")