# Doom VAE

The idea of this notebook is to construct a small variational auto-encoder that can reproduce images from the "Doom" video game.

This should all work in Colab, including some fun controls below to investigate the model.


In [1]:
import keras
from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Layer
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras.models import Model
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

from IPython.display import Image, display
from keras.preprocessing import image

In [None]:
import urllib.request
try:
    url = 'https://metatonetransfer.com/datasets/doom_images.npz'  
    urllib.request.urlretrieve(url, './doom_images.npz') 
except Exception as e:
    print(e)

# Test loading new file.
with np.load('doom_images.npz') as data:
    x_train = data['arr_0']

# View an input
#plt.imshow(x_train[0])
print("Here's a sample image:")
img = image.array_to_img(x_train[np.random.randint(len(x_train))], scale=False)
display(img.resize((300, 300)))

# The data, split between train and test sets:
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')

x_train = x_train.astype('float32') / 255 # scale to [0,1]

In [3]:
# Setup neural network hyperparameters
img_rows, img_cols, img_chns = 64, 64, 3
latent_dim = 16
intermediate_dim = 128
epsilon_std = 1.0
epochs = 100
filters = 32
num_conv = 3
batch_size = 128

img_size = (img_rows, img_cols, img_chns)
original_dim = img_rows * img_cols * img_chns

In [None]:
# Enc
input_img = Input(shape=img_size, name='encoder_input')
x = Conv2D(img_chns, kernel_size=(2,2), padding='same', activation='relu')(input_img)
x = Conv2D(filters, kernel_size=(2,2), padding='same', activation='relu', strides=(2,2))(x)
x = Conv2D(filters, kernel_size=(2,2), padding='same', activation='relu', strides=(2,2))(x)
# x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding='same')(x) # try a max pooling layer here instead of the previous stride
x = Conv2D(filters, kernel_size=num_conv, padding='same', activation='relu', strides=1)(x)
shape_before_flattening = x.shape
x = Flatten()(x)
x = Dense(intermediate_dim, activation='relu', name='latent_project')(x)

print("Shape before flattening:",shape_before_flattening)

# mean and var
z_mean = Dense(latent_dim, name='Z_mean')(x)
z_log_var = Dense(latent_dim, name='Z_var')(x)

# make an encoder model (not used until after training)
encoder = Model(input_img, z_mean)


In [None]:

# sampling layer
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

z = Lambda(sampling, name="Z_sample", output_shape=z_mean.shape)([z_mean, z_log_var])

# dec
decoder_input = Input(z.shape[1:])
y = Dense(intermediate_dim, activation='relu')(decoder_input) # (z)
y = Dense(np.prod(shape_before_flattening[1:]), activation='relu')(y)
y = Reshape(shape_before_flattening[1:])(y)
y = Conv2DTranspose(filters, kernel_size=num_conv, padding='same', strides=1, activation='relu', name='deconv_1')(y) # deconv 1
y = Conv2DTranspose(filters, kernel_size=num_conv, padding='same', strides=(2,2), activation='relu', name='deconv_2')(y) # deconv 2
y = Conv2DTranspose(filters, kernel_size=(3, 3), strides=(2, 2), padding='valid', activation='relu', name='deconv_3')(y) # deconv 3, upsamp
y = Conv2D(img_chns, kernel_size=2, padding='valid', activation='sigmoid', name="mean_squash")(y) # mean squash
decoder = Model(decoder_input, y, name="Decoder")
z_decoded = decoder(z) #y



