In [None]:
%matplotlib inline
collab_mode = False

if collab_mode:
    # set up tensorflow
    %tensorflow_version 2.x
# imports
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np

import warnings # This ignore all the warning messages
warnings.filterwarnings('ignore')

from os import path
import os
import time

print(tf.__version__)

In [None]:
root_local_path = os.getcwd()
root_gdrive_path = '/content/drive'
gdrive_project_path = 'My Drive/pp/GSN/FaceGenerator'
checkpoints_path = 'checkpoints'
dataset_path = 'datasets'
dataset_name = "celeb_a"
tensorboard_logs_dir='tensorboard'
download_path = '' # output path for the dataset
generated_images_path = 'generated_images'
dataset_image_size = (28, 28)
run_name = 'default'
gdrive_mounted = False

In [None]:
def allow_memory_growth():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

# run the line below if you're using local runtime and have GTX > 1660 (this is known bug with tensorflow memory allocation)
# allow_memory_growth()

allow_memory_growth()
print("Getting device name")
tf.test.gpu_device_name()

Misc helper functions

In [None]:
def get_time():
    return time.strftime("%d-%m-%Y-_%H-%M-%S") 

# Mount gdrive disk if necessary

In [None]:
if collab_mode:
    from google.colab import drive
    project_path = path.join(root_gdrive_path, gdrive_project_path )
    gdrive_project_path = path.join(root_gdrive_path, gdrive_project_path)
    drive.mount(root_gdrive_path)
    gdrive_mounted = True

def get_base_path():
    if collab_mode:
        return path.join(root_gdrive_path, gdrive_project_path)
    else:
        return root_local_path 

## Import dataset_helpers

In [None]:
if collab_mode:
    path_with_imports = path.join(root_gdrive_path, gdrive_project_path)
    print("Files in path", path_with_imports)
    !ls /content/drive/My\ Drive/pp/GSN/FaceGenerator
    if path_with_imports not in os.sys.path:
        os.sys.path.append(path_with_imports)

import dataset_helpers as ds_helpers

### Download dataset

In [None]:
def download_dataset():
    '''Downloads data to dataset_path/dataset_name directory'''
    if collab_mode:
        download_path = path.join(root_gdrive_path, gdrive_project_path, dataset_path, dataset_name)
    else:
        download_path = path.join(root_local_path, dataset_path, dataset_name)
    
    print('dataset download path is {}'.format(download_path))
    ds_helpers.download_extract('celeba', download_path)

download_dataset()

### Load dataset

In [None]:
download_path = path.join(root_local_path, dataset_path, 'celeb_a')
img_path = path.join(download_path, '50000\\*.jpg')
IMAGE_SIZES = (64, 64)
IMAGE_CHANNELS=3
IMAGES_SHAPE = (*IMAGE_SIZES, IMAGE_CHANNELS)
print(IMAGES_SHAPE)
def process_image(img):
    img = tf.cast(img, tf.float32)/127.5-1 # IMPORTANT, image's pixels are in the range <-1, 1>
    img = tf.image.resize(img, IMAGE_SIZES)
    return img

def load_image(filename):
    img = tf.io.read_file(filename)
    img = tf.image.decode_jpeg(img)
    return img

def load_dataset(batch_size, preprocess_images=True, shuffle_size=500, seed=101):
    data = tf.data.Dataset.list_files(img_path, seed=seed)\
        .shuffle(shuffle_size)\
        .map(load_image)
    if preprocess_images:
        data = data.map(process_image)
    return data.batch(batch_size)
    
data = load_dataset(batch_size=100)

In [None]:
def display_image_from_dataset():
    # Check image
    for batch in data.take(1):
        for img in iter(batch):
            img_ = (img+1)/2
            plt.imshow(img_)
            print(img_.shape, np.min(img_), np.max(img_))
            break
            
display_image_from_dataset()

### Saving functions

In [None]:
def save_generated_image(epoch):
    save_dir = path.join(get_base_path(), generated_images_path, run_name)
    if not path.exists(save_dir):
        os.mkdir(save_dir)
    name = path.join(save_dir,
                     'img_{}_{}.png'.format(epoch, get_time()))
    plt.savefig(name)


def show_images(images, epoch, save_images=False, display_images=False):
    print("image pixels range", np.min(images), np.max(images), "std", np.std(images))
    
    num_of_images = min(10, images.shape[0])
    plt.figure(figsize=(num_of_images, 1))
    # print("saving images with shape", images.shape, "image size", num_of_images)
    # print("showing", num_of_images, "images")
    for i in range(num_of_images):
        plt.subplot(1, num_of_images, i + 1)
        img = images[i, :, :, :].numpy() #
        # print("values of image", i, np.min(img), np.max(img))
        img = (img * 127.5 + 127.5).astype(np.uint8)
        # print("values of image", i, np.min(img), np.max(img), np.std(img))
        plt.imshow(img)
        plt.axis('off')
    

    if save_images:
        save_generated_image(epoch)
    if display_images:
        plt.show()

