<a href="https://colab.research.google.com/github/felix0097/CVAE_mnist/blob/master/cvae_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install tensorflow==2.1.0

In [0]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds
import numpy as np
from typing import Tuple

tf.compat.v1.disable_eager_execution()

**Define Convolutional CVAE model**

In [0]:
class ConvCVAE(tf.keras.Model):

  def __init__(self, 
               input_shape_img: Tuple[int, int, int],
               input_shape_cond: int, 
               latent_dim: int):

    super(ConvCVAE, self).__init__()

    self.input_shape_img = input_shape_img
    self.latent_dim = latent_dim

    self.conv_enc = tf.keras.Sequential(
        [tf.keras.layers.InputLayer(input_shape=input_shape_img),
         tf.keras.layers.Conv2D(filters=32, 
                                kernel_size=3, 
                                activation='relu',
                                padding='same'),
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
                                      padding='same'),
         tf.keras.layers.Conv2D(filters=64, 
                                kernel_size=3,
                                activation='relu',
                                padding='same'),
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2),
                                      padding='same'),
         tf.keras.layers.Flatten()
      ],
      name='encoder')
    
    self.enc = tf.keras.Sequential(
        [tf.keras.layers.InputLayer(input_shape=self.conv_enc.output_shape[1] + input_shape_cond),
         tf.keras.layers.Dense(20*latent_dim,
                               activation='relu'),
         tf.keras.layers.Dense(2*latent_dim)]
    )

    self.dec = tf.keras.Sequential(
        [tf.keras.layers.InputLayer(input_shape=(latent_dim + input_shape_cond)),
         tf.keras.layers.Dense(units=self.conv_enc.output_shape[1],
                               activation=tf.nn.relu),
         tf.keras.layers.Reshape(target_shape=self.conv_enc.layers[-2].output_shape[1:]),
         tf.keras.layers.Conv2DTranspose(filters=64,
                                         kernel_size=3,
                                         activation='relu',
                                         padding='same'),
         tf.keras.layers.UpSampling2D(size=(2, 2)),
         tf.keras.layers.Conv2DTranspose(filters=32,
                                         kernel_size=3,
                                         activation='relu',
                                         padding='same'),
         tf.keras.layers.UpSampling2D(size=(2, 2)),
         tf.keras.layers.Conv2DTranspose(filters=input_shape_img[2], 
                                         kernel_size=3, 
                                         strides=(1, 1),
                                         padding='same'),
         tf.keras.layers.Activation('sigmoid')
        ],
        name='decoder')
    
    def call(self, inputs, training=False):

      img_input = inputs[0]
      cond_input = inputs[1]

      enc_img = self.conv_enc(img_input)
      enc_output = self.enc(tf.concat([enc_img, cond_input], axis=1))
      mean, log_scale = tf.split(enc_output, num_or_size_splits=2, axis=1)
      scale= tf.math.exp(log_scale)

      latent_dist = tfp.distributions.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=scale)
      
      ref_dist = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(self.latent_dim))

      kl_divergence = tfp.distributions.kl_divergence(latent_dist, ref_dist)
      self.add_loss(tf.math.reduce_sum(kl_divergence, name='KL_divergence_loss'))

      input_dec = tf.concat([latent_dist.sample(), cond_input], axis=1)
      dec_img = self.dec(input_dec)
      dec_img = tf.image.resize_with_crop_or_pad(dec_img, 
                                                 self.input_shape_img[0],
                                                 self.input_shape_img[1])
      
      return dec_img
      

**Prepare data set for fitting**

In [0]:
def preprocess_data(elem):
    img = elem['image']
    label = elem['label']
    # convert input image to [0, 1]
    img = tf.cast(img, dtype=tf.float32) / tf.cast(255., dtype=tf.float32)
    # one hot encode label
    label = tf.one_hot(tf.cast(label, dtype=tf.uint8), depth=10)

    return ((img, label), img)

ds_train = tfds.load(name="mnist", split="train")
ds_val = tfds.load(name="mnist", split="test")

ds_train = ds_train.map(preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_val = ds_val.map(preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)

ds_train = ds_train.shuffle(5000, reshuffle_each_iteration=True)\
                   .repeat()\
                   .batch(64)\
                   .prefetch(64)
ds_val = ds_val.shuffle(5000)\
               .repeat()\
               .batch(64)\
               .prefetch(64)


**Fit model**

In [4]:
ds_train

<DatasetV1Adapter shapes: (((None, 28, 28, 1), (None, 10)), (None, 28, 28, 1)), types: ((tf.float32, tf.float32), tf.float32)>

In [4]:
conv_cvae = ConvCVAE(input_shape_img=(28, 28, 1),
                     input_shape_cond=10,
                     latent_dim=10)

conv_cvae.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.BinaryCrossentropy())

conv_cvae.fit(ds_train,
              validation_data=ds_val,
              epochs=10
              )

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


AttributeError: ignored