In [None]:
import tensorflow as tf
# tfe = tf.contrib.eager
#tf.enable_eager_execution()

from tensorflow.keras import layers

from tensorflow.keras import backend as K

from tensorflow.keras.backend import batch_flatten

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
#import PIL
import imageio
from IPython import display
import pathlib
AUTOTUNE=tf.data.experimental.AUTOTUNE

In [None]:
tf.__version__

In [None]:
def get_paths(directory):
    dir=pathlib.Path.cwd()/directory
    all_image_paths=list(dir.glob('*'))
    return [str(path) for path in all_image_paths]


def preprocess_image(image, size):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (size, size))
    image /= 255.0  # normalize to [0,1] range
    #image = tf.image.convert_image_dtype(image, tf.float16)
    return image

def load_and_preprocess_image(path, size):
    image = tf.io.read_file(path)
    return preprocess_image(image, size)

def from_path_to_tensor(paths, batch_size, size):
    path_ds=tf.data.Dataset.from_tensor_slices(paths)
    ds=path_ds.map(lambda x: load_and_preprocess_image(x, size), num_parallel_calls=1)
    #ds=ds.repeat()
    #ds=ds.shuffle(5000)
    ds=ds.batch(batch_size)
    ds=ds.prefetch(buffer_size=1)
    return ds


In [None]:
# got some model ideas from here https://medium.com/@jonathan_hui/gan-dcgan-deep-convolutional-generative-adversarial-networks-df855c438f
# http://karpathy.github.io/2019/04/25/recipe/
# https://towardsdatascience.com/deciding-optimal-filter-size-for-cnns-d6f7b56f9363

# image dim must be divisible by 8
class VAE(tf.keras.Model):
    def __init__(self, latent_dim, image_dim, mode, kernelsize=3, selected_layers = None, loader=None):
        super(VAE, self).__init__()
        if loader == None:
            self.inference_net = tf.keras.Sequential(
              [
              tf.keras.layers.Conv2D(
                  filters=8, kernel_size=3, strides=(2, 2), activation='relu', use_bias=False, input_shape=(image_dim, image_dim, 3)),
              tf.keras.layers.BatchNormalization(),
              tf.keras.layers.Conv2D(
                  filters=4, kernel_size=3, strides=(2, 2), activation='relu', use_bias=False),
              tf.keras.layers.BatchNormalization(),
              tf.keras.layers.Conv2D(
                  filters=2, kernel_size=3, strides=(2, 2), activation='relu', use_bias=False),
              tf.keras.layers.BatchNormalization(),
              tf.keras.layers.Flatten(),
              # No activation
              tf.keras.layers.Dense(latent_dim + latent_dim),
              ]
                )

            self.generative_net = tf.keras.Sequential(
            [
              tf.keras.layers.Dense(units=24*24*32, activation=tf.nn.relu, input_shape=(latent_dim,)),
              tf.keras.layers.Reshape(target_shape=(24, 24, 32)),
              tf.keras.layers.Conv2DTranspose(
                  filters=2,
                  kernel_size=3,
                  strides=(2, 2),
                  padding="SAME",
                  activation='relu'),
              tf.keras.layers.Conv2DTranspose(
                  filters=4,
                  kernel_size=3,
                  strides=(2, 2),
                  padding="SAME",
                  activation='relu'),
              tf.keras.layers.Conv2DTranspose(
                  filters=8,
                  kernel_size=3,
                  strides=(2, 2),
                  padding="SAME",
                  activation='relu'),  
              # No activation
              tf.keras.layers.Conv2DTranspose(
                  filters=3, kernel_size=3, strides=(1, 1), padding="SAME", activation='sigmoid'),
                        ]
                    )
        
        if loader:
            self.inference_net = tf.keras.models.load_model(loader+'/inf')
            self.generative_net = tf.keras.models.load_model(loader+'/gen')
        
        if mode == 'dfc' or mode == 'combo':
            self.percep_net = tf.keras.models.clone_model(self.inference_net)
            self.percep_net.set_weights(self.inference_net.get_weights())


        # if no layers are specififed, use the first two convolution layers
        if selected_layers:
            self.selected_layers = selected_layers
        else:
            self.selected_layers = [layer.name for layer in self.inference_net.layers if layer.name.startswith('conv')][:2]

    @tf.function
    def encode(self, x):
        mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    @tf.function
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=tf.shape(mean))
        return eps * tf.exp(logvar * .5) + mean

    @tf.function
    def decode(self, z):
        return self.generative_net(z)

    @tf.function
    def get_features(self, x):
        rv = []
        for layer in self.percep_net.layers:
                # We do not want to apply updates for the inference pass
                x=layer(x)
                if layer.name in self.selected_layers:
                    rv.append(x)
                if len(rv) == len(self.selected_layers):
                    return rv

    def saver(self, DIR, tag):
        directory = './'+DIR+'/'+tag
        if not os.path.exists(directory):
            os.mkdir(directory)
        self.inference_net.save(directory+'/inf', save_format='h5')
        self.generative_net.save(directory+'/gen', save_format='h5')
        
    def loader(self, directory):
        pass