In [18]:
# modified from https://keras.io/examples/generative/vae/ in 2024. why is it using tf.GradientTape
from keras import ops

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = ops.mean(
                ops.sum(
                    keras.losses.binary_crossentropy(data, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
decoder.summary()
vae.summary()

In [22]:
# def xent(y_true, y_pred):
#   return keras.metrics.binary_crossentropy(y_true, y_pred)

# def kl_measure(loc, log_var):
#   return -0.5 * keras.ops.mean(1 + log_var - keras.ops.square(loc) - keras.ops.exp(log_var), axis=-1)


# class VAELayer(keras.layers.Layer):    
#     def __init__(self, **kwargs):
#         self.is_placeholder = True
#         super(VAELayer, self).__init__(**kwargs)
      
#     def vae_loss(self, x, z_decoded):
#         x = Flatten()(x)
#         z_decoded = Flatten()(z_decoded)
#         r_loss = original_dim * xent(x, z_decoded)
#         kl_loss = kl_measure(z_mean, z_log_var)
#         print("KL Shape:", kl_loss.shape)
#         print("Xent shape:", r_loss.shape)
#         return keras.ops.mean(r_loss + kl_loss)
    
#     def call(self, inputs):
#         x = inputs[0]
#         z_decoded = inputs[1]
#         loss = self.vae_loss(x, z_decoded)
#         self.add_loss(loss) #, inputs=inputs)
#         return x

# y = VAELayer()([input_img, z_decoded])

# vae = Model(input_img, y, name="VAE")
# vae.compile(optimizer='adam', metrics=['mse','binary_crossentropy'])


# decoder.summary()
# vae.summary()

In [None]:
# Train!
history = vae.fit(x_train, epochs=epochs, batch_size=batch_size)

In [None]:
# Plot the training loss.
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'])
plt.xlabel("epochs")
plt.ylabel("loss")
plt.show()

In [None]:
# Alternatively, download and load weights.
!wget https://metatonetransfer.com/datasets/doom_vae_weights.h5
vae.load_weights("doom_vae_weights.h5")


In [None]:
# Let's see how the encoder works
# First we'll take a random image from the corpus and encode it to a latent vector:
ex = x_train[np.random.randint(len(x_train))]
plt.figure(figsize=(5, 5))
plt.imshow(ex) # cmap ignored if input is 3D (as it should be here)
plt.show()

enc_z = encoder.predict(np.array([ex]))
display(enc_z[0])

In [None]:
# Now we can decode from the same vector to try to reproduce that image:
ex_dec = decoder.predict(np.array([enc_z[0]]))
# Plot output
plt.figure(figsize=(5, 5))
plt.imshow(ex_dec[0]) # cmap ignored if input is 3D (as it should be here)
plt.show()

In [None]:
# Let's try sampling different parts of the latent space to see what we have.
n = 10 # num images
img_size = 64
figure = np.zeros((img_size * n, img_size * n, img_chns))

for i in range(n):
    for j in range(n):
        z_sample = np.array([np.random.uniform(-1,1 ,size=latent_dim)])
        x_decoded = decoder.predict(z_sample)
        img = x_decoded[0].reshape(img_size, img_size, img_chns)
        figure[i * img_size: (i + 1) * img_size,j * img_size: (j + 1) * img_size] = img

plt.figure(figsize=(20, 20))
plt.imshow(figure)
plt.show()

In [None]:
# Save and download models
!mkdir models

def save_model_three_ways(model, name="model"):
  # Save the weights
  # model.save("./models/" + name + ('_ld_%d_conv_%d_id_%d_e_%d.h5' % (latent_dim, num_conv, intermediate_dim, epochs)))
  model.save_weights("./models/" + name + '_weights.h5')
  # Save the model architecture
  with open("./models/" + name + '_architecture.json', 'w') as f:
    f.write(model.to_json())

save_model_three_ways(vae, name="vae")
save_model_three_ways(encoder, name="encoder")
save_model_three_ways(encoder, name="decoder")
!tar -czvf doom_models.tar.gz models

#from google.colab import files
#files.download('doom_models.tar.gz')

In [None]:
# Colab only!
#@title Interactive Latent Space Exploration { run: "auto", vertical-output: true, form-width: "50px" }
z_1 = 0.88 #@param {type:"slider", min:-1, max:1, step:0.01}
z_2 = -0.85 #@param {type:"slider", min:-1, max:1, step:0.01}
z_3 = 0.7 #@param {type:"slider", min:-1, max:1, step:0.01}
z_4 = 0.51 #@param {type:"slider", min:-1, max:1, step:0.01}
z_5 = 0.15 #@param {type:"slider", min:-1, max:1, step:0.01}
z_6 = 0.23 #@param {type:"slider", min:-1, max:1, step:0.01}
z_7 = 0.51 #@param {type:"slider", min:-1, max:1, step:0.01}
z_8 = -0.8 #@param {type:"slider", min:-1, max:1, step:0.01}
z_9 = 0.77 #@param {type:"slider", min:-1, max:1, step:0.01}
z_10 = -0.99 #@param {type:"slider", min:-1, max:1, step:0.01}
z_11 = -0.51 #@param {type:"slider", min:-1, max:1, step:0.01}
z_12 = -0.29 #@param {type:"slider", min:-1, max:1, step:0.01}
z_13 = -0.48 #@param {type:"slider", min:-1, max:1, step:0.01}
z_14 = -0.6 #@param {type:"slider", min:-1, max:1, step:0.01}
z_15 = 0.56 #@param {type:"slider", min:-1, max:1, step:0.01}
z_16 = 0.92 #@param {type:"slider", min:-1, max:1, step:0.01}

new_z= np.array([z_1,z_2,z_3,z_4,z_5,z_6,z_7,z_8,z_9,z_10,z_11,z_12,z_13,z_14,z_15,z_16])
print(new_z)
dec = decoder.predict(np.array([new_z]))
plt.figure(figsize=(10, 10))
plt.imshow(dec[0]) # cmap ignored if input is 3D (as it should be here)
plt.show()

In [None]:
ex = x_train[1]
plt.figure(figsize=(5, 5))
plt.imshow(ex) # cmap ignored if input is 3D (as it should be here)
plt.show()

enc_z = encoder.predict(np.array([ex]))[0]
print(enc_z)
z_1 = enc_z[0]
z_2 = enc_z[1]
z_3 = enc_z[2]
z_4 = enc_z[3]
z_5 = enc_z[4]
z_6 = enc_z[5]
z_7 = enc_z[6]
z_8 = enc_z[7]
z_9 = enc_z[8]
z_10 = enc_z[9]
z_11 = enc_z[10]
z_12 = enc_z[11]
z_13 = enc_z[12]
z_14 = enc_z[13]
z_15 = enc_z[14]
z_16 = enc_z[15]