In [None]:
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.experimental.set_virtual_device_configuration(
        gpus[0],
        [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=3072)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print(e)

In [None]:
import numpy as np
import glob
import json

# Configuration

PATH = "./data"

INPUT_PATH = PATH + "/input/wood"
CHECKPOINTS_PATH = PATH + "/checkpoints"

TYPES = ["cedar", "oak", "pine", "redwood"]

# Functions

def load_dictionary(filename):
    fd = open(filename, "r")
    dictionary = json.loads(fd.read())
    fd.close()
    return dictionary

def generate_label_vector(metafile):
    mydict = load_dictionary(metafile)
    img_type = mydict["type"]
    label_vector = []
    for t in TYPES:
        if t == img_type:
            label_vector.append(1.)
        else:
            label_vector.append(0.)
    return np.array(label_vector).astype(float)

def normalize(image):
    image = (image / 127.5) - 1
    return image

def load_image(id):
    input_img = tf.cast(tf.image.decode_png(tf.io.read_file(INPUT_PATH + "/" + id + "-edges.png"), channels=3), tf.float32)
    target_img = tf.cast(tf.image.decode_png(tf.io.read_file(INPUT_PATH + "/" + id + "-image.png"), channels=3), tf.float32)
    class_vector = generate_label_vector(INPUT_PATH + "/" + id + "-meta.txt")
    
    input_img = normalize(input_img)
    target_img = normalize(target_img)

    return input_img, target_img, class_vector

# Main

image_paths = glob.glob(INPUT_PATH + "/*image*")
data_size = len(image_paths)

ids = list(map(lambda i: "%.4d" % (i,), range(data_size)))

train_size = round(data_size * 0.80)

ids_rand = np.copy(ids)
np.random.shuffle(ids_rand)

train_ids = ids_rand[:train_size]
test_ids = ids_rand[train_size:]

train_tensors = list(map(lambda i: load_image(i), train_ids))
test_tensors= list(map(lambda i: load_image(i), test_ids))

In [None]:
image_size = load_image(ids_rand[0])[0].shape

print(image_size)

In [None]:
import matplotlib.pyplot as plt

img = load_image(ids_rand[0])[1]
img = (img + 1) / 2
plt.imshow(img)
plt.show()

In [None]:
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_tensors,
    output_types=(tf.float32, tf.float32, tf.float32),
    output_shapes=(tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3]), tf.TensorShape([None]))
)
train_dataset = train_dataset.batch(1)

test_dataset = tf.data.Dataset.from_generator(
    lambda: test_tensors,
    output_types=(tf.float32, tf.float32, tf.float32),
    output_shapes=(tf.TensorShape([None, None, 3]), tf.TensorShape([None, None, 3]), tf.TensorShape([None]))
)
test_dataset = test_dataset.batch(1)

In [None]:
for input_img, target_img, class_vector in train_dataset.take(1):
    print(input_img.shape)
    plt.imshow((input_img[0,...] + 1) / 2)
    plt.show()
    
    print(target_img.shape)
    plt.imshow((target_img[0,...] + 1) / 2)
    plt.show()
    
    print(class_vector)

In [None]:
from tensorflow.keras import *
from tensorflow.keras.layers import *

def downsampler(filters, apply_batch_normalization=True):
    
    result = Sequential()
    
    initializer = tf.random_normal_initializer(0, 0.02)
    
    result.add(Conv2D(
        filters,
        kernel_size = 4,
        strides = 2,
        padding = "same",
        kernel_initializer = initializer,
        use_bias = not apply_batch_normalization
    ))
    
    if apply_batch_normalization:
        result.add(BatchNormalization())
    
    result.add(LeakyReLU(alpha = 0.2))
    
    return result

downsampler(64)

In [None]:
def upsampler(filters, apply_dropout=False):
    result = Sequential()
    
    initializer = tf.random_normal_initializer(0, 0.02)
    
    result.add(Conv2DTranspose(
        filters,
        kernel_size = 4,
        strides = 2,
        padding = "same",
        kernel_initializer = initializer,
        use_bias = False
    ))
    
    result.add(BatchNormalization())
    
    if apply_dropout:
        result.add(Dropout(0.5))
    
    result.add(ReLU())
    
    return result

