In [None]:
%matplotlib inline
collab_mode = False

if collab_mode:
    # set up tensorflow in collab
    %tensorflow_version 2.x
# imports
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import numpy as np

import warnings # This ignore all the warning messages
warnings.filterwarnings('ignore')

from os import path
import os
import time

import dataset_helpers as ds_helpers

print("Tensorflow version is", tf.__version__, ", device name", tf.test.gpu_device_name())

In [None]:
def allow_memory_growth():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

allow_memory_growth()

In [None]:
def get_time():
    return time.strftime("%d-%m-%Y_%H-%M-%S")

In [None]:
def process_image(img, image_shape):
    img = tf.cast(img, tf.float32)/127.5-1 # IMPORTANT, image's pixels are in the range <-1, 1>
    img = tf.image.resize(img, (64, 64))
    return img

def load_image(filename):
    img = tf.io.read_file(filename)
    img = tf.image.decode_jpeg(img)
    return img

def display_image_from_dataset(data):
    # Check image
    for batch in data.take(1):
        for img in iter(batch):
            img_ = (img+1)/2
            plt.imshow(img_)
            print(img_.shape, np.min(img_), np.max(img_))
            break

def save_generated_image(settings, epoch):
    save_dir = settings.generated_images_path
    if not path.exists(save_dir):
        os.makedirs(save_dir)
    name = path.join(save_dir,
                     'img_{}_{}.png'.format(epoch, get_time()))
    plt.savefig(name)


def show_images(images, epoch, settings, save_images=False, display_images=False):
    print("image pixels range", np.min(images), np.max(images), "std", np.std(images))
    num_of_images = min(10, images.shape[0])
    # (x, y=1)
    plt.figure(figsize=(num_of_images, 1))
    for i in range(num_of_images):
        plt.subplot(1, num_of_images, i + 1)
        img = images[i, :, :, :].numpy() #
        img = (img * 127.5 + 127.5).astype(np.uint8)
        plt.imshow(img)
        plt.axis('off')
    

    if save_images:
        save_generated_image(settings, epoch)
    if display_images:
        plt.show()

def load_dataset(dataset_path, image_shape, preprocess_images=True, shuffle_size=500, seed=101):
    img_path = path.join(dataset_path, '*.jpg')
    data = tf.data.Dataset.list_files(img_path, seed=seed)\
        .shuffle(shuffle_size)\
        .map(load_image)
    if preprocess_images:
        data = data.map(lambda x: process_image(x, image_shape))
    return data

In [None]:
class Settings:
    def __init__(self, collab_mode):
        self.root_local_path = os.getcwd()
        self.root_gdrive_path = '/content/drive'
        self.gdrive_project_path = 'My Drive/pp/GSN/FaceGenerator'
        self.dataset_name = "celeb_a"
        self.subdataset_dir="1000"
        self.dataset_image_size = (28, 28)
        self.image_size = (64, 64)
        self.image_channels = 3
        self.generator_input_shape = (self.dataset_image_size[0], self.image_channels)
        self.gdrive_mounted = False
        self.collab_mode = collab_mode
        self.batch_size = 100
        self.epochs = 100
        self.save_models = False # save models at the end?
        self.mount_gdrive()
        
    @property
    def run_name(self):
        return 'run_{}'.format(self.subdataset_dir)
        # return "{}_epochs_{}_batch_{}".format(self.epochs, self.batch_size, self.subdataset_dir)
    
    @property
    def image_shape(self):
        return (*self.image_size, self.image_channels)

    @property
    def download_path(self):
        return path.join(self.get_base_path, 'datasets', self.dataset_name)

    @property
    def dataset_path(self):
        return path.join(self.download_path, self.subdataset_dir)

    @property
    def tensorboard_log_dir(self):
        return path.join(self.get_base_path, 'saved_state', self.run_name, 'tensorboard_logs')

    @property
    def checkpoint_dir(self):
        return path.join(self.get_base_path, 'saved_state', self.run_name, "ckpt")

    @property
    def model_save_path(self):
        return path.join(self.get_base_path, 'saved_state', self.run_name, 'models')
    
    @property
    def generated_images_path(self):
        return path.join(self.get_base_path, 'saved_state', self.run_name, 'generated_images')
    
    @property
    def get_base_path(self):
        if self.collab_mode:
            return path.join(self.root_gdrive_path, self.gdrive_project_path)
        else:
            return self.root_local_path

    def mount_gdrive(self):
        if self.collab_mode:
            from google.colab import drive
            project_path = path.join(self.root_gdrive_path, self.gdrive_project_path)
            self.gdrive_project_path = path.join(self.root_gdrive_path, self.gdrive_project_path)
            drive.mount(self.root_gdrive_path)
            self.gdrive_mounted = True
        
            path_with_imports = path.join(self.root_gdrive_path, self.gdrive_project_path)
            print("Files in path", path_with_imports)
            !ls /content/drive/My\ Drive/pp/GSN/FaceGenerator
            if path_with_imports not in os.sys.path:
                os.sys.path.append(path_with_imports)
                