In [None]:
# def save_models(generator, discriminator):
#     def save(epoch_number):
#         path = get_path()
#     return save

In [None]:
log_dir_path = path.join(get_base_path(), tensorboard_logs_dir, run_name)
if collab_mode:
    %reload_ext tensorboard

    print('tensorboard log dir {}'.format(log_dir_path))
    %tensorboard --logdir logs
    from tensorboard import notebook
    notebook.list() # View open TensorBoard instances
else:
    print('open tensorboard with command')
    print('tensorboard --logdir {}'.format(log_dir_path))


In [None]:
train_log_dir = 'logs/gradient_tape/' + run_name + get_time() + '/train'
test_log_dir = 'logs/gradient_tape/' + run_name + get_time() + '/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

checkpoint_dir = os.path.join('.', 'training_checkpoints', run_name)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

def restore_from_checkpoint():
    pass

Loss functions


In [None]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def min_max_discriminator_loss(real_out, gen_out):
    real_loss = bce(tf.ones_like(real_out), real_out)
    gen_loss = bce(tf.zeros_like(gen_out), gen_out)
    return real_loss + gen_loss


def min_max_generator_loss(gen_out):
    return - min_max_discriminator_loss(tf.ones_like(gen_out), gen_out)


def w_discriminator_loss(real_out, gen_out):
    res = - (tf.reduce_mean(real_out) - tf.reduce_mean(gen_out))
    return res


def w_generator_loss(gen_out):
    return - tf.reduce_mean(gen_out)

In [None]:
def print_layers(model):
    for layer in model.layers:
            print(layer.name, ":", layer.input_shape, "->", layer.output_shape)

Generator


In [None]:
class Generator(tf.keras.Model):

    def __init__(self, input_shape, model_name="Generator", **kwargs):
        super(Generator, self).__init__(name=model_name, **kwargs)