upsampler(64)

In [None]:
def pre_conditioner(input_filters, latent_vector_size):
    result = Sequential()
        
    result.add(Flatten())
        
    result.add(Dense(input_filters))
    result.add(BatchNormalization())
    result.add(LeakyReLU(alpha = 0.2))
    
    result.add(Dense(latent_vector_size))

    return result

def conditioner(output_dim1, output_dim2, output_dim3):
    result = Sequential()
        
    result.add(Dense(2048))
    result.add(LeakyReLU(alpha = 0.2))
    result.add(Dropout(0.2))
    
    result.add(Dense(output_dim1 * output_dim2 * output_dim3))
    result.add(BatchNormalization())
    result.add(LeakyReLU(alpha = 0.2))
    result.add(Dropout(0.2))
    
    result.add(Reshape((output_dim1, output_dim2, output_dim3)))

    return result

i = tf.keras.layers.Input(shape=[1, 1, 4])
ia = tf.keras.layers.Input(shape=[2])

o = pre_conditioner(512, 100)(i)
o = conditioner(1, 1, 512)(concatenate([o, ia]))
m = Model(inputs=[i, ia], outputs=o)

d = tf.cast(np.array([[[[0, 0.5, -0.5, 0]]]]), tf.float32)
da = tf.cast(np.array([[0, 1]]), tf.float32)
m([d, da], training=False)

In [None]:
def Generator(input_dim1, input_dim2, class_num):
    inputs = tf.keras.layers.Input(shape=[input_dim1, input_dim2, 3])
    class_vector = tf.keras.layers.Input(shape=[class_num])
    
    last_layer = Conv2DTranspose(
        filters = 3,
        kernel_size = 4,
        strides = 2,
        padding = "same",
        kernel_initializer = tf.random_normal_initializer(0, 0.02),
        activation = "tanh"
    )

    # Encoder
    l_e1 = downsampler(64, apply_batch_normalization = False)(inputs)
    l_e2 = downsampler(128)(l_e1)
    l_e3 = downsampler(256)(l_e2)
    l_e4 = downsampler(512)(l_e3)
    l_e5 = downsampler(512)(l_e4)
    l_e6 = downsampler(512)(l_e5)
    l_e7 = downsampler(512)(l_e6)
    l_e8 = downsampler(512)(l_e7)
    
    # Conditioner
    l_c1 = pre_conditioner(int((input_dim1 / 256) * (input_dim2 / 256) * 512), 100)(l_e8)
    l_c2 = conditioner(int(input_dim1 / 256), int(input_dim1 / 256), 512)(concatenate([l_c1, class_vector]))

    # Decoder
    l_d1 = upsampler(512, apply_dropout = True)(concatenate([l_c2, l_e8]))
    l_d2 = upsampler(512, apply_dropout = True)(concatenate([l_d1, l_e7]))
    l_d3 = upsampler(512, apply_dropout = True)(concatenate([l_d2, l_e6]))
    l_d4 = upsampler(512)(concatenate([l_d3, l_e5]))
    l_d5 = upsampler(256)(concatenate([l_d4, l_e4]))
    l_d6 = upsampler(128)(concatenate([l_d5, l_e3]))
    l_d7 = upsampler(64)(concatenate([l_d6, l_e2]))
    
    last = last_layer(l_d7)
    
    return Model(inputs=[inputs, class_vector], outputs=last)

generator = Generator(image_size[0], image_size[1], len(TYPES))

In [None]:
for input_img, target_img, vector_class in train_dataset.take(1):
    gen_output = generator([((input_img + 1) * 255), vector_class], training=False)
    plt.imshow(gen_output[0,...])

In [None]:
def expand_vector(dim1, dim2):
    def f(x):
        x = tf.expand_dims(x, axis = 1)
        x = tf.expand_dims(x, axis = 1)
        x = tf.tile(x, [1, dim1, dim2, 1])
        return x
    return f

