<a href="https://colab.research.google.com/github/dauparas/tensorflow_examples/blob/master/linear_beta_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#Step 1: import dependencies
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from keras import regularizers
import time
from __future__ import division

%matplotlib inline
plt.style.use('ggplot')

In [0]:
#Step 2: generate some data
N_points = 501 #number of points to sample
np.random.seed(1) #set random seed

#Sample data points from a uniform distribution
t1 = np.random.uniform(0, 1, N_points)
t2 = np.random.uniform(0, 1, N_points)
t3 = np.random.uniform(0, 1, N_points)

#Create features and add some Gaussian noise
noise_level = 0.01
x1 = (1.0*t1+noise_level*np.random.randn(N_points)).reshape(-1,1)
x2 = (2.0*t1+noise_level*np.random.randn(N_points)).reshape(-1,1)
x3 = (3.0*t2+noise_level*np.random.randn(N_points)).reshape(-1,1)
x4 = (5.0*t2+noise_level*np.random.randn(N_points)).reshape(-1,1)
N_features = 4
X = np.concatenate([x1, x2, x3, x4], axis=1)

#Rescale data 
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) /std
X = X.astype(np.float32)
assert X.shape == (N_points, N_features)

In [0]:
#Plot generating variables
plt.scatter(t1,t2);
plt.axis('equal');
plt.xlabel('t1');
plt.ylabel('t2');
plt.title('Generating variables');

In [0]:
#Plot data points
plt.figure(1)
plt.scatter(data[:,0], data[:,1]);
plt.axis('equal');
plt.xlabel('x1');
plt.ylabel('x2');
plt.title('Data points');
plt.axis('equal');
#Plot data points
plt.figure(2)
plt.scatter(data[:,0], data[:,2]);
plt.axis('equal');
plt.xlabel('x1');
plt.ylabel('x3');
plt.axis('equal');

plt.figure(3)
plt.scatter(data[:,2], data[:,3]);
plt.axis('equal');
plt.xlabel('x3');
plt.ylabel('x4');
plt.axis('equal');

In [0]:
#Calculate the covariance matrix of the data matrix X
C = np.matmul(X.T, X)/X.shape[0]

#Plot correlation matrix C
f, ax = plt.subplots();
sns.heatmap(C, square=True, linewidths=1.0);
ax.set_xlabel('features');
ax.set_ylabel('features');
ax.set_title('Correlation matrix');

In [0]:
#Create encoder and decoder for beta-VAE      
def encoder(x_in, N_features, N_latent):
    with tf.variable_scope("encoder", reuse=None):
        w1 = tf.get_variable('w1', [N_features, N_latent], initializer=tf.glorot_uniform_initializer())
        b1 = tf.get_variable('b1', [N_latent], initializer=tf.constant_initializer(0.0))
        enc_mean = tf.matmul(x_in, w1)+b1 #mean of the encoded Gaussian
      
      
        w2 = tf.get_variable('w2', [N_features, N_latent], initializer=tf.glorot_uniform_initializer())
        b2 = tf.get_variable('b2', [N_latent], initializer=tf.constant_initializer(0.0))
        enc_log_sd = tf.matmul(x_in, w2)+b2 #log of the standard deviation of the Gaussian
        
        epsilon = tf.random_normal(tf.shape(enc_mean)) #sample from Gaussian(0,1)
        z  = enc_mean + tf.multiply(epsilon, tf.exp(enc_log_sd)) #create the latent sample
        return z, enc_mean, enc_log_sd, w1, w2
      
def decoder(z_in, N_features, N_latent):
    with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
        w3 = tf.get_variable('w3', [N_latent, N_features], initializer=tf.glorot_uniform_initializer())
        b3 = tf.get_variable('b3', [N_features], initializer=tf.constant_initializer(0.0))
        dec_mean = tf.matmul(z_in, w3)+b3 #mean of the Gaussian output
        return dec_mean, w3

