In [1]:
import tensorflow as tf
# tfe = tf.contrib.eager
#tf.enable_eager_execution()

from tensorflow.keras import layers

from tensorflow.keras import backend as K

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
#import PIL
import imageio
from IPython import display
import pathlib
AUTOTUNE=tf.data.experimental.AUTOTUNE

In [2]:
tf.__version__

'2.0.0-beta0'

In [3]:
dir=pathlib.Path.cwd()/'../Documents/img_align_celeba/'

In [4]:
all_image_paths=list(dir.glob('*'))
all_image_paths = [str(path) for path in all_image_paths]
image_count = len(all_image_paths)
image_count

#train_paths=all_image_paths[:-20000]
train_paths=all_image_paths[:6400]
test_paths=all_image_paths[-320:]

In [5]:
all_image_paths[0]

'/home/isaiahk/intro_dfc/../Documents/img_align_celeba/008239.jpg'

In [6]:
def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [192, 192])
  image /= 255.0  # normalize to [0,1] range
  #image = tf.image.convert_image_dtype(image, tf.float16)

  return image

def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)


def from_path_to_tensor(paths, batch_size):
    path_ds=tf.data.Dataset.from_tensor_slices(paths)
    ds=path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
    ds=ds.batch(batch_size)
    ds=ds.prefetch(buffer_size=AUTOTUNE)
    return ds

In [7]:
BATCH_SIZE = 32
BUFFER_SIZE=image_count//9

train_set= from_path_to_tensor(train_paths, BATCH_SIZE)



test_set=from_path_to_tensor(test_paths, BATCH_SIZE)

In [8]:
class CVAE(tf.keras.Model):
  def __init__(self, latent_dim):
    super(CVAE, self).__init__()
    self.latent_dim = latent_dim
    self.inference_net = tf.keras.Sequential(
      [
          tf.keras.layers.InputLayer(input_shape=(192, 192, 3)),
          tf.keras.layers.Conv2D(
              filters=8, kernel_size=3, strides=(2, 2), activation='relu'),
          tf.keras.layers.Conv2D(
              filters=4, kernel_size=3, strides=(2, 2), activation='relu'),
          tf.keras.layers.Conv2D(
              filters=2, kernel_size=3, strides=(2, 2), activation='relu'),
          tf.keras.layers.Flatten(),
          # No activation
          tf.keras.layers.Dense(latent_dim + latent_dim),
      ]
    )

    self.generative_net = tf.keras.Sequential(
        [
          tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
          tf.keras.layers.Dense(units=24*24*32, activation=tf.nn.relu),
          tf.keras.layers.Reshape(target_shape=(24, 24, 32)),
          tf.keras.layers.Conv2DTranspose(
              filters=2,
              kernel_size=3,
              strides=(2, 2),
              padding="SAME",
              activation='relu'),
          tf.keras.layers.Conv2DTranspose(
              filters=4,
              kernel_size=3,
              strides=(2, 2),
              padding="SAME",
              activation='relu'),
          tf.keras.layers.Conv2DTranspose(
              filters=8,
              kernel_size=3,
              strides=(2, 2),
              padding="SAME",
              activation='relu'),  
          # No activation
          tf.keras.layers.Conv2DTranspose(
              filters=3, kernel_size=3, strides=(1, 1), padding="SAME", activation='sigmoid'),
        ]
    )
  
  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(4, self.latent_dim))
    return self.decode(eps)
  @tf.function
  def encode(self, x):
    mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
    return mean, logvar
  
  @tf.function
  def reparameterize(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean
  @tf.function
  def decode(self, z):
    return self.generative_net(z)

In [9]:
def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)

def compute_loss(model, x, test=False):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_r = model.decode(z)

    rc_loss = K.sum(K.binary_crossentropy(
    K.batch_flatten(x), 
    K.batch_flatten(x_r)), axis=-1)

    # Regularization term (KL divergence)
    kl_loss = -0.5 * K.sum(1 + logvar \
                             - K.square(mean) \
                             - K.exp(logvar), axis=-1)
    
    # Average over mini-batch
    total_loss = K.mean(rc_loss + kl_loss)
    
    if test:
        return rc_loss, kl_loss, total_loss, x, x_r
    else:
        return rc_loss, kl_loss, total_loss


def compute_gradients(model, x):
  with tf.GradientTape() as tape:
    _, _2, loss = compute_loss(model, x)
  return tape.gradient(loss, model.trainable_variables), loss

def apply_gradients(optimizer, gradients, variables):
     optimizer.apply_gradients(zip(gradients, variables))
    

In [10]:
epochs = 1
latent_dim = 50
num_examples_to_generate = 4

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, latent_dim])
model = CVAE(latent_dim)

In [11]:
model.inference_net.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 95, 95, 8)         224       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 47, 47, 4)         292       
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 23, 23, 2)         74        
_________________________________________________________________
flatten (Flatten)            (None, 1058)              0         
_________________________________________________________________
dense (Dense)                (None, 100)               105900    
Total params: 106,490
Trainable params: 106,490
Non-trainable params: 0
_________________________________________________________________


In [12]:
def generate_and_save_images(model, epoch, test_input):
  predictions = model.sample(test_input)
  fig = plt.figure(figsize=(2,2))

  for i in range(predictions.shape[0]):
      plt.subplot(2, 2, i+1)
      plt.imshow(predictions[i, :, :, :])
      plt.axis('off')

  # tight_layout minimizes the overlap between 2 sub-plots
  #plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