In [None]:
#loss definitions

# Couldn't figure how to do some simple per pixel mse... smh tensorlfow!
@tf.function
def mse(label, prediction):
    #flatten the tensors, maintaining batch dim
    return tf.losses.MSE(batch_flatten(label), batch_flatten(prediction))

@tf.function
def compute_loss(model, x, mode, scales, test=False):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_r = model.decode(z)
    rv = {}

    # Regularization term (KL divergence)
    kl_loss = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar), axis=-1)
    if 'kl_loss' in scales.keys(): kl_loss *= scales['kl_loss']
    rv['kl_loss']=kl_loss

    # Different losses for different trianing modes.
    if mode == 'vae':
        # Reconstruction loss
        rc_loss = mse(x, x_r)
        if 'rc_loss' in scales.keys(): rc_loss *= scales['rc_loss']
        rv['rc_loss']=rc_loss
        # Average over mini-batch and balance the losses
        total_loss = tf.reduce_mean(rc_loss + kl_loss)

    if mode == 'dfc':
        # get deep features
        outputs = model.get_features(x)
        outputs_r = model.get_features(x_r)
        # Perceptual loss
        perceptual_losses = [mse(original, reconstructed) for original, reconstructed in zip(outputs, outputs_r)]
        for layer, loss in zip(model.selected_layers, perceptual_losses):
            if layer in scales.keys(): loss*=scales[layer]
            rv[layer]=loss
        percep_loss = sum([rv[layer] for layer in model.selected_layers])
        if 'percep_loss' in scales.keys(): percep_loss *= scales['percep_loss']
        rv['percep_loss']=percep_loss
        total_loss = tf.reduce_mean(percep_loss + kl_loss)

    if mode == 'combo':
        outputs = model.get_features(x)
        outputs_r = model.get_features(x_r)
        perceptual_losses = [mse(original, reconstructed) for original, reconstructed in zip(outputs, outputs_r)]
        for layer, loss in zip(model.selected_layers, perceptual_losses):
            if layer in scales.keys(): loss*=scales[layer]
            rv[layer]=loss
        percep_loss = sum(perceptual_losses)
        if 'percep_loss' in scales.keys(): percep_loss *= scales['percep_loss']
        rv['percep_loss']=percep_loss
        rc_loss = mse(x, x_r)
        if 'rc_loss' in scales.keys(): rc_loss *= scales['rc_loss']
        rv['rc_loss']=rc_loss
        total_loss = tf.reduce_mean(percep_loss + rc_loss + kl_loss)

    rv['total_loss']=total_loss

    if test:
        rv['x']=x
        rv['x_r']=x_r
    return rv

@tf.function
def train_step(batch, model, optimizer, mode, scales):
    with tf.GradientTape(persistent=True) as tape:
        loss_dict = compute_loss(model, batch, mode, scales)
    inf_gradients=tape.gradient(loss_dict['total_loss'], model.inference_net.trainable_variables)
    gen_gradients=tape.gradient(loss_dict['total_loss'], model.generative_net.trainable_variables)
    optimizer.apply_gradients(zip(inf_gradients, model.inference_net.trainable_variables))
    opt2.apply_gradients(zip(gen_gradients, model.generative_net.trainable_variables))
    if mode == 'dfc' or mode == 'combo':
        opt3.apply_gradients(zip(inf_gradients, model.percep_net.trainable_variables))
    return loss_dict