In [0]:
#Create latent variables for testing disentanglement
N_latent_test = 21
z1 = np.linspace(-0.5,0.5,N_latent_test)
z2 = np.linspace(-0.5,0.5,N_latent_test) 
z = np.meshgrid(z1, z2)
z = np.reshape(z, [2, -1]).T
z = np.asarray(z, dtype=np.float32)
assert z.shape == (N_latent_test**2, 2)

In [0]:
#Feed data matrix X, connect encoder and decoder, define a loss function
tf.reset_default_graph() #reset the tensorflow graph 

#Parameters for training
n_epochs =1000
batch_size = N_points
N_latent = 2
learning_rate = 0.1

#Create placeholders for feeding training parameters
BATCH_SIZE = tf.placeholder(tf.int64, name='batch') 
Z = tf.placeholder(tf.float32, name='Z_test') #for inputing latent variables
C = tf.placeholder(tf.float32, name='C') #for the capacity of the encoder

#Feed data matrix X into tf.data.Dataset
train_data = (X, X) #inputs and targets - they are the same for the autoencoder
train_data = tf.data.Dataset.from_tensor_slices(train_data)
train_data = train_data.shuffle(buffer_size=10000)
train_data = train_data.batch(BATCH_SIZE)

#Create an iterator
iterator = tf.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
features, labels = iterator.get_next()
train_init = iterator.make_initializer(train_data)

#Connect encoder and decoder
z_sampled, enc_mean, enc_log_sd, w1, w2 = encoder(features, N_features, N_latent)
dec_mean, w3 = decoder(z_sampled, N_features, N_latent)

#Input latent space testing data
dec_mean_out, _ = decoder(Z, N_features, N_latent)


#Define losses
#Decoder loss - reconstruction loss is a sum of square differences, i.e. we assume that the output distribution
#Gaussian(dec_mean, 1)
dec_loss = tf.reduce_mean(tf.reduce_sum(tf.square(features-dec_mean),1))

#KL loss between the prior p(z) which is Gaussian(0,1) and the decoder which is Gaussian(enc_mean, enc_variance)
#see equation (10) in the Kingma, Welling paper (https://arxiv.org/pdf/1312.6114.pdf)
lam1 = 0.5
kl = tf.reduce_mean(-0.5 * tf.reduce_sum(1.0+ 2.0*enc_log_sd - tf.square(enc_mean) - tf.exp(2.0 * enc_log_sd), 1))
kl_loss = lam1*tf.abs(kl-C)

#Add L1 regularization on network weights
lam2 = 0.1
l1_loss = lam2*(tf.reduce_sum(tf.abs(w1))+tf.reduce_sum(tf.abs(w2))+tf.reduce_sum(tf.abs(w3)))

#Total loss for training
loss = dec_loss+kl_loss+l1_loss

