### Import necessary dependencies

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import os
import time
# import matplotlib.pyplot as plt
from IPython.display import clear_output

print("Using TensorFlow version: " + tf.__version__)

Using TensorFlow version: 2.0.0-alpha0


In [3]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
tf.test.gpu_device_name()
# watch -n 0.5 nvidia-smi

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 11967844869934486272
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 2657005771649407317
physical_device_desc: "device: XLA_GPU device"
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 3381951437603560624
physical_device_desc: "device: XLA_CPU device"
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 7385933415
locality {
  bus_id: 1
  links {
  }
}
incarnation: 11828848820783591893
physical_device_desc: "device: 0, name: GeForce GTX 1080, pci bus id: 0000:65:00.0, compute capability: 6.1"
]


'/device:GPU:0'

### Define parameters

In [4]:
BUFFER_SIZE = 400

BATCH_SIZE = 1

IMG_WIDTH = 256

IMG_HEIGHT = 256

INPUT_IMAGE_DIRECTORY = '../sketchy_database_photo_pairs/'

CHECKPOINT_DIRECTORY = '../checkpoints/'

OUTPUT_CHANNELS = 3

LAMBDA = 100

EPOCHS = 200

### Input image preprocessing functions

Loads a combined image file and splits it into the photo and sketch image counterparts.

In [5]:
def load(combined_image_file_path):
    combined_image = tf.io.read_file(combined_image_file_path)
    combined_image = tf.image.decode_jpeg(combined_image)

    w = tf.shape(combined_image)[1]

    w = w // 2
    photo = combined_image[:, :w, :]
    sketch = combined_image[:, w:, :]

    photo = tf.cast(photo, tf.float32)
    sketch = tf.cast(sketch, tf.float32)

    return photo, sketch

Resizes a photo and sketch image pair to the specified dimensions.

In [6]:
def resize(photo, sketch, height, width):
    photo = tf.image.resize(photo, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    sketch = tf.image.resize(sketch, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return photo, sketch

Randlomly crops a photo and sketch image pair.

In [7]:
def random_crop(photo, sketch):
    stacked_image = tf.stack([photo, sketch], axis=0)
    cropped_image = tf.image.random_crop(
    stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
    return cropped_image[0], cropped_image[1]

In [8]:
# normalizing the images to [-1, 1]
def normalize(photo, sketch):
    photo = (photo / 127.5) - 1
    photo = (sketch / 127.5) - 1
    return photo, sketch

In [9]:
@tf.function()
def random_jitter(photo, sketch):
    # resizing to 286 x 286 x 3
    photo, sketch = resize(photo, sketch, 286, 286)

    # randomly cropping to 256 x 256 x 3
    photo, sketch = random_crop(photo, sketch)

    if tf.random.uniform(()) > 0.5:
        # random mirroring
        photo = tf.image.flip_left_right(photo)
        sketch = tf.image.flip_left_right(sketch)

    return photo, sketch

### Create training and test datasets

In [10]:
def load_image_train(image_file_path):
    photo, sketch = load(image_file_path)
    photo, sketch = random_jitter(photo, sketch)
    photo, sketch = normalize(photo, sketch)
    return photo, sketch

In [11]:
def load_image_test(image_file_path):
    photo, sketch = load(image_file_path)
    photo, sketch = resize(photo, sketch, IMG_HEIGHT, IMG_WIDTH)
    photo, sketch = normalize(photo, sketch)
    return photo, sketch

In [12]:
train_dataset = tf.data.Dataset.list_files(INPUT_IMAGE_DIRECTORY + 'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)

In [13]:
test_dataset = tf.data.Dataset.list_files(INPUT_IMAGE_DIRECTORY + 'test/*.jpg')
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)

### Build the generator

In [14]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [15]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [16]:
def Generator():
    down_stack = [
        downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh') # (bs, 256, 256, 3)

    concat = tf.keras.layers.Concatenate()

    inputs = tf.keras.layers.Input(shape=[None,None,3])
    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = concat([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [17]:
generator = Generator()

### Build the discriminator

In [18]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
    tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')

    x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=[inp, tar], outputs=last)

In [19]:
discriminator = Discriminator()

### Define the loss functions and the optimizer

In [20]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [21]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss

In [22]:
def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

    # mean absolute error
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)

    return total_gen_loss

In [23]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

### Checkpoints (Object-based saving)

In [24]:
checkpoint_dir = CHECKPOINT_DIRECTORY
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator)

### Training

In [25]:
def generate_images(model, test_input, tar):
    # the training=True is intentional here since
    # we want the batch statistics while running the model
    # on the test dataset. If we use training=False, we will get
    # the accumulated statistics learned from the training dataset
    # (which we don't want)
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15,15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
        
    plt.show()

In [26]:
@tf.function
def train_step(photo, sketch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(photo, training=True)

        disc_real_output = discriminator([photo, sketch], training=True)
        disc_generated_output = discriminator([photo, gen_output], training=True)

        gen_loss = generator_loss(disc_generated_output, gen_output, sketch)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))


In [27]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()
        
        count = 1
        for photo, sketch in dataset:
            print("train step " + str(count))
            count += 1
            train_step(photo, sketch)

        clear_output(wait=True)
        for inp, tar in test_dataset.take(1):
            generate_images(generator, inp, tar)

        # saving (checkpoint) the model every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))


In [None]:
train(train_dataset, EPOCHS)

train step 1