class DatasetCache:
    def __init__(self):
        self.path = ""
        self.batch_size = 0
        self._data = None

    def load_data(self, settings):
        if self._data is not None and self.path == settings.dataset_path:
            self.batch_size = settings.batch_size
            return self.data
        else:
            print("downloading and loading data")
            self.download_dataset(settings)
            self._data = load_dataset(settings.dataset_path, settings.image_size)
            self.data_path = settings.dataset_path
            self.batch_size = settings.batch_size
            return self.data
        
    @property
    def data(self):
        if self._data is None:
            return None
        else:
            return self._data.batch(self.batch_size)

    def download_dataset(self, settings):
        '''Downloads data to dataset_path/dataset_name directory'''
        print('dataset download path is {}'.format(settings.download_path))
        ds_helpers.download_extract('celeba', settings.download_path)
        
class TensorboardManager():
    def __init__(self):
        self.log_path = ''
        self.train_summary_writer = None
        self.test_summary_writer = None
        
    def initialize(self, settings):
        should_be_updated = False
        if settings.collab_mode and self.log_path != "tensorboard_logs":
            should_be_updated = True
            self.log_path = "tensorboard_logs"
        elif not settings.collab_mode and self.log_path != settings.tensorboard_log_dir:
            should_be_updated = True
            self.log_path = settings.tensorboard_log_dir
                
        if should_be_updated:
            self.train_summary_writer = tf.summary.create_file_writer(path.join(self.log_path, 'train'))
            self.test_summary_writer = tf.summary.create_file_writer(path.join(self.log_path, 'test'))
            print('Initialized tensorboard log dir with path', self.log_path)
            self.launch(settings.collab_mode)
        
    def launch(self, collab_mode):
        if collab_mode:
            %reload_ext tensorboard
            %tensorboard --logdir tensorboard_logs
            from tensorboard import notebook
            notebook.list() # View open TensorBoard instances
        else:
            print('open tensorboard with command')
            print('tensorboard --logdir {}'.format(self.log_path))
            
class Environment():
    def __init__(self, collab_mode):
        self.settings = Settings(collab_mode)
        self.models = dict()
        self.datasetCache = DatasetCache()
        self.checkpointManager = None
        self.tensorboard = None

In [None]:
# test downloading and loading data

dataset = DatasetCache()
dataset.load_data(Settings(collab_mode))

if not collab_mode:
    data = load_dataset("./datasets/celeb_a/1000", (64, 64))
    data = data.batch(100)
    display_image_from_dataset(data)

In [None]:
display_image_from_dataset(dataset.data)

Loss functions


In [None]:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def min_max_discriminator_loss(real_out, gen_out):
    real_loss = bce(tf.ones_like(real_out), real_out)
    gen_loss = bce(tf.zeros_like(gen_out), gen_out)
    return real_loss + gen_loss


def min_max_generator_loss(gen_out):
    return - min_max_discriminator_loss(tf.ones_like(gen_out), gen_out)


def w_discriminator_loss(real_out, gen_out):
    res = - (tf.reduce_mean(real_out) - tf.reduce_mean(gen_out))
    return res


def w_generator_loss(gen_out):
    return - tf.reduce_mean(gen_out)

Generator