#         print("input ", input_shape)
        self.noise_decoder = tf.keras.Sequential([
            # flat
            tf.keras.layers.Dense(1024, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Reshape([16, 16, 256]),
            # conv without stride (16x16)
            tf.keras.layers.Conv2D(256, 5, 1, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # t_conv with stride (32x32)
            tf.keras.layers.Conv2DTranspose(128, 5, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # conv without stride (32x32)
            tf.keras.layers.Conv2D(64, 5, 1, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # t_conv with stride (64x64)
            tf.keras.layers.Conv2DTranspose(32, 5, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # conv without stride
            tf.keras.layers.Conv2D(3, (1, 1), 1, 'same')
        ])
        print_layers(self.noise_decoder)

    def call(self, noise, training):
        return self.noise_decoder(noise)
    
    def summary(self):
        self.noise_decoder.summary()

generator_input_shape = (64, 3)
generator = Generator(input_shape=generator_input_shape)
# generator.build((None, *generator_input_shape))
generator.summary()

# generator = Generator(input_shape=IMAGES_SHAPE)
# generator.build((None, *IMAGES_SHAPE))
# generator.summary()
# tf.keras.utils.plot_model(generator, "gen.png")

Discriminator


In [None]:
class Discriminator(tf.keras.Model):

    def __init__(self, input_shape, model_name="Discriminator", **kwargs):
        super(Discriminator, self).__init__(name=model_name, **kwargs)

        # since discriminator is for classification it should be robust, thus, add
        # additional regularization like dropout to prevent from pixel attacks
        self.image_encoder = tf.keras.Sequential([
            # conv with stride (32x32)
            tf.keras.layers.Conv2D(64, 5, 2, 'same', input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # conv with stride (16x16x128)
            tf.keras.layers.Conv2D(128, 3, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # flatten + hidden layer
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # prediction (LOGITS!)
            tf.keras.layers.Dense(1)
        ])
        print_layers(self.image_encoder)

    def call(self, images, training):
        return self.image_encoder(images)
    
    def summary(self):
        self.image_encoder.summary()

# discriminator = Discriminator(input_shape=IMAGES_SHAPE)
# discriminator.build(input_shape=(None, *IMAGES_SHAPE))
# discriminator.summary()
# start_training()

Noise generator


In [None]:
class NoiseGenerator(tf.keras.layers.Layer):

    def __init__(self, distribution_size):
        super().__init__()
        self.distribution_size = distribution_size
        # self.data_distributions = self.add_weight(shape=(num_classes, distribution_size), trainable=True)
        # self.data_distributions = tf.tile(tf.range(0, num_classes, dtype=tf.float32)[:, tf.newaxis], [1, distribution_size])
        # TODO:

    def call(self, inputs):
        # dists = tf.nn.embedding_lookup(self.data_distributions, inputs)
        # dists += tf.random.uniform(tf.shape(dists), -0.35, 0.35)
        # return dists
        # TODO
        return tf.random.normal([tf.shape(inputs)[0], self.distribution_size, 3])
        
    def diverse_distributions_loss(self):
        # TODO
        return None

Training step



In [None]:
def train_step_template(generator, discriminator, noise, d_optim, g_optim, d_loss_f, g_loss_f):

    @tf.function
    def _train_step_template(images):
        with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
            real_out = discriminator(images, True)
            gen_out = discriminator(generator(noise(images), True), True)

            d_loss = d_loss_f(real_out, gen_out)
            g_loss = g_loss_f(gen_out)

        d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables + noise.trainable_variables)

        d_optim.apply_gradients(zip(d_grads, discriminator.trainable_variables))
        g_optim.apply_gradients(zip(g_grads, generator.trainable_variables + noise.trainable_variables))

    return _train_step_template


In [None]:
batch_size = 100
train_data = load_dataset(batch_size=batch_size)


Inference step



In [None]:
def gen_step_template(generator, noise):
    @tf.function
    def _gen_step_template(images):
        return tf.clip_by_value(generator(noise(images), False), -1, 1)

    return _gen_step_template

Training

In [None]:
def train(train_step, gen_step, epochs, data, save_images=True, display_images=True, checkpoint=None):
    for epoch in range(epochs):
        epoch_start = time.time()
        for images in data:
            train_step(images)

        epoch_end = time.time()
        print('$+'*30)
        print('Epoch {0}/{1}, duration {2}'.format(epoch, epochs, epoch_end-epoch_start))
        if (epoch + 1) % 5 == 0 and checkpoint is not None:
            print("Saving checkpoint")
            checkpoint.save(file_prefix=checkpoint_prefix)

        images_to_generate = [img for img in data.take(1)][0].numpy() # take one batch from train_data
        generated = gen_step(images_to_generate)
        show_images(generated, epoch, save_images=save_images, display_images=display_images)
        print('$-'*30)


In [None]:
# check how images are displayed/saved
def test_image_generation():
    save_images = True
    display_images = True
    batch_size = 10
    data = load_dataset(batch_size=batch_size)
    
    generator_input_shape = (64, 3)
    generator = Generator(input_shape=generator_input_shape)
    noise = NoiseGenerator(64)
    
    gen_step = gen_step_template(
        generator=generator,
        noise=noise
    )
    images_to_generate = [img for img in data.take(1)][0].numpy() # take one batch from train_data
    generated = gen_step(images_to_generate)
    show_images(generated, -1, save_images=save_images, display_images=display_images)
    
# test_image_generation()

Training with Wasserstein loss function

In [None]:
generator_input_shape = (64, 3)
generator = Generator(input_shape=generator_input_shape)
generator.build((None, *generator_input_shape))
generator.summary()

discriminator = Discriminator(input_shape=IMAGES_SHAPE)
discriminator.build(input_shape=(None, *IMAGES_SHAPE))
discriminator.summary()

In [None]:
def start_training():
    noise = NoiseGenerator(64)
    d_optim = tf.optimizers.Adam(1e-4)
    g_optim = tf.optimizers.Adam(1e-4)

    checkpoint = tf.train.Checkpoint(generator_optimizer=g_optim,
                                     discriminator_optimizer=d_optim,
                                     generator=generator,
                                     discriminator=discriminator)

    train_step = train_step_template(
        generator=generator,
        discriminator=discriminator,
        noise=noise,
        d_optim=d_optim,
        g_optim=g_optim,
        d_loss_f=w_discriminator_loss,
        g_loss_f=w_generator_loss,
    )

    gen_step = gen_step_template(
        generator=generator,
        noise=noise
    )

    train(
        train_step=train_step,
        gen_step=gen_step,
        epochs=1,
        data=train_data,
        checkpoint=checkpoint
    )

run_name="first"

print("Start time", get_time())
print('%'*30)
start = time.time()

start_training()
# na początku generowane obrazki są białe, bardzo małe odchylenie w wartościach pikseli ok 17 dla skali 0-255
# generator używa tylko skali np 52-160
# później generator uczy się zwiększać odchylenie i wartości pikseli na obrazkach zwiększają się do przedziału 0-255

end = time.time()
print('%'*30)
print("End time", get_time())
print("seconds elapsed", end - start)
