# pix2pix

## Load images

## download dataset

In [None]:
import os

path = "datasets/kitti/depth_selection/val_selection_cropped/"

if not os.path.isdir(path):
    ! mkdir -p datasets/kitti
    ! wget -O datasets/kitti.zip https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip
    ! unzip -q -o datasets/kitti.zip -d datasets/kitti

### common dependencies

In [None]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import csv
import random
import math
from PIL import Image
from tqdm import tqdm_notebook
from keras_tqdm import TQDMNotebookCallback

## load dataset

In [None]:
xs = []
ys = []

num_samples = 2000

for image_path in tqdm_notebook(glob.glob(path + "image/" + "*.png"), desc="Loading Images"):
    x = Image.open(image_path)
    x = x.convert('RGB')
    
    width = math.floor(x.size[0] * 256 / x.size[1])
    
    x = x.resize((width, 256))

    y = Image.open(image_path.replace('/image/', '/groundtruth_depth/').replace('sync_image', 'sync_groundtruth_depth'))
    y = y.convert('L')
    y = y.resize((width, 256))
    
    for i in range(random.randint(1, 2)):
        x_offset = random.randint(0, width - 256)
        x_crop = x.crop((x_offset, 0, x_offset + 256, 256))
        y_crop = y.crop((x_offset, 0, x_offset + 256, 256))
        
        if random.randint(0, 1) > 0:
            x_crop = x_crop.transpose(Image.FLIP_LEFT_RIGHT)
            y_crop = y_crop.transpose(Image.FLIP_LEFT_RIGHT)
        
        xs.append(np.array(x_crop.getdata()).reshape((256, 256, 3)) / 255)
        ys.append(np.array(y_crop.getdata()).reshape((256, 256, 1)) / 255)
    
    if len(xs) >= num_samples:
        break;

xs = np.array(xs)
ys = np.array(ys)

#### Visualization of random images and their labels

In [None]:
%matplotlib inline
fig, ax = plt.subplots(6,6,figsize=(16,16))
fig.tight_layout()
ax = ax.flatten()

for i in range(18):
    rand = np.random.randint(len(xs)-1)
    x = xs[rand]
    y = ys[rand].reshape((256, 256))
    
    ax[2 * i].imshow(x)
    ax[2 * i].set_title(f"{i}_x")
    ax[2 * i].axis("off")
    ax[2 * i + 1].imshow(y)
    ax[2 * i + 1].set_title(f"{i}_y")
    ax[2 * i + 1].axis("off")

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir {"logs/mono-depth-perception/"} --host 0.0.0.0 --port 8007

## Import pix2pix and generate model
After creating our dataset we load and create our model

In [None]:
%run pix2pix.ipynb

model = Pix2pix(output_dim=1)

To allow a consistent split between training and test dataset we split our dataset before training

In [None]:
((train_x, train_y), (test_x, test_y)) = model.split_dataset(xs, ys, validation_split=0.025)

#### Checkpoint stuff
To pause the training process and resume it later we use automatic checkpoints.
Before setting up these automatic checkpoints we check if a earlier checkpoint exists and if so load it.

In [None]:
checkpoint_dir = 'checkpoints/mono-depth-perception/'
checkpoint_path = checkpoint_dir + 'checkpoint-{epoch:04d}.ckpt'
os.makedirs(checkpoint_dir, exist_ok=True)

# if a checkpoint exists => build model and load weights
if tf.train.latest_checkpoint(checkpoint_dir) != None:
    model.build(train_x.shape)
    status = model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

### fit model
Now, after loading the model and potentially restoring it from a earlier checkpoint, we create our needed callbacks and train the model.

To show TQDM progress bars in Jupyter Lab run install the jupyterlab-manager widget before training:
``` bash
$ jupyter labextension install @jupyter-widgets/jupyterlab-manager
```

In [None]:
# callback for progressbars
tqdm_callback = TQDMNotebookCallback(inner_description_update="Epoch: {epoch}")

# callback for automatic checkpoints
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=False)

# callback for tensorboard
tensorboard_calback = tf.keras.callbacks.TensorBoard(log_dir='logs/mono-depth-perception/', histogram_freq=0, write_graph=True, write_images=False, update_freq='epoch')

# train the model
model.fit(train_x, train_y, batch_size=3, epochs=100, initial_epoch=0, validation_data=(test_x, test_y), callbacks=[tqdm_callback, checkpoint_callback, tensorboard_calback])

### visualize results of test data
After training the model, we visualize some examples from our tesing dataset including the input, output and expected output.

In [None]:
out = model.predict(test_x, batch_size=10)
for i in range(len(out)):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = test_x[i]
    y = test_y[i].reshape((256, 256))
    o = out[i].reshape((256, 256))
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")