#Optimizer for training
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [0]:
#Create a training session
saver = tf.train.Saver()
with tf.Session() as sess:
    start_time = time.time()
    sess.run(tf.global_variables_initializer())
    for i in range(n_epochs):
        c = 5.0*(1.0*i)/1000 #capacity of the encoder
        sess.run(train_init, feed_dict={BATCH_SIZE: batch_size, C: c})
        total_loss = 0 #total loss
        total_kl = 0
        total_dec_loss = 0 #total_decoder_loss
        total_kl_loss = 0
        total_l1_loss = 0
        n_batches = 0
        try:
            while True:
                _, l, kl_out, l_dec, l_kl, l_l1= sess.run([optimizer, loss, kl, dec_loss, kl_loss, l1_loss], feed_dict={C: c})
                total_loss += l
                total_kl += kl_out
                total_dec_loss += l_dec
                total_kl_loss += l_kl
                total_l1_loss += l_l1
                n_batches += 1
        except tf.errors.OutOfRangeError:
            pass
        if i % 50 == 0:
          print('Epoch: {0}, loss: {1:.4f}, KL: {2:.4f}, dec_loss: {3:.4f}, kl_loss: {4:.4f}, l1_loss: {5:.4f}'
                .format(i, total_loss/n_batches, total_kl/n_batches, total_dec_loss/n_batches, 
                        total_kl_loss/n_batches, total_l1_loss/n_batches ))
    print('Total time: {0} seconds'.format(time.time() - start_time))
    
    
    #Calculate the decoder mean for the training data X
    sess.run(train_init, feed_dict={BATCH_SIZE: X.shape[0]})
    try:
      while True:
        dec_mean_train = sess.run([dec_mean])
    except tf.errors.OutOfRangeError:
      pass
    
    
    #Calculate the encoding for the training data X
    sess.run(train_init, feed_dict={BATCH_SIZE: X.shape[0]})
    try:
      while True:
        enc_mean_out, enc_log_sd_out = sess.run([enc_mean, enc_log_sd])
    except tf.errors.OutOfRangeError:
      pass

    dec_mean_test = sess.run([dec_mean_out], feed_dict={Z: z})
    
    
    #Save the model
    save_path = saver.save(sess, "./ae.ckpt")
    print("Model saved in path: %s" % save_path)

In [0]:
tf.reset_default_graph()

# Create variables to be restored
w1 = tf.get_variable('encoder/w1', shape=[N_features, N_latent], dtype=tf.float32)
b1 = tf.get_variable('encoder/b1', shape=[N_latent], dtype=tf.float32)
w3 = tf.get_variable('decoder/w3', shape=[N_latent,N_features], dtype=tf.float32)
b3 = tf.get_variable('decoder/b3', shape=[N_features], dtype=tf.float32)

saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, "./ae.ckpt")
  w1_out = w1.eval()
  b1_out = b1.eval()
  w3_out = w3.eval()
  b3_out = b3.eval()

In [0]:
#Plot the encoder matrix w1
f, ax = plt.subplots();
sns.heatmap(w1_out.T, square=True, linewidths=1.0, vmin=-1.0, vmax=1.0);
ax.set_xlabel('features');
ax.set_ylabel('latent');
ax.set_title('Encoder matrix');

#Plot the decoder matrix w3
f, ax = plt.subplots();
sns.heatmap(w3_out, square=True, linewidths=1.0, vmin=-1.0, vmax=1.0);
ax.set_xlabel('features');
ax.set_ylabel('latent');
ax.set_title('Decoder matrix');

#Plot the encoder * decoder matrix
f, ax = plt.subplots();
sns.heatmap(np.matmul(w1_out, w3_out), linewidths=1.0, vmin=-1.0, vmax=1.0);
ax.set_xlabel('features');
ax.set_ylabel('features');
ax.set_title('Encoder * decoder matrix');

In [0]:
dec_mean_train = np.squeeze(np.array(dec_mean_train))
dec_mean_test = np.squeeze(np.array(dec_mean_test))
enc_mean_out = np.squeeze(np.array(enc_mean_out))

In [0]:
plt.scatter(enc_mean_out[:,0], enc_mean_out[:,1]);
plt.axis('equal');

In [0]:
dc = np.reshape(dec_mean_test, [N_latent_test, N_latent_test, N_features])

In [0]:
plt.scatter(X[:,0], X[:,1], label='train_data');
# plt.scatter(dec_mean_out[:,0], dec_mean_out[:,1], label='learnt model')
plt.scatter(dc[:,0,0], dc[:,0,1], label='changing z1');
plt.scatter(dc[0,:,0], dc[0,:,1], label='changing z2');
plt.legend();

In [0]:
plt.scatter(X[:,2], X[:,3], label='train_data');
# plt.scatter(dec_mean_out[:,0], dec_mean_out[:,1], label='learnt model')
plt.scatter(dc[:,0,2], dc[:,0,3], label='changing z1');
plt.scatter(dc[0,:,2], dc[0,:,3], label='changing z2');
plt.legend();