In [None]:
class GANNetwork(tf.keras.Model):
    def __init__(self, model_name="Network", **kwargs):
        super().__init__(name=model_name, **kwargs)
        self.model = None

    def print_layers(self):
        print(self.model)
        for layer in self.model:
            print(layer.name, ":", layer.input_shape, "->", layer.output_shape)

    def summary(self):
        self.model.summary()

    @tf.function
    def call(self, data, training):
        return self.model(data)

    def save_model(self, save_path):
        if not path.exists(save_path):
            os.makedirs(save_path)
        filename = path.join(save_path, '{}_{}'.format(self.name, get_time()))
        print("Saving model", self.name, "as", filename)
        self.model.save(filename)

In [None]:
class Generator(GANNetwork):
    def __init__(self, input_shape, model_name="Generator", **kwargs):
        super().__init__(model_name, **kwargs)
        self.model = tf.keras.Sequential([
            # flat
            tf.keras.layers.Dense(1024, input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Reshape([16, 16, 256]),
            # conv without stride (16x16)
            tf.keras.layers.Conv2D(256, 5, 1, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # t_conv with stride (32x32)
            tf.keras.layers.Conv2DTranspose(128, 5, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # conv without stride (32x32)
            tf.keras.layers.Conv2D(64, 5, 1, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # t_conv with stride (64x64)
            tf.keras.layers.Conv2DTranspose(32, 5, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            # conv without stride
            tf.keras.layers.Conv2D(3, (1, 1), 1, 'same')
        ])

Discriminator


In [None]:
class Discriminator(GANNetwork):

    def __init__(self, input_shape, model_name="Discriminator", **kwargs):
        super().__init__(model_name, **kwargs)

        # since discriminator is for classification it should be robust, thus, add
        # additional regularization like dropout to prevent from pixel attacks
        self.model = tf.keras.Sequential([
            # conv with stride (32x32)
            tf.keras.layers.Conv2D(64, 5, 2, 'same', input_shape=input_shape),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # conv with stride (16x16x128)
            tf.keras.layers.Conv2D(128, 3, 2, 'same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # flatten + hidden layer
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dropout(0.3),
            # prediction (LOGITS!)
            tf.keras.layers.Dense(1)
        ])

Noise generator


In [None]:
class NoiseGenerator(tf.keras.layers.Layer):

    def __init__(self, distribution_size):
        super().__init__()
        self.distribution_size = distribution_size
        # self.data_distributions = self.add_weight(shape=(num_classes, distribution_size), trainable=True)
        # self.data_distributions = tf.tile(tf.range(0, num_classes, dtype=tf.float32)[:, tf.newaxis], [1, distribution_size])
        # TODO:

    def call(self, inputs):
        # dists = tf.nn.embedding_lookup(self.data_distributions, inputs)
        # dists += tf.random.uniform(tf.shape(dists), -0.35, 0.35)
        # return dists
        # TODO
        return tf.random.normal([tf.shape(inputs)[0], self.distribution_size, 3])
        
    def diverse_distributions_loss(self):
        # TODO
        return None

Training step



In [None]:
def train_step_template(generator, discriminator, noise, d_optim, g_optim, d_loss_f, g_loss_f, summary_writer):

    @tf.function
    def _train_step_template(images, epoch):
        with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
            real_out = discriminator(images, True)
            gen_out = discriminator(generator(noise(images), True), True)

            d_loss = d_loss_f(real_out, gen_out)
            g_loss = g_loss_f(gen_out)

        if summary_writer is not None:
            with summary_writer.as_default():
                pass
                #TODO: d_loss and g_loss are not scalars, so they cannot be saved using lines below
                #tf.summary.scalar('discriminator_loss', d_loss.result(), step=epoch)
                #tf.summary.scalar('generator_loss', g_loss.result(), step=epoch)

        d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables + noise.trainable_variables)

        d_optim.apply_gradients(zip(d_grads, discriminator.trainable_variables))
        g_optim.apply_gradients(zip(g_grads, generator.trainable_variables + noise.trainable_variables))

    return _train_step_template


Inference step



In [None]:
def gen_step_template(generator, noise):
    @tf.function
    def _gen_step_template(images):
        return tf.clip_by_value(generator(noise(images), False), -1, 1)

    return _gen_step_template

Training

In [None]:
def train(train_step, gen_step, epochs, data, settings, ckptManager, save_images=True, display_images=True):
    for epoch in range(epochs):
        epoch_start = time.time()
        for images in data:
            train_step(images, epoch)

        epoch_end = time.time()
        print('-'*30)
        print('Epoch {0}/{1}, duration {2}'.format(epoch+1, epochs, epoch_end-epoch_start))
        if (epoch + 1) % 5 == 0 or epoch == epochs-1:
            print("Saving checkpoint, epoch", epoch+1)
            ckptManager.save()

        images_to_generate = [img for img in data.take(1)][0].numpy() # take one batch from train_data
        generated = gen_step(images_to_generate)
        show_images(generated, epoch, settings, save_images=save_images, display_images=display_images)
        print('+'*30)


In [None]:
# check how images are displayed/saved
def test_image_generation(settings):
    save_images = True
    display_images = True
    batch_size = 10
    data = load_dataset(batch_size=settings.batch_size)
    
    generator_input_shape = (64, 3)
    generator = Generator(input_shape=generator_input_shape)
    noise = NoiseGenerator(64)
    
    gen_step = gen_step_template(
        generator=generator,
        noise=noise
    )
    images_to_generate = [img for img in data.take(1)][0].numpy() # take one batch from train_data
    generated = gen_step(images_to_generate)
    show_images(generated, -1, settings, save_images=save_images, display_images=display_images)
    
# test_image_generation(Settings(collab_mode))

Training with Wasserstein loss function

In [None]:
def get_models(settings):
    generator = Generator(input_shape=settings.generator_input_shape)
    generator.build((None, *settings.generator_input_shape))
#     generator.summary()

    discriminator = Discriminator(input_shape=settings.image_shape)
    discriminator.build(input_shape=(None, *settings.image_shape))
#     discriminator.summary()
    
    return generator, discriminator

In [None]:
env = Environment(collab_mode)
env.settings.epochs = 2
env.settings.batch_size = 100
env.settings.save_models = True
# load data
env.datasetCache.load_data(env.settings)
# set models
env.models['generator'], env.models['discriminator'] = get_models(env.settings)
# setup tensorboard
env.tensorboard = TensorboardManager()
env.tensorboard.initialize(env.settings)

noise = NoiseGenerator(64)
d_optim = tf.keras.optimizers.Adam(1e-4)
g_optim = tf.keras.optimizers.Adam(1e-4)

checkpoint = tf.train.Checkpoint(generator_optimizer=g_optim,
                                 discriminator_optimizer=d_optim,
                                 generator=env.models['generator'],
                                 discriminator=env.models['discriminator'])
env.checkpointManager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                   directory=env.settings.checkpoint_dir,
                                                   max_to_keep=3
                                                  )
if env.checkpointManager.latest_checkpoint:
    print("restoring state from", env.checkpointManager.latest_checkpoint)
    checkpoint\
        .restore(env.checkpointManager.latest_checkpoint)
    
train_step = train_step_template(
    generator=env.models['generator'],
    discriminator=env.models['discriminator'],
    noise=noise,
    d_optim=d_optim,
    g_optim=g_optim,
    d_loss_f=w_discriminator_loss,
    g_loss_f=w_generator_loss,
    summary_writer=env.tensorboard.train_summary_writer
)

gen_step = gen_step_template(
    generator=env.models['generator'],
    noise=noise
)

In [None]:
print("Start time", get_time())
print('%'*30)
start = time.time()

train(
    train_step=train_step,
    gen_step=gen_step,
    epochs=env.settings.epochs,
    data=env.datasetCache.data,
    settings=env.settings,
    ckptManager=env.checkpointManager
)

end = time.time()
print('%'*30)
print("End time", get_time())
print("seconds elapsed", end - start)

if env.settings.save_models:
    print('saving models')
    for model in env.models:
        env.models[model].save_model(env.settings.model_save_path)


uwagi:
* na początku generowane obrazki są białe, bardzo małe odchylenie w wartościach pikseli ok 17 dla skali 0-255
* generator używa tylko skali np 52-160
* później generator uczy się zwiększać odchylenie i wartości pikseli na obrazkach zwiększają się do przedziału 0-255
