In [35]:
import tensorflow as tf
import pandas as pd
import numpy as np
from tensorflow import keras
import matplotlib.pyplot as plt
import copy

In [36]:
INPUT_SHAPE=(32,32)
BATCH_SIZE = 32
DIM_Z = 50
DIM_Z_HIDDEN = 25
DIM_X_HIDDEN = 50
DIM_X=np.prod(INPUT_SHAPE)

In [37]:
tf.reset_default_graph()
tf.enable_eager_execution()
tf.keras.backend.clear_session()

In [38]:
def get_model(dim_z_hidden, dim_z, input_shape, dim_x_hidden):
    flatten_encode = keras.layers.Flatten()
    dense_encode = keras.layers.Dense(dim_z_hidden, activation='tanh')
    mu_encode = keras.layers.Dense(dim_z, activation='linear')
    log_sigma_encode = keras.layers.Dense(dim_z, activation='linear')
    inputs_encode = keras.layers.Input(shape=(*input_shape,))
    encoder=keras.models.Model(
        inputs=inputs_encode,
        outputs=(
            mu_encode(dense_encode(flatten_encode(inputs_encode))),
            log_sigma_encode(dense_encode(flatten_encode(inputs_encode)))
        )
    )
    
    dim_x = np.prod(input_shape)    
    dense_decode = keras.layers.Dense(dim_x_hidden, activation='tanh')
    mu_decode = keras.layers.Dense(dim_x, activation='linear')
    log_sigma_decode = keras.layers.Dense(dim_x, activation='linear')
    inputs_decode = keras.layers.Input(shape=(dim_z,))
    decoder=keras.models.Model(
        inputs=inputs_decode,
        outputs=(
            mu_decode(dense_decode((inputs_decode))),
            log_sigma_decode(dense_decode((inputs_decode)))
        )
    )
    
    return encoder, decoder

In [39]:
def elbo(x,encoder, decoder, L=100, seed=0):
    batch = x.shape[0]
    mu_z, log_sigma_z = encoder(x)
    
    dim = mu_z.shape[1]
    np.random.seed(seed)
    eps = np.random.normal(0, 1, size = (L, batch, dim))
    
    zs = tf.reshape(eps *tf.exp(log_sigma_z) + mu_z, (-1, dim))
    mu_x, log_sigma_x = decoder(zs) # (L * batch, dim_x)
    mu_x = tf.reshape(mu_x, (L, batch, -1))
    log_sigma_x = tf.reshape(log_sigma_x, (L, batch, -1))
    
    
    minus_log_q = eps**2/2 + log_sigma_z + 0.5*tf.log(2*np.pi)
    log_p = -(tf.dtypes.cast(tf.reshape(x, (batch, -1)), tf.float32)-mu_x)**2/(2 * tf.exp(2*log_sigma_x)) -log_sigma_x - 0.5*tf.log(2*np.pi)
    log_pz = -zs**2/2 - 0.5*tf.log(2*np.pi)
    return (tf.math.reduce_sum(log_p) + tf.math.reduce_sum(minus_log_q) + tf.math.reduce_sum(log_pz))/L
    

In [40]:
def bad_elbo(x,encoder, decoder, L=100, seed=0):
    batch = x.shape[0]
    mu_z, log_sigma_z = encoder(x)
    sigma_z = np.exp(log_sigma_z)
    dim = mu_z.shape[1]
    
    np.random.seed(seed)
    eps = np.random.normal(0,1, size = (L, batch, dim))
    zs = eps * sigma_z + mu_z
    mu_x, log_sigma_x = decoder(np.reshape(zs, (-1, dim)))
    sigma_x = np.exp(log_sigma_x)
    mu_x = np.reshape(mu_x, (L, batch, -1))
    sigma_x = np.reshape(sigma_x, (L, batch, -1))
    p_x_z = np.prod(np.exp(-zs**2/2)/np.sqrt(2*np.pi)) * np.prod(np.exp(-(np.reshape(x, (batch, -1))-mu_x)**2/(sigma_x**2*2))/np.sqrt(2*np.pi*sigma_x**2))
    q_z_x = np.exp(-(zs-mu_z)**2/(sigma_z**2*2))/np.sqrt(2*np.pi*sigma_z**2)
    return (np.log(np.prod(p_x_z)) - np.log(np.prod(q_z_x)))/L

