In [8]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

Criando um sampling layer

Primeiro ele define um layer extra que vai fazer amostragem no espaço latente. Ele usa a média e a variancia que
layer anterior vai dar a amostragem normal com esses parametros.

In [11]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector enconding a digit."""
    
    def call(self,inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon                                    

Só queria fazer uma nota que ele vai criar um erro N(0,1)(epsilon) e multiplicar por um número e somar outro
                                                   E(X) == 0 => E(X + u) = u
                                                   v(X) = 1 => V(òX) = ò2

Aqui ele está fazendo e ~ N(0,1), somando a z_mean e multiplicando por ex(0.5 * z_log_var), então teremos 
                                                   E(e+z_mean) = z_mean
para a média e 
                                           V(e*exp{0.5*z_log_var})=
                                           v(e.exp{1/2logò2})=
                                           V(e*exp{logò2*1/2})=
                                           V(e*exp{logò})=
                                           V(e*ò)=
                                           ò2V(e)=
                                           ò2
para a variância.

Criando o encoder

No exemplo ele usa latent_dim = 2 para que nosso código tenha dimensão 2. Assim a nossa representação das variáveis seguirá uma distriuibuição normal bivariada, e sendo de 2 dimensões, vamos conseguir vizualiza-lá facilmente.

In [None]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation='relu')(x)
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

encoder.summary()

Built the decoder

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation='sigmoid' padding='same')(x)
decoder = keras.model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()

Define o VAE em um modelo customizado.

Essa parte é um pouco mais complicada.

Em VAE.__init__ ele define VAE, que junta o encoder com o decoder e calcula algumas métricas de interesse(perda de reconstrução,perda de kl e perda total)

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name='total_loss')
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name='recosntruction_loss'
        )
        self.kl_loss_tracker = keras.metrics.Mean(name='k1_loss')

In [None]:
@PROPERTY
def metrics(self):
    return[
        self.total_loss_tracker,
        self.reconstruction_loss_tracker,
        self.k1_loss_tracker,
    ]
def train_step(self, data):
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
            )
        )
        k1_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        k1_loss = tf.reduce_mean(tf.reduce_sum(k1_loss, axis=1))
        total_loss = reconstruction_loss + k1_loss
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    self.total_loss_tracker.update_state(total_loss)
    self.reconstruction_loss_tracker.update_state(reconstruction_loss)
    self.k1_loss_tracker.update_state(k1_loss)
    return{
        'loss': self.total_loss_tracker.result(),
        'recosntruction_loss': self.reconstruction_loss_tracker.result(),
        'k1_loss': self.k1_loss_tracker.result(),
    }

Treinando o modelo

In [None]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype('float32') / 255

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optmizer.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)

Display a grid of sampled digits

In [None]:
import matplotlib as plt

In [None]:
def plot_latent_space(va, n=30, figsize=15):
    #display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digits_size * n))
    #linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(scale, scale, n)[::-1]
    
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoder[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit
            
    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.imshow(figure, cmap='Greys_r')
    plt.show()
    
plot_latent_space(vae)

In [None]:
plot_label_clusters(vae, x_train, y_train)