def Discriminator(input_dim1, input_dim2, class_num):
    input_img = Input(shape=[input_dim1, input_dim2, 3])
    generated_img = Input(shape=[input_dim1, input_dim2, 3])
    class_vector = Input(shape=[class_num])
                
    l_d1 = downsampler(64, apply_batch_normalization=False)(concatenate([input_img, generated_img]))
    l_l1 = Lambda(expand_vector(int(input_dim1 / 2), int(input_dim2 / 2)))(class_vector)

    l_d2 = downsampler(128)(concatenate([l_d1, l_l1]))
    l_d3 = downsampler(256)(l_d2)
    l_d4 = downsampler(512)(l_d3)
        
    last = Conv2D(
        filters = 1,
        kernel_size = 4,
        strides = 2,
        padding = "same",
        kernel_initializer = tf.random_normal_initializer(0, 0.02),
    )(l_d4)
    
    return Model(inputs=[input_img, generated_img, class_vector], outputs=last)

discriminator = Discriminator(image_size[0], image_size[1], len(TYPES))

In [None]:
for input_img, target_img, class_vector in train_dataset.take(1):
    gen_output = generator([((input_img + 1) * 255), class_vector], training=False)
    disc_out = discriminator([((input_img + 1) * 255), gen_output, class_vector], training=False)
    
    plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
    plt.colorbar()

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    
    total_loss = real_loss + generated_loss
    
    return total_loss

In [None]:
LAMBDA = 100

def generator_loss(disc_generated_output, generated_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    
    l1_loss = tf.reduce_mean(tf.abs(target - generated_output))
    
    total_loss = gan_loss + (LAMBDA * l1_loss)
    
    return total_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
checkpoint_prefix = CHECKPOINTS_PATH + "/chkp"

checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator
)

#checkpoint.restore(tf.train.latest_checkpoint(CHECKPOINT_PATH)).assert_consumed()

In [None]:
def generate_images(model, test_input, tar, class_vector, save_filename=False, display_imgs=True):
    prediction = model([test_input, class_vector], training = False)
    
    if save_filename:
        tf.keras.preprocessing.image.save_img(PATH + '/output/' + save_filename + ".jpg", prediction[0, ...])
        
    plt.figure(figsize=(10,10))
    
    display_list = [test_input[0], tar[0], prediction[0]]
    title = ["Input image", "Ground truth", "Predicted Image"]
    
    if display_imgs:
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.title(title[i])
            
            plt.imshow(display_list[i] * 0.5 + 0.5)
            plt.axis("off")
    
    plt.show()

In [None]:
@tf.function()
def train_step(input_image, target_image, class_vector):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
        generated_image = generator([input_image, class_vector], training=True)
    
        generated_image_disc = discriminator([generated_image, input_image, class_vector], training = True)
    
        target_image_disc = discriminator([target_image, input_image, class_vector], training = True)
    
        disc_loss = discriminator_loss(target_image_disc, generated_image_disc)
    
        gen_loss = generator_loss(generated_image_disc, generated_image, target_image)
    
        discriminator_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        generator_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
        
        discriminator_optimizer.apply_gradients(zip(discriminator_grads, discriminator.trainable_variables))
        generator_optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables))
        

In [None]:
from IPython.display import clear_output

def train(dataset, epochs):
    for epoch in range(epochs):
        img_counter = 0
        for input_image, target_image, class_vector in train_dataset:
            print("epoch %d - train: %d / %d" % (epoch, img_counter, len(train_ids)))
            img_counter += 1
            train_step(input_image, target_image, class_vector)
        
        clear_output(wait=True)

        img_counter = 0
        for input_image, target_image, class_vector in test_dataset.take(5):
            generate_images(generator, input_image, target_image, class_vector, "%d_%d" % (img_counter, epoch), display_imgs=True)
            img_counter += 1
            
        if (epoch + 1) % 25 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
train(train_dataset, 1000)