In [41]:
def manual_elbo(x_, encoder, decoder, L=100, seed=0):
    x = np.reshape(x_,(x_.shape[0], -1))
    batch = x.shape[0]
    encode_weights = encoder.get_weights()
    encode_hidden = np.tanh(np.matmul(x, encode_weights[0]) + encode_weights[1])
    mu_z = np.matmul(encode_hidden, encode_weights[2]) + encode_weights[3]
    log_sigma_z = np.matmul(encode_hidden, encode_weights[4]) + encode_weights[5]
    
    decode_weights = decoder.get_weights()
    dim = mu_z.shape[1]
    np.random.seed(seed)
    eps = np.random.normal(0, 1, size = (L, batch, dim))
    zs = np.reshape(eps*np.exp(log_sigma_z) + mu_z, (-1, dim))
    
    decode_hidden = np.tanh(np.matmul(zs, decode_weights[0]) + decode_weights[1])
    mu_x = np.matmul(decode_hidden, decode_weights[2]) + decode_weights[3]
    log_sigma_x = np.matmul(decode_hidden, decode_weights[4]) + decode_weights[5]
    mu_x = tf.reshape(mu_x, (L, batch, -1))
    log_sigma_x = tf.reshape(log_sigma_x, (L, batch, -1))
    
    minus_log_q = eps**2/2 + log_sigma_z + 0.5*tf.log(2*np.pi)
    log_p = -(x-mu_x)**2/(2 * np.exp(2*log_sigma_x)) -log_sigma_x - 0.5*np.log(2*np.pi)
    log_pz = -zs**2/2 - 0.5*np.log(2*np.pi)
    return (np.sum(log_p) + np.sum(minus_log_q) + np.sum(log_pz))/L
    

In [42]:
def grad_elbo(x, encoder, decoder, L=100, seed=0):
    with tf.GradientTape() as tape:
        loss = elbo(x, encoder, decoder, L, seed)
    return loss, tape.gradient(loss, [encoder.trainable_variables, decoder.trainable_variables])

In [44]:
epochs=1000
batch_size=32
x = np.random.normal(0,1, size=(3,4,2))
encoder, decoder = get_model(3,2,x.shape[1:],1)
optimizer = tf.keras.optimizers.Adamax()

In [47]:
for epoch in range(epochs):
    optimizer.minimize(lambda : elbo(x ,encoder, decoder), [encoder.trainable_variables, decoder.trainable_variables])
    print(elbo(x ,encoder, decoder))

