In [39]:
import os
import pathlib
import datetime
import time

import tensorflow as tf

import matplotlib.pyplot as plt
from IPython import display

In [40]:
dataset_name = "facades"
dataset_url = "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/" + dataset_name + ".tar.gz"
path_to_tar = tf.keras.utils.get_file(
    fname=dataset_name + ".tar.gz",
    origin=dataset_url,
    extract=True
)
path_to_tar = pathlib.Path(path_to_tar)
dataset_path = path_to_tar.parent/dataset_name

In [41]:
BUFFERSIZE = 400
BATCHSIZE = 1
WIDTH = 256
HEIGHT = 256

In [42]:
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image)
    w = tf.shape(image)[1] // 2    
    input_image = image[:, w:, :]
    input_image = tf.cast(input_image, tf.float32)
    real_image = image[:, :w, :]
    real_image = tf.cast(real_image, tf.float32)
    return input_image, real_image

def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return input_image, real_image

def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(stacked_image, size=[2, HEIGHT, WIDTH, 3])
    return cropped_image[0], cropped_image[1]

def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1
    return input_image, real_image

@tf.function()
def random_jitter(input_image, real_image):
    input_image, real_image = resize(input_image, real_image, 286, 286)
    input_image, real_image = random_crop(input_image, real_image)
    if tf.random.uniform(()) > 0.5:
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)
    return input_image, real_image

In [43]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)
    return input_image, real_image

def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image, HEIGHT, WIDTH)
    input_image, real_image = normalize(input_image, real_image)
    return input_image, real_image

train_dataset = tf.data.Dataset.list_files(str(dataset_path / "train/*.jpg"))
train_dataset = train_dataset.map(load_image_train)
train_dataset = train_dataset.shuffle(BUFFERSIZE)
train_dataset = train_dataset.batch(BATCHSIZE)

test_dataset = tf.data.Dataset.list_files(str(dataset_path / "test/*.jpg"))
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCHSIZE)