In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds

tf.logging.set_verbosity(tf.logging.ERROR)


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
import numpy as np
from IPython.display import clear_output
%config InlineBackend.figure_formats = ['svg']

In [3]:
from tfp_vae.vae import VAE

## hyperparams
hps = tf.contrib.training.HParams(
    batch_size=64,
    img_height=32,
    img_width=32,
    img_channels=1,
    z_dim=100,
    discrete_outputs=True,
    num_epochs=3)

# Construct a tf.data.Dataset
ds_train, ds_test = tfds.load(name="mnist", split=["train", "test"])

def process_dataset(ds, epochs):
    f1 = lambda row: row["image"]
    f2 = lambda img: tf.cast(img, dtype=tf.int32)
    f3 = lambda img: tf.cast(img, dtype=tf.float32)
    f4 = lambda img: img * tf.constant((1.0 / 255.0))
    f5 = lambda img: tf.image.resize_images(img, [32, 32])
    f6 = lambda img: tf.round(img)
    f7 = lambda img: tf.cast(img, dtype=tf.int32)
    normalize_pixels = lambda row: f7(f6(f5(f4(f3(f2(f1(row)))))))
    ds = ds.map(normalize_pixels)
    ds = ds.shuffle(1000)
    ds = ds.batch(hps.batch_size)
    ds = ds.repeat(epochs)
    ds = ds.prefetch(10)
    return ds

ds_train = process_dataset(ds_train, epochs=hps.num_epochs)
ds_test = process_dataset(ds_test, epochs=1)

vae = VAE(hps)
init_op = tf.global_variables_initializer()

sess = tf.Session()
_ = sess.run(init_op)

In [4]:
iteration = 0

for batch in tfds.as_numpy(ds_train):
    images = batch
    elbo = vae.train(sess, images)
    
    print('Iter {}... ELBO: {}'.format(iteration, elbo))
    
    iteration += 1

Iter 0... ELBO: -729.9469604492188
Iter 1... ELBO: -724.716796875
Iter 2... ELBO: -721.4344482421875
Iter 3... ELBO: -717.4259643554688
Iter 4... ELBO: -714.0306396484375
Iter 5... ELBO: -710.7200927734375
Iter 6... ELBO: -707.55859375
Iter 7... ELBO: -704.4791870117188
Iter 8... ELBO: -701.6035766601562
Iter 9... ELBO: -697.7581787109375
Iter 10... ELBO: -694.9827880859375
Iter 11... ELBO: -691.2391357421875
Iter 12... ELBO: -687.4039306640625
Iter 13... ELBO: -685.2200927734375
Iter 14... ELBO: -681.516357421875
Iter 15... ELBO: -676.09423828125
Iter 16... ELBO: -673.21630859375
Iter 17... ELBO: -668.0511474609375
Iter 18... ELBO: -662.0806884765625
Iter 19... ELBO: -654.5950927734375
Iter 20... ELBO: -651.5343017578125
Iter 21... ELBO: -642.81982421875
Iter 22... ELBO: -637.383056640625
Iter 23... ELBO: -630.26708984375
Iter 24... ELBO: -615.0050048828125
Iter 25... ELBO: -608.4891967773438
Iter 26... ELBO: -595.5088500976562
Iter 27... ELBO: -589.334228515625
Iter 28... ELBO: -583.

Iter 225... ELBO: -250.72630310058594
Iter 226... ELBO: -252.87667846679688
Iter 227... ELBO: -268.11041259765625
Iter 228... ELBO: -253.179931640625
Iter 229... ELBO: -253.58177185058594
Iter 230... ELBO: -256.003662109375
Iter 231... ELBO: -250.8680877685547
Iter 232... ELBO: -259.0232849121094
Iter 233... ELBO: -241.64590454101562
Iter 234... ELBO: -257.96044921875
Iter 235... ELBO: -250.3859100341797
Iter 236... ELBO: -242.46688842773438
Iter 237... ELBO: -247.91494750976562
Iter 238... ELBO: -263.80499267578125
Iter 239... ELBO: -249.2689208984375
Iter 240... ELBO: -249.92352294921875
Iter 241... ELBO: -249.36102294921875
Iter 242... ELBO: -260.3346252441406
Iter 243... ELBO: -251.89901733398438
Iter 244... ELBO: -266.955322265625
Iter 245... ELBO: -262.4355163574219
Iter 246... ELBO: -252.81959533691406
Iter 247... ELBO: -260.0534362792969
Iter 248... ELBO: -242.4871826171875
Iter 249... ELBO: -262.6478271484375
Iter 250... ELBO: -262.05267333984375
Iter 251... ELBO: -263.6594543

Iter 444... ELBO: -234.63555908203125
Iter 445... ELBO: -228.69952392578125
Iter 446... ELBO: -224.96832275390625
Iter 447... ELBO: -233.28155517578125
Iter 448... ELBO: -221.59133911132812
Iter 449... ELBO: -237.85946655273438
Iter 450... ELBO: -230.77947998046875
Iter 451... ELBO: -235.0552978515625
Iter 452... ELBO: -210.64031982421875
Iter 453... ELBO: -230.98294067382812
Iter 454... ELBO: -221.6182098388672
Iter 455... ELBO: -218.37596130371094
Iter 456... ELBO: -209.65200805664062
Iter 457... ELBO: -233.4285125732422
Iter 458... ELBO: -232.9909210205078
Iter 459... ELBO: -238.99652099609375
Iter 460... ELBO: -226.603271484375
Iter 461... ELBO: -238.44146728515625
Iter 462... ELBO: -222.57272338867188
Iter 463... ELBO: -234.42263793945312
Iter 464... ELBO: -212.9147491455078
Iter 465... ELBO: -225.89454650878906
Iter 466... ELBO: -245.46124267578125
Iter 467... ELBO: -218.0048828125
Iter 468... ELBO: -229.4516143798828
Iter 469... ELBO: -223.68190002441406
Iter 470... ELBO: -229.1

KeyboardInterrupt: 

In [None]:
def interpolate_random_pairs(x, num_pairs=5):
    rows = []
    def interpolate_pair():
        i = np.random.randint(hps.batch_size)
        j = np.random.randint(hps.batch_size)
        x1 = x[i]
        x2 = x[j]
        interpolations = vae.interpolate(sess, x1, x2)
        
        row = np.concatenate([x1] + [interpolations[i] for i in range(0, 5)] + [x2], axis=1)
        if hps.img_channels == 1:
            row = row[:,:,0]
            
        return row
    
    rows = [interpolate_pair() for _ in range(0, num_pairs)]
    viz = np.concatenate(rows, axis=0)
    cmap_spec = 'gray' if hps.img_channels == 1 else None
    plt.imshow(viz, cmap=cmap_spec)
    plt.show()

In [None]:
x = mnist_dataset.get_reference_batch()
interpolate_random_pairs(x)

In [None]:
x = mnist_dataset.get_reference_batch()
interpolate_random_pairs(x)