tf.Tensor(-177.4398, shape=(), dtype=float32)
tf.Tensor(-178.68, shape=(), dtype=float32)
tf.Tensor(-179.92976, shape=(), dtype=float32)
tf.Tensor(-181.18898, shape=(), dtype=float32)
tf.Tensor(-182.45787, shape=(), dtype=float32)
tf.Tensor(-183.73634, shape=(), dtype=float32)
tf.Tensor(-185.02464, shape=(), dtype=float32)
tf.Tensor(-186.32266, shape=(), dtype=float32)
tf.Tensor(-187.63077, shape=(), dtype=float32)
tf.Tensor(-188.94878, shape=(), dtype=float32)
tf.Tensor(-190.27673, shape=(), dtype=float32)
tf.Tensor(-191.61494, shape=(), dtype=float32)
tf.Tensor(-192.96344, shape=(), dtype=float32)
tf.Tensor(-194.32213, shape=(), dtype=float32)
tf.Tensor(-195.69133, shape=(), dtype=float32)
tf.Tensor(-197.07109, shape=(), dtype=float32)
tf.Tensor(-198.4613, shape=(), dtype=float32)
tf.Tensor(-199.86226, shape=(), dtype=float32)
tf.Tensor(-201.27373, shape=(), dtype=float32)
tf.Tensor(-202.69624, shape=(), dtype=float32)
tf.Tensor(-204.12962, shape=(), dtype=float32)
tf.Tensor(-205.573

tf.Tensor(-647.28906, shape=(), dtype=float32)
tf.Tensor(-651.9991, shape=(), dtype=float32)
tf.Tensor(-656.74225, shape=(), dtype=float32)
tf.Tensor(-661.51807, shape=(), dtype=float32)
tf.Tensor(-666.3267, shape=(), dtype=float32)
tf.Tensor(-671.1689, shape=(), dtype=float32)
tf.Tensor(-676.04456, shape=(), dtype=float32)
tf.Tensor(-680.95386, shape=(), dtype=float32)
tf.Tensor(-685.8976, shape=(), dtype=float32)
tf.Tensor(-690.87537, shape=(), dtype=float32)
tf.Tensor(-695.88806, shape=(), dtype=float32)
tf.Tensor(-700.9356, shape=(), dtype=float32)
tf.Tensor(-706.0183, shape=(), dtype=float32)
tf.Tensor(-711.13654, shape=(), dtype=float32)
tf.Tensor(-716.2905, shape=(), dtype=float32)
tf.Tensor(-721.4807, shape=(), dtype=float32)
tf.Tensor(-726.70685, shape=(), dtype=float32)
tf.Tensor(-731.96985, shape=(), dtype=float32)
tf.Tensor(-737.26953, shape=(), dtype=float32)
tf.Tensor(-742.6068, shape=(), dtype=float32)
tf.Tensor(-747.98096, shape=(), dtype=float32)
tf.Tensor(-753.3933, s

tf.Tensor(-2396.3062, shape=(), dtype=float32)
tf.Tensor(-2413.3684, shape=(), dtype=float32)
tf.Tensor(-2430.5479, shape=(), dtype=float32)
tf.Tensor(-2447.85, shape=(), dtype=float32)
tf.Tensor(-2465.2715, shape=(), dtype=float32)
tf.Tensor(-2482.8171, shape=(), dtype=float32)
tf.Tensor(-2500.4856, shape=(), dtype=float32)
tf.Tensor(-2518.279, shape=(), dtype=float32)
tf.Tensor(-2536.195, shape=(), dtype=float32)
tf.Tensor(-2554.239, shape=(), dtype=float32)
tf.Tensor(-2572.4075, shape=(), dtype=float32)
tf.Tensor(-2590.7056, shape=(), dtype=float32)
tf.Tensor(-2609.1294, shape=(), dtype=float32)
tf.Tensor(-2627.6833, shape=(), dtype=float32)
tf.Tensor(-2646.3672, shape=(), dtype=float32)
tf.Tensor(-2665.1833, shape=(), dtype=float32)
tf.Tensor(-2684.1282, shape=(), dtype=float32)
tf.Tensor(-2703.2068, shape=(), dtype=float32)
tf.Tensor(-2722.4182, shape=(), dtype=float32)
tf.Tensor(-2741.7632, shape=(), dtype=float32)
tf.Tensor(-2761.2468, shape=(), dtype=float32)
tf.Tensor(-2780.86

tf.Tensor(-8241.901, shape=(), dtype=float32)
tf.Tensor(-8299.164, shape=(), dtype=float32)
tf.Tensor(-8356.825, shape=(), dtype=float32)
tf.Tensor(-8414.889, shape=(), dtype=float32)
tf.Tensor(-8473.362, shape=(), dtype=float32)
tf.Tensor(-8532.242, shape=(), dtype=float32)
tf.Tensor(-8591.538, shape=(), dtype=float32)
tf.Tensor(-8651.249, shape=(), dtype=float32)
tf.Tensor(-8711.389, shape=(), dtype=float32)
tf.Tensor(-8771.95, shape=(), dtype=float32)
tf.Tensor(-8832.937, shape=(), dtype=float32)
tf.Tensor(-8894.354, shape=(), dtype=float32)
tf.Tensor(-8956.209, shape=(), dtype=float32)
tf.Tensor(-9018.505, shape=(), dtype=float32)
tf.Tensor(-9081.246, shape=(), dtype=float32)
tf.Tensor(-9144.441, shape=(), dtype=float32)
tf.Tensor(-9208.081, shape=(), dtype=float32)
tf.Tensor(-9272.177, shape=(), dtype=float32)
tf.Tensor(-9336.738, shape=(), dtype=float32)
tf.Tensor(-9401.758, shape=(), dtype=float32)
tf.Tensor(-9467.245, shape=(), dtype=float32)
tf.Tensor(-9533.204, shape=(), dtyp

tf.Tensor(-28992.006, shape=(), dtype=float32)
tf.Tensor(-29201.807, shape=(), dtype=float32)
tf.Tensor(-29413.18, shape=(), dtype=float32)
tf.Tensor(-29626.113, shape=(), dtype=float32)
tf.Tensor(-29840.588, shape=(), dtype=float32)
tf.Tensor(-30056.69, shape=(), dtype=float32)
tf.Tensor(-30274.363, shape=(), dtype=float32)
tf.Tensor(-30493.67, shape=(), dtype=float32)
tf.Tensor(-30714.555, shape=(), dtype=float32)
tf.Tensor(-30937.09, shape=(), dtype=float32)
tf.Tensor(-31161.295, shape=(), dtype=float32)
tf.Tensor(-31387.133, shape=(), dtype=float32)
tf.Tensor(-31614.615, shape=(), dtype=float32)
tf.Tensor(-31843.812, shape=(), dtype=float32)
tf.Tensor(-32074.71, shape=(), dtype=float32)
tf.Tensor(-32307.29, shape=(), dtype=float32)
tf.Tensor(-32541.592, shape=(), dtype=float32)
tf.Tensor(-32777.625, shape=(), dtype=float32)
tf.Tensor(-33015.41, shape=(), dtype=float32)
tf.Tensor(-33254.95, shape=(), dtype=float32)
tf.Tensor(-33496.24, shape=(), dtype=float32)
tf.Tensor(-33739.32, s

tf.Tensor(-112947.85, shape=(), dtype=float32)
tf.Tensor(-113776.89, shape=(), dtype=float32)
tf.Tensor(-114612.08, shape=(), dtype=float32)
tf.Tensor(-115453.4, shape=(), dtype=float32)
tf.Tensor(-116301.02, shape=(), dtype=float32)
tf.Tensor(-117154.84, shape=(), dtype=float32)
tf.Tensor(-118014.97, shape=(), dtype=float32)
tf.Tensor(-118881.34, shape=(), dtype=float32)
tf.Tensor(-119754.26, shape=(), dtype=float32)
tf.Tensor(-120633.72, shape=(), dtype=float32)
tf.Tensor(-121519.57, shape=(), dtype=float32)
tf.Tensor(-122411.98, shape=(), dtype=float32)
tf.Tensor(-123310.94, shape=(), dtype=float32)
tf.Tensor(-124216.44, shape=(), dtype=float32)
tf.Tensor(-125128.91, shape=(), dtype=float32)
tf.Tensor(-126047.94, shape=(), dtype=float32)
tf.Tensor(-126973.79, shape=(), dtype=float32)
tf.Tensor(-127906.66, shape=(), dtype=float32)
tf.Tensor(-128846.16, shape=(), dtype=float32)
tf.Tensor(-129792.78, shape=(), dtype=float32)
tf.Tensor(-130746.25, shape=(), dtype=float32)
tf.Tensor(-131

In [9]:
def test_elbo_1():
    x = np.random.normal(0,1, size=(3,4,2))
    encoder, decoder = get_model(3,2,x.shape[1:],1)
    L=5
    seed=0
    print (elbo(x, encoder, decoder, L,seed).numpy(),
    manual_elbo(x, encoder, decoder, L, seed),
    bad_elbo(x, encoder, decoder, L, seed))
test_elbo_1()

-46.603218 -46.60321758116592 -46.60322114257292


In [10]:
def test_grad_elbo_1():
    x = np.random.normal(0,1, size=(3,4,2))
    encoder, decoder = get_model(3,2,x.shape[1:],1)
    L=5
    seed=0
    print (grad_elbo(x, encoder, decoder, L, seed))
test_grad_elbo_1()

W0903 09:08:11.294821 139768131282688 deprecation.py:323] From /home/dung/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


(<tf.Tensor: id=734, shape=(), dtype=float32, numpy=-41.36326>, [[<tf.Tensor: id=948, shape=(8, 3), dtype=float32, numpy=
array([[ 4.511608  , -0.19963868, -0.5984243 ],
       [-2.3367553 ,  0.08292091,  0.00629209],
       [ 5.466874  , -0.31411502, -0.19245052],
       [ 0.26483363, -0.16757119,  0.9171245 ],
       [ 3.5577078 , -0.26744398, -0.35635006],
       [-2.018921  ,  0.21768755,  0.32522786],
       [ 2.5862007 , -0.09949231, -1.0005776 ],
       [ 4.8295403 , -0.18388346, -1.2133842 ]], dtype=float32)>, <tf.Tensor: id=949, shape=(3,), dtype=float32, numpy=array([-3.2936754 ,  0.22852534, -0.10238004], dtype=float32)>, <tf.Tensor: id=937, shape=(3, 2), dtype=float32, numpy=
array([[-0.54802805,  3.387109  ],
       [ 0.9924328 , -4.329119  ],
       [-0.8883408 ,  3.639003  ]], dtype=float32)>, <tf.Tensor: id=934, shape=(2,), dtype=float32, numpy=array([-1.1085176,  4.637015 ], dtype=float32)>, <tf.Tensor: id=942, shape=(3, 2), dtype=float32, numpy=
array([[-6.9688573,  2

In [14]:
def verify_grad_elbo():
    x = np.random.normal(0,1, size=(5,2))
    encoder, decoder = get_model(2,1,x.shape[1:],2)
    L=5
    seed=0
    eps = 0.000001
    ERR_GRAD = 1
    grad = grad_elbo(x, encoder, decoder, L, seed)[1]
    elbo_at = elbo(x, encoder, decoder, L, seed)
    encoder_weights = encoder.get_weights()
    decoder_weights = decoder.get_weights()
    weights = [encoder_weights, decoder_weights]
    coders = (encoder, decoder)
    encoder.set_weights(encoder_weights)
    decoder.set_weights(decoder_weights)
    for i in range(len(weights)):
        for j in range(len(weights[i])):
            layer = weights[i][j]
            if len(layer.shape) == 1:
                for k in range(len(layer)):
                    layer[k] += eps
                    elbo_near = elbo(x, encoder, decoder, L, seed)
                    layer[k] -= eps
                    numer_grad = (elbo_near - elbo_at)/eps
                    assert(abs(grad[i][j].numpy()[k] - numer_grad) < ERR_GRAD)
            else:
                for k in range(len(layer)):
                    for l in range(len(layer[k])):
                        layer[k][l] += eps
                        elbo_near = elbo(x, encoder, decoder, L, seed)
                        layer[k][l] -= eps
                        numer_grad = (elbo_near - elbo_at)/eps
                        print(grad[i][j].numpy()[k][l], numer_grad)
                        assert(abs(grad[i][j].numpy()[k][l] - numer_grad) < ERR_GRAD)

In [13]:
optimizer = tf.keras.optimizers.Adamax()
global_step = tf.Variable(0)

In [34]:
type(encoder.trainable_weights)

list