In [None]:
import os
import numpy as np
import tensorflow as tf
import pandas as pd
import datetime
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import itertools as iter

physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(physical_devices))


DATA_PATH = os.getenv('DATA_PATH')
print(DATA_PATH)


## Reading tfrecords

In [None]:
filpaths=[]
for root,dirs,files in os.walk(os.path.join(DATA_PATH, "monet_tfrec")):
    for f in files:
       filpaths.append(os.path.join(root,f))

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

In [None]:
# datasetMonet = load_dataset(filpaths, labeled=True)

## Larger dataset

In [None]:
allFiles=[]
for root, dirs, files in os.walk(os.path.join(DATA_PATH, "processed")):
    for f in files:
        allFiles.append(os.path.join(root, f))

In [None]:
def generator():
    for i,file in enumerate(allFiles):
        image = tf.io.read_file(file)
        image = tf.io.decode_jpeg(image, channels=3)
        image = tf.image.random_brightness(image, 0.1)
        image = tf.image.random_contrast(image, 0.95, 1.05)
        image = (tf.cast(image, tf.float32) / 127.5) - 1
        yield image

In [None]:
g = generator()
test = next(g)
test.shape

In [None]:
test = next(g)
plt.imshow(((test.numpy()+1)/2))
plt.axis('off')

In [None]:
dataset = tf.data.Dataset.from_generator(generator, output_signature=(tf.TensorSpec(shape=(*IMAGE_SIZE, 3), dtype=tf.float32)))

## Tensorflow

In [None]:
%load_ext tensorboard


In [None]:
from tensorflow.keras import layers