mse=tf.losses.MeanSquaredError()

@tf.function
def p_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_r = model.decode(z)
    
    outputs = [model.inference_net.get_layer(l).output for l in ["conv2d", "conv2d_1", "conv2d_2"]]
    p_model = tf.keras.Model(outputs)
    h1_list = model(x)
    h2_list = model(x_r)
    
    rc_loss = 0.0
    
    for h1, h2, weight in zip(h1_list, h2_list):
        rc_loss += mse(h1,h2)
        
    return rc_loss

In [13]:
def summarize(metrics, tags, test=False):
    if test:
        with test_summary_writer.as_default():
            for metric, tag in zip(metrics, tags):
                tf.summary.scalar(tag, metric.result(), step=optimizer.iterations)
                metric.reset_states()
    else:
        with train_summary_writer.as_default():
            for metric, tag in zip(metrics, tags):
                tf.summary.scalar(tag, metric.result(), step=optimizer.iterations)
                metric.reset_states()
            



@tf.function
def train_step(batch, model, optimizer):
    with tf.GradientTape() as tape:
        rc_loss, kl_loss, loss = compute_loss(model, batch)
    gradients=tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return rc_loss, kl_loss, loss


losses=['loss','rc_loss','kl_loss']

def train_vae(model, optimizer, epochs, dataset, print_interval, save_interval, log_freq=10):
    #summary_writer = tf.summary.create_file_writer(DIR)
    for epoch in range(1, epochs + 1):
        avg_loss=tf.metrics.Mean(name='loss', dtype=tf.float32)
        avg_val=tf.metrics.Mean(name='val_loss', dtype=tf.float32)
        avg_train_rc=tf.metrics.Mean(name='training_rc_loss', dtype=tf.float32)
        avg_val_rc=tf.metrics.Mean(name='val_rc_loss', dtype=tf.float32)
        avg_train_kl=tf.metrics.Mean(name='training_kl_loss', dtype=tf.float32)
        avg_val_kl=tf.metrics.Mean(name='val_kl_loss', dtype=tf.float32)
        avg_p_loss=tf.metrics.Mean(name='p_loss',dtype=tf.float32)
        start_time = time.time()
        for step, batch in enumerate(dataset):
            rc_loss_x, kl_loss_x, loss_x = train_step(batch, model, optimizer)
            avg_loss.update_state(loss_x)
            avg_train_rc.update_state(rc_loss_x)
            avg_train_kl.update_state(kl_loss_x)
            if tf.equal(optimizer.iterations % log_freq, 0):
                print('log', '{}'.format(optimizer.iterations.numpy()))
                summarize([avg_loss, avg_train_rc, avg_train_kl], losses)
            if (time.time()-start_time)>5:
                print('time to check!')
                for batch in test_set:
                    print('do we get here?')
                    rc_val, kl_val, val, x, x_r=compute_loss(model, batch, test=True)
                    ploss=p_loss(model, batch)
                    avg_val.update_state(val)
                    avg_val_rc.update_state(rc_val)
                    avg_val_kl.update_state(kl_val)
                    avg_p_loss.update_state(ploss)
                print('Batch',step,'done.', 'loss on validaiton set: {}'.format(avg_val.result()))
                summarize([avg_val,avg_val_rc, avg_val_kl], losses,test=True)
                with test_summary_writer.as_default():
                    tf.summary.image('input', x, step = optimizer.iterations, max_outputs=3)
                    tf.summary.image('output', x_r, step = optimizer.iterations, max_outputs=3)
                    tf.summary.scalar('p_loss', avg_p_loss.result(), step=optimizer.iterations)
                    avg_p_loss.reset_states()
                start_time=time.time()
            if step+1 % save_interval ==0:
                generate_and_save_images(model, epoch, step+1, random_vector_for_generation)
                model.weight_saver(TRAINING_DIR, epoch, i)
                end_time = time.time()
    #generate_and_save_images(model, epoch,0, random_vector_for_generation)


In [14]:
optimizer=tf.optimizers.Adadelta(1e-4)

In [15]:
TRAINING_DIR='experiment'

In [16]:
train_summary_writer = tf.summary.create_file_writer(TRAINING_DIR+'/summaries/train')
test_summary_writer = tf.summary.create_file_writer(TRAINING_DIR+'/summaries/test')


In [17]:
train_vae(model, optimizer, 1, train_set, 10, 500)

W0620 18:39:44.097499 140534846768896 deprecation.py:323] From /home/isaiahk/miniconda3/envs/tf2b/lib/python3.6/site-packages/tensorflow/python/ops/math_grad.py:1220: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


log 10
log 20
log 30
time to check!
do we get here?


NotImplementedError: in converted code:

    <ipython-input-12-a7b2ed9308f5>:24 p_loss  *
        h1_list = model(x)
    /home/isaiahk/miniconda3/envs/tf2b/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:667 __call__
        outputs = call_fn(inputs, *args, **kwargs)
    /home/isaiahk/miniconda3/envs/tf2b/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py:750 call  *
        raise NotImplementedError('When subclassing the `Model` class, you should'

    NotImplementedError: When subclassing the `Model` class, you should implement a `call` method.
