In [None]:
import numpy as np
import matplotlib.pyplot as plt
from vae import VAE
from functools import reduce
import operator

### Load LunarLander Training Data
Go ahead and load the data from disk.

In [None]:
lunar_data_path = 'LunarLander-v2_img_10_200.npz'
lunar_data = np.load(lunar_data_path)

And the model itself, of course.

In [None]:
lunar_vae_32 = VAE()
lunar_vae_32.make_vae(lunar_data_path, 32)
lunar_vae_32.load_model('LunarLander_32.h5')

lunar_vae_64 = VAE()
lunar_vae_64.make_vae(lunar_data_path, 64)
lunar_vae_64.load_model('LunarLander_64.h5')

### LunarLander Visualization
Here's what a typical frame from this environment will look like.

In [None]:
img0 = lunar_data['arr_0'][5]
fig = plt.figure(figsize = (6,6))
fig.add_subplot(111).imshow(img0)

The VAE managed to capture all of the above in just one latent vector.

In [None]:
z0 = lunar_vae_32.encode_image(np.array([img0]))
z0

Side-by-side comparison of original data & VAE reconstruction, for your viewing pleasure.

In [None]:
rec0 = lunar_vae_32.decode_latent(z0)[0]
f, axarr = plt.subplots(1,2,figsize=(12,12))
axarr[0].imshow(img0)
axarr[1].imshow(rec0)
plt.show()

Curious about just how compressed the information is?

In [None]:
compression = reduce(operator.mul, (rec0.shape[i] for i in range(len(rec0.shape)))) / z0.shape[1]
print(str(compression) + "x compression ratio!")

Not exactly the same (some noise), but VAE reconstruction gets the gist of it!

### VAE Resiliency
Now, what happens if we add some encoding noise?

Let's nudge our latent vector a bit.

In [None]:
z0 += 5 * 1.2
rec0 = lunar_vae_32.decode_latent(z0)[0]
fig = plt.figure(figsize = (6,6))
fig.add_subplot(111).imshow(rec0)

Now let's introduce a spaceship.

In [None]:
img1 = lunar_data['arr_8'][30]
fig = plt.figure(figsize = (6,6))
fig.add_subplot(111).imshow(img1)

What does our new latent vector look like?

In [None]:
z1 = lunar_vae_32.encode_image(np.array([img1]))
z1

Alright, let's try it out!

In [None]:
rec1 = lunar_vae_32.decode_latent(z1)[0]
f, axarr = plt.subplots(1,2,figsize=(12,12))
axarr[0].imshow(img1)
axarr[1].imshow(rec1)
plt.show()

Terrain looks good (maybe a bit fuzzy?), but where's the spaceship?!

A lot of reconstruction loss is going associated with the black/white since they are the extremes. The purple is not a huge deal, being pretty close to black anyway. It is in the net's interest to get the terrain right first, so it will dedicate most of 32-dim latent vector to that.

Let's try a bigger latent vector.

In [None]:
z1 = lunar_vae_64.encode_image(np.array([img1]))
z1

Any improvement?

In [None]:
rec1 = lunar_vae_64.decode_latent(z1)[0]
f, axarr = plt.subplots(1,2,figsize=(12,12))
axarr[0].imshow(img1)
axarr[1].imshow(rec1)
plt.show()

Terrain looks crisper, and there's definitely a discernible spaceship now (exact ship details don't really matter). Looks good!

However: does it generalize to other environments?

## Space Invaders VAE

Let's try out the model with Space Invaders-v0. We're gonna go ahead and load the data & model from disk.

In [None]:
space_data_path = 'SpaceInvaders-v0_img_10_250.npz'
space_data = np.load(space_data_path)

In [None]:
space_vae = VAE()
space_vae.make_vae(space_data_path, 64)
space_vae.load_model('SpaceInvaders_64.h5')

What does a typical data frame look like?

In [None]:
img2 = space_data['arr_0'][0]
fig = plt.figure(figsize = (6,6))
fig.add_subplot(111).imshow(img2)

What does the latent encoding of the above frame look like?

In [None]:
z2 = space_vae.encode_image(np.array([img2]))
z2

Let's do a side-by-side comparison with the original & reconstruction!

In [None]:
rec2 = space_vae.decode_latent(z2)[0]
f, axarr = plt.subplots(1,2,figsize=(12, 12))
axarr[0].imshow(img2)
axarr[1].imshow(rec2)
plt.show()

It's pretty damn good. Some slight fuzziness if you squint, but very, very good reconstruction accuracy overall!

In [None]:
for on in [-1, -28, -46]:
  img2 = space_data['arr_5'][on] #also do 8
  z2 = space_vae.encode_image(np.array([img2]))
  rec2 = space_vae.decode_latent(z2)[0]
  f, axarr = plt.subplots(1,2,figsize=(8,8))
  axarr[0].imshow(img2)
  axarr[1].imshow(rec2)
  plt.show()

Pretty much all elements there (and in the right color too)!

Exactly how small can we make these latent vectors?

## CartPole VAE

Obviously our overall method is way too overpowered for CartPole, but it's a good way to see just how well our VAE can learn features. 

Let's go ahead and load the model & data from disk.

In [None]:
cart_data_path = 'CartPole-v0_img_10_50.npz'
cart_data = np.load(cart_data_path)

In [None]:
cart_vae = VAE()
cart_vae.make_vae(cart_data_path, 8)
cart_vae.load_model('CartPole_8.h5')

What does a typical frame look like?

In [None]:
img3 = cart_data['arr_0'][-28] #(0, 0), (0, -5), (1, 15)
plt.imshow(img3)

What about a typical latent encoding?

In [None]:
z3 = cart_vae.encode_image(np.array([img3]))
z3

Only 8 dimensions, that's right. Theoretically, we could get down to as little as 2, but it would require copious training time.

Let's go ahead and do a side-by-side original & reconstruction comparison!

In [None]:
rec3 = cart_vae.decode_latent(z3)[0]
f, axarr = plt.subplots(1,2,figsize=(8,8))
axarr[0].imshow(img3)
axarr[1].imshow(rec3)
plt.show()

Looks good to me!