def upscaleBlock(xIn, channelsBefore, channelsAfter, filtersize):
    x=tf.keras.layers.Conv2D(channelsBefore, filtersize, padding="same")(xIn)
    x=tf.keras.layers.BatchNormalization()(x)
    x=tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.UpSampling2D(size=(2, 2))(x)
    x = tf.keras.layers.Conv2D(channelsAfter, filtersize, padding="same", kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x=tf.keras.layers.LeakyReLU()(x)
    return x


def createGenerator():
    inputs = tf.keras.Input(shape=(1024,))
    # x=layers.Embedding(VOCAB_SIZE, EMBED_DIM, name="embedding")(inputs)

    x=tf.keras.layers.Dense(4*4*32, kernel_regularizer="l2")(inputs)
    x=tf.keras.layers.BatchNormalization()(x)
    x=tf.keras.layers.LeakyReLU()(x)

    x=tf.keras.layers.Reshape((4,4,32))(x)

    x = upscaleBlock(x, 32, 32, 3)

    x = upscaleBlock(x, 32, 16, 3)

    x = upscaleBlock(x, 16, 8, 3)

    x = upscaleBlock(x, 8, 8, 3)

    x = upscaleBlock(x, 8, 4, 3)

    x = upscaleBlock(x, 4, 2, 3)
    
    x=tf.keras.layers.Conv2D(3, 5, padding="same", kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(x)
    outputs = tf.keras.activations.tanh(x)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name="generator")

generator = createGenerator()

generator.summary()

dot_img_file = './'+generator.name +'.png'
tf.keras.utils.plot_model(generator, to_file=dot_img_file, show_shapes=True)


In [None]:
def downscaleBlock(xIn, channelsBefore, channelsAfter, filtersize):
    x = tf.keras.layers.Conv2D(channelsBefore, filtersize, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(xIn)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.Conv2D(channelsAfter, filtersize, strides=2, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    return x


def createDiscriminator():
    inputs = tf.keras.Input(shape=(*IMAGE_SIZE, 3))


    x = downscaleBlock(inputs, 2, 4, 3)

    x = downscaleBlock(x, 4, 8, 3)
    
    x = downscaleBlock(x, 8, 16, 3)
    
    x = downscaleBlock(x, 16, 32, 3)


    x = tf.keras.layers.Conv2D(32,3, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Flatten()(x)

    x = tf.keras.layers.Dense(64, kernel_regularizer=tf.keras.regularizers.L2(l2=0.05))(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Dense(64, kernel_regularizer=tf.keras.regularizers.L2(l2=0.01))(x)
    x = tf.keras.layers.LeakyReLU()(x)

    # outputs=layers.Dense(1, activation=tf.keras.activations.sigmoid)(x)
    outputs=layers.Dense(1,)(x)

    return tf.keras.Model(inputs=inputs, outputs=outputs, name="Discriminator")

discriminator = createDiscriminator()

discriminator.summary()

dot_img_file = './'+discriminator.name +'.png'
tf.keras.utils.plot_model(discriminator, to_file=dot_img_file, show_shapes=True)

In [None]:
loadExistingModel=False

if loadExistingModel:
    generator = tf.keras.models.load_model(os.path.join(DATA_PATH, "generator_epoch_195"))
    discriminator = tf.keras.models.load_model(os.path.join(DATA_PATH, "discriminator_epoch_195"))

In [None]:
def saveImages(model, epoch):
  testInput = tf.random.uniform((9, 1024))
  predictions = model(testInput, training=False)

  fig = plt.figure(figsize=(5, 5))

  for i in range(predictions.shape[0]):
    plt.subplot(3, 3, i+1)
    plt.imshow(np.rint(predictions[i, :, :, :] * 127.5 + 127.5).astype(int))
    plt.axis('off')

  plt.savefig(os.path.join(DATA_PATH,"trainImages", 'epoch_{:04d}.png'.format(epoch)))

In [None]:

BATCH_SIZE=64
LOG_INTERVAL=5
epochs = 200
saveModel=True

startEpoch=0

Td=2
Tg=3

log_dir = "./logs/"+generator.name+"/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1,
                                                      write_graph=True, update_freq=5)


summary_writer = tf.summary.create_file_writer(log_dir)

# Instantiate an optimizer .
# optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)
# optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
optimizerGen = tf.keras.optimizers.Adam(learning_rate=1e-4)
optimizerDis = tf.keras.optimizers.Adam(learning_rate=1e-4)

# Instantiate a loss function.
lossFnGen = tf.keras.losses.MeanSquaredError()
# lossFnGen = tf.keras.losses.BinaryCrossentropy()
lossFnDis = tf.keras.losses.MeanSquaredError()
# lossFnDis = tf.keras.losses.BinaryCrossentropy()

accuracyDis = tf.keras.metrics.BinaryAccuracy()

# batchedDataset = datasetMonet.batch(BATCH_SIZE, drop_remainder=False)
datasetShuffled = dataset.shuffle(200)
batchedDataset = datasetShuffled.batch(BATCH_SIZE, drop_remainder=False).prefetch(tf.data.AUTOTUNE)


@tf.function()
def trainStepGen(trueImages):
    input = tf.random.uniform((BATCH_SIZE, 1024))
    with tf.GradientTape() as tapeGen, tf.GradientTape() as tapeDis:
        fakeImages = generator(input, training=True) 

        discOutputFake = discriminator(fakeImages, training=True)
        discOutputTrue = discriminator(trueImages, training=True)

        lossGen = lossFnGen(tf.ones_like(discOutputFake), discOutputFake)

        lossDisTrue = lossFnDis(tf.ones_like(discOutputTrue), discOutputTrue)
        lossDisFake = lossFnDis(tf.zeros_like(discOutputFake), discOutputFake)
        totalLossDis = lossDisTrue + lossDisFake

    gradsGen = tapeGen.gradient(lossGen, generator.trainable_weights)

    accuracyDis.update_state(tf.zeros_like(discOutputFake), discOutputFake)
    accuracyDis.update_state(tf.ones_like(discOutputTrue), discOutputTrue)

    optimizerGen.apply_gradients(zip(gradsGen, generator.trainable_weights)) 


    return lossGen, totalLossDis

@tf.function()
def trainStepDis(trueImages):
    input = tf.random.uniform((BATCH_SIZE, 1024))
    with tf.GradientTape() as tapeGen, tf.GradientTape() as tapeDis:
        fakeImages = generator(input, training=True) 

        discOutputFake = discriminator(fakeImages, training=True)
        discOutputTrue = discriminator(trueImages, training=True)

        lossGen = lossFnGen(tf.ones_like(discOutputFake), discOutputFake)

        lossDisTrue = lossFnDis(tf.ones_like(discOutputTrue), discOutputTrue)
        lossDisFake = lossFnDis(tf.zeros_like(discOutputFake), discOutputFake)
        totalLossDis = lossDisTrue + lossDisFake

    gradsDis = tapeDis.gradient(totalLossDis, discriminator.trainable_weights)

    accuracyDis.update_state(tf.zeros_like(discOutputFake), discOutputFake)
    accuracyDis.update_state(tf.ones_like(discOutputTrue), discOutputTrue)

    optimizerDis.apply_gradients(zip(gradsDis, discriminator.trainable_weights)) 

    return lossGen, totalLossDis

@tf.function()
def trainStepGenDis(trueImages):
    input = tf.random.uniform((BATCH_SIZE, 1024))
    with tf.GradientTape() as tapeGen, tf.GradientTape() as tapeDis:
        fakeImages = generator(input, training=True) 

        discOutputFake = discriminator(fakeImages, training=True)
        discOutputTrue = discriminator(trueImages, training=True)

        lossGen = lossFnGen(tf.ones_like(discOutputFake), discOutputFake)

        lossDisTrue = lossFnDis(tf.ones_like(discOutputTrue), discOutputTrue)
        lossDisFake = lossFnDis(tf.zeros_like(discOutputFake), discOutputFake)
        totalLossDis = lossDisTrue + lossDisFake

    gradsGen = tapeGen.gradient(lossGen, generator.trainable_weights)
    gradsDis = tapeDis.gradient(totalLossDis, discriminator.trainable_weights)

    accuracyDis.update_state(tf.zeros_like(discOutputFake), discOutputFake)
    accuracyDis.update_state(tf.ones_like(discOutputTrue), discOutputTrue)

    optimizerGen.apply_gradients(zip(gradsGen, generator.trainable_weights)) 
    optimizerDis.apply_gradients(zip(gradsDis, discriminator.trainable_weights)) 

    return lossGen, totalLossDis



maxStep=len(allFiles)//BATCH_SIZE
# Train the discriminator only every discHandicap steps
discHandicap = 1

for epoch in np.arange(startEpoch, startEpoch+epochs, 1):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(batchedDataset):
        if epoch<Td:
            lossGen, lossDis = trainStepDis(x_batch_train)
        elif epoch<Td+Tg:
            trainGen=True
            trainDis=False
            lossGen, lossDis = trainStepGen(x_batch_train)
        else:
            if step%discHandicap==0:
                lossGen, lossDis = trainStepGenDis(x_batch_train)
            else:
                lossGen, lossDis = trainStepGen(x_batch_train)


        # Log 
        if step % LOG_INTERVAL == 0:
            accDis = accuracyDis.result().numpy()
            template = 'Epoch {}/Step {}, Loss Generator: {:.4f}, Loss Discriminator: {:.4f}, Accuracy Dis: {:.4f}'
            print(template.format(epoch, step, lossGen.numpy(), lossDis.numpy(),  accDis))
            
            with summary_writer.as_default():
                tf.summary.scalar('lossGen', lossGen, step=maxStep*epoch+step)
                tf.summary.scalar('lossDis', lossDis, step=maxStep*epoch+step)
                tf.summary.scalar('Disc Accuracy', accuracyDis.result().numpy(), step=maxStep*epoch+step)
                summary_writer.flush()

    maxStep=step

    # Adaptive disc handicap
    if accDis>0.95 and discHandicap<10 and epoch>=Td+Tg:
        discHandicap += 1
        print("Decrease disc training frequency to every {} steps".format(discHandicap))
    if accDis<0.65 and discHandicap>1 and epoch>=Td+Tg:
        discHandicap -= 1
        print("Increase disc training frequency to every {} steps".format(discHandicap))

    accuracyDis.reset_state()

    print("Saving images")
    saveImages(generator, epoch)

    if saveModel and epoch%10==0:
      generator.save(os.path.join(DATA_PATH, "generator_"+"epoch_{}".format(epoch)))
      discriminator.save(os.path.join(DATA_PATH, "discriminator_"+"epoch_{}".format(epoch)))

In [None]:
testInput = tf.random.uniform((9, 1024))
predictions = generator(testInput, training=False)

fig = plt.figure(figsize=(5, 5))

for i in range(predictions.shape[0]):
    plt.subplot(3, 3, i+1)
    plt.imshow(np.rint(predictions[i, :, :, :] * 127.5 + 127.5).astype(int))
    plt.axis('off')

plt.show()

In [None]:
tf.keras.models.save_model(generator, os.path.join(DATA_PATH, "generator"))
tf.keras.models.save_model(discriminator, os.path.join(DATA_PATH, "discriminator"))