# Use a class to create tf.variables on call for AutoGraph
class test:
    def __init__(self, loss_dict, image_size):
        #testing metrics
        self.metric_dict = {key: tf.metrics.Mean() for key in loss_dict}
        self.losses_dict = loss_dict
        self.losses_dict['x']=tf.zeros(shape=(loss_dict['kl_loss'].shape[0], image_size, image_size, 3))
        self.losses_dict['x_r']=tf.zeros(shape=(loss_dict['kl_loss'].shape[0], image_size, image_size, 3))

    @tf.function
    def __call__(self, model, test_set, step, mode, scales):
        with tf.device('/gpu:0'):
            for batch in test_set:
                self.losses_dict = compute_loss(model, batch, mode, scales, test=True)
                for loss, metric in self.metric_dict.items():
                    metric.update_state(self.losses_dict[loss])
        rv = self.metric_dict['total_loss'].result()
        for loss, metric in self.metric_dict.items():
            tf.summary.scalar(loss, metric.result(), step=step)
            metric.reset_states()
        with tf.device('/cpu:0'):
            tf.summary.image('input', self.losses_dict['x'], step = step, max_outputs=3)
            tf.summary.image('output', self.losses_dict['x_r'], step = step, max_outputs=3)
        return rv

In [None]:
###########  Parameters  ############
#folder to save weights and images
DIR = '../test2'
BATCH_SIZE = 32
image_size = 192
epochs = 1
latent_dim = 50
optimizer = tf.optimizers.Adam(1e-4)
opt2 = tf.optimizers.Adam(1e-4)
opt3 = tf.optimizers.Adam(1e-4)
log_freq = 10
kernelsize = 3
mode = 'dfc'
model = VAE(latent_dim, image_size, mode)
scales = {}
#####################################

In [None]:
image_dir='../Documents/img_align_celeba'

all_image_paths=get_paths(image_dir)

train_paths=all_image_paths[:6400]
test_paths=all_image_paths[-640:]
#train set defined in the loop for shuffling
test_set=from_path_to_tensor(test_paths, BATCH_SIZE, size=image_size)
train_dir='./{}/train'.format(DIR)
test_dir='./{}/test'.format(DIR)
# check if I'm about to overwrite event files
train_exists = os.path.exists(train_dir) and len(os.listdir(train_dir))!=0
test_exists = os.path.exists(test_dir) and len(os.listdir(test_dir))!=0
assert (not train_exists), "You are going to overwrite your train event files."
assert (not test_exists), "You are going to overwrite your test event files."
# Tensorboard logdirs
train_summary_writer = tf.summary.create_file_writer(train_dir)
test_summary_writer = tf.summary.create_file_writer(test_dir)

In [None]:
for epoch in range(1,epochs+1):
    train_set= from_path_to_tensor(train_paths, BATCH_SIZE, size=image_size)
    start_time = time.time()
    for i, batch in enumerate(train_set):
        loss_dict = train_step(batch, model, optimizer, mode, scales)
        if i==0:
            metrics_dict = {key: tf.metrics.Mean() for key in loss_dict}
        for loss, value in loss_dict.items():
            metrics_dict[loss].update_state(value)
        if tf.equal(optimizer.iterations % log_freq, 0):
            print('log', optimizer.iterations.numpy())
            
        if tf.equal(optimizer.iterations % 100, 0):
            with test_summary_writer.as_default():
                tester = test(loss_dict, image_size)
                avg_loss = tester(model, test_set, optimizer.iterations, mode, scales)
                print('Epoch: {}, test set average loss: {:.4f},'.format(epoch, avg_loss),
                    'time elapsed for current epoch: {:.2f}'.format((time.time() - start_time)/60), 'minutes')
            model.saver(DIR, '19')