# pix2pix

## Load images

## download dataset

In [None]:
import os

path = "datasets/outdoor/"

if not os.path.isdir(path):
    ! mkdir -p datasets/outdoor
    ! wget -O datasets/outdoor.tar http://transattr.cs.brown.edu/files/aligned_images.tar
    ! tar -C datasets/outdoor -xf datasets/outdoor.tar
    ! wget -O datasets/outdoor1.tar http://transattr.cs.brown.edu/files/annotations.tar
    ! tar -C datasets/outdoor -xf datasets/outdoor1.tar

### common dependencies

In [None]:
import numpy as np
import glob
import matplotlib.pyplot as plt
from PIL import Image
import csv
import random
from tqdm import tqdm

## load dataset
available annotations:
- 0: **dirty**
- 1: **daylight**
- 2: **night**
- 3: **sunrisesunset**
- 4: **dawndusk**
- 5: **sunny**
- 6: **clouds**
- 7: **fog**
- 8: **storm**
- 9: **snow**
- 10: **warm**
- 11: **cold**
- 12: **busy**
- 13: **beautiful**
- 14: **flowers**
- 15: **spring**
- 16: **summer**
- 17: **autumn**
- 18: **winter**
- 19: **glowing**
- 20: **colorful**
- 21: **dull**
- 22: **rugged**
- 23: **midday**
- 24: **dark**
- 25: **bright**
- 26: **dry**
- 27: **moist**
- 28: **windy**
- 29: **rain**
- 30: **ice**
- 31: **cluttered**
- 32: **soothing**
- 33: **stressful**
- 34: **exciting**
- 35: **sentimental**
- 36: **mysterious**
- 37: **boring**
- 38: **gloomy**
- 39: **lush**

In [None]:
attribute_x = 18
attribute_y = 16

files = []

with open(path + 'annotations/annotations.tsv', newline='') as annotations:
    annotations = csv.reader(annotations, delimiter='\t')
    
    current_img_path = ""
    current_img_attributes = {}
    for row in annotations:
        scene = row[0].split('/')[0]
        if scene != current_img_path:
            if attribute_x in current_img_attributes and attribute_y in current_img_attributes:
                for file_x in current_img_attributes[attribute_x]:
                    for file_y in current_img_attributes[attribute_y]:
                        files.append((current_img_path + '/' + file_x, current_img_path + '/' + file_y))
            
            current_img_path = scene
            current_img_attributes = {}
            
        for i in range(1, len(row)):
            if float(row[i].split(',')[0]) > 0.8:
                if i - 1 in current_img_attributes:
                    current_img_attributes[i - 1].append(row[0].split('/')[1])
                else:
                    current_img_attributes[i - 1] = [row[0].split('/')[1]]

In [None]:
# reduce loading time by reducing amount of samples
num_samples = 800

random.shuffle(files)
files = files[:num_samples]

In [None]:
xs = []
ys = []

for (file_x, file_y) in tqdm(files):
    x = Image.open(path + "imageAlignedLD/" + file_x)
    x = x.convert('RGB')
    x = x.resize((256, 256))

    y = Image.open(path + "imageAlignedLD/" + file_y)
    y = y.convert('RGB')
    y = y.resize((256, 256))
    
    xs.append(np.array(x.getdata()).reshape((256, 256, 3)) / 255)
    ys.append(np.array(y.getdata()).reshape((256, 256, 3)) / 255)

xs = np.array(xs)
ys = np.array(ys)

#### Visualization of random images and their labels

In [None]:
%matplotlib inline
fig, ax = plt.subplots(6,6,figsize=(16,16))
fig.tight_layout()
ax = ax.flatten()

for i in range(18):
    rand = np.random.randint(len(xs)-1)
    x = xs[rand]
    y = ys[rand]
    
    ax[2 * i].imshow(x)
    ax[2 * i].set_title(f"{i}_x")
    ax[2 * i].axis("off")
    ax[2 * i + 1].imshow(y)
    ax[2 * i + 1].set_title(f"{i}_y")
    ax[2 * i + 1].axis("off")

### import pix2pix and generate model

In [None]:
%run pix2pix.ipynb

model = Pix2pix()

### fit model

In [None]:
((train_x, train_y), (test_x, test_y)) = model.split_dataset(xs, ys, validation_split=0.05)

model.fit(train_x, train_y, batch_size=10, epochs=150, validation_data=(test_x, test_y))

### visualize results of test data

In [None]:
out = model.predict(test_x, batch_size=10)
for i in range(len(out)):
    fig, ax = plt.subplots(1,3,figsize=(10,10))
    fig.tight_layout()
    ax = ax.flatten()

    x = test_x[i]
    y = test_y[i]
    o = out[i]
    
    ax[0].imshow(x)
    ax[0].set_title("x")
    ax[0].axis("off")
    ax[1].imshow(y)
    ax[1].set_title("y")
    ax[1].axis("off")
    ax[2].imshow(o)
    ax[2].set_title("g(x)")
    ax[2].axis("off")