# pix2pix

## Load images

## download dataset

In [None]:
import os

path = "datasets/kitti/depth_selection/val_selection_cropped/"

if not os.path.isdir(path):
    ! mkdir -p datasets/kitti
    ! wget -O datasets/kitti.zip https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip
    ! unzip -q -o datasets/kitti.zip -d datasets/kitti

### common dependencies

In [None]:
import numpy as np
import glob
import matplotlib.pyplot as plt
import csv
import random
import math
from PIL import Image
from tqdm import tqdm

## load dataset

In [None]:
xs = []
ys = []

num_samples = 600

for image_path in tqdm(glob.glob(path + "image/" + "*.png"), desc="Loading Images"):
    x = Image.open(image_path)
    x = x.convert('RGB')
    
    width = math.floor(x.size[0] * 256 / x.size[1])
    x_offset = random.randint(0, width - 256)
    
    x = x.resize((width, 256))
    x = x.crop((x_offset, 0, x_offset + 256, 256))
    xs.append(np.array(x.getdata()).reshape((256, 256, 3)) / 255)

    y = Image.open(image_path.replace('/image/', '/groundtruth_depth/').replace('sync_image', 'sync_groundtruth_depth'))
    y = y.convert('L')
    y = y.resize((width, 256))
    y = y.crop((x_offset, 0, x_offset + 256, 256))
    y = np.array(y.getdata()).reshape((256, 256, 1)) / 255
    ys.append(y)
    
    if len(xs) >= num_samples:
        break;

xs = np.array(xs)
ys = np.array(ys)

#### Visualization of random images and their labels

In [None]:
%matplotlib inline
fig, ax = plt.subplots(6,6,figsize=(16,16))
fig.tight_layout()
ax = ax.flatten()

for i in range(18):
    rand = np.random.randint(len(xs)-1)
    x = xs[rand]
    y = ys[rand].reshape((256, 256))
    
    ax[2 * i].imshow(x)
    ax[2 * i].set_title(f"{i}_x")
    ax[2 * i].axis("off")
    ax[2 * i + 1].imshow(y)
    ax[2 * i + 1].set_title(f"{i}_y")
    ax[2 * i + 1].axis("off")

### import pix2pix and generate model

In [None]:
%run pix2pix.ipynb

model = Pix2pix(output_dim=1)

### fit model

In [None]:
((train_x, train_y), (test_x, test_y)) = model.split_dataset(xs, ys, validation_split=0.05)

model.fit(train_x, train_y, batch_size=10, epochs=150, validation_data=(test_x, test_y))

### visualize results of test data

In [None]:
out = model.predict(test_x, batch_size=10)
for i in range(len(out)):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = test_x[i]
    y = test_y[i].reshape((256, 256))
    o = out[i].reshape((256, 256))
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")