<center><img src='https://i.postimg.cc/TPR1n1rp/AI-Tech-PL-RGB.png' height="60"></center>

AI TECH - Akademia Innowacyjnych Zastosowań Technologii Cyfrowych. Programu Operacyjnego Polska Cyfrowa na lata 2014-2020
<hr>

<center><img src='https://i.postimg.cc/Gpq2KRQz/logotypy-aitech.jpg'></center>

<center>
Projekt współfinansowany ze środków Unii Europejskiej w ramach Europejskiego Funduszu Rozwoju Regionalnego
Program Operacyjny Polska Cyfrowa na lata 2014-2020,
Oś Priorytetowa nr 3 "Cyfrowe kompetencje społeczeństwa" Działanie  nr 3.2 "Innowacyjne rozwiązania na rzecz aktywizacji cyfrowej"
Tytuł projektu:  „Akademia Innowacyjnych Zastosowań Technologii Cyfrowych (AI Tech)”
</center>

# Lab 06: Model-based Reinforcement Learning

In this lab, we will reproduce the _Learning Inside of a Dream_ experiment of the seminal [World Models](https://worldmodels.github.io/) paper.

In [None]:
#@title Mount your Google Drive

#@markdown Your work will be stored in a folder called `rl_lab_2022` by default.

#@markdown Run each section with Shift+Enter

#@markdown Double-click on section headers to show code.

import os
from google.colab import drive
drive.mount('/content/gdrive')

LAB_PATH = '/content/gdrive/MyDrive/rl_lab_2022/model_based_rl'
if not os.path.exists(LAB_PATH):
  %mkdir -p $LAB_PATH

# cd into the lab directory
%cd $LAB_PATH

## 0. Collect data

In [None]:
!pip install scikit-video

In [None]:
import itertools
import random
import time

import matplotlib.pyplot as plt
import numpy as np
import skvideo.io
import tensorflow as tf

from base64 import b64encode
from IPython.display import HTML

def show_video(file_name):
    mp4 = open(file_name,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML("""
    <video width=256 controls>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url)

ACTIONS = [[0, 0], # no-op
           [0, 1], # right
           [1, 0]] # left
FRAMES_PER_ACTION = 2 # 4? 8? 12??
OBS_SIZE = (64, 64)
TOTAL_STEPS = 225_000 # 2_000_000

_Install VizDoom..._

In [None]:
%%bash
apt-get update

# Install deps from
# https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md#-linux
apt-get -qq -y install build-essential zlib1g-dev libsdl2-dev libjpeg-dev \
nasm tar libbz2-dev libgtk2.0-dev cmake git libfluidsynth-dev libgme-dev \
libopenal-dev timidity libwildmidi-dev unzip

# Boost libraries
apt-get -qq -y install libboost-all-dev

# Lua binding dependencies
apt-get -qq -y install liblua5.1-dev

In [None]:
!pip install vizdoom

In [None]:
import os
import vizdoom as vzd

TAKE_OVER_CONFIG = os.path.join(vzd.scenarios_path, 'take_cover.cfg')

def initialize_doom_game():
    game = vzd.DoomGame()
    game.load_config(TAKE_OVER_CONFIG)
    game.set_window_visible(False)
    game.set_mode(vzd.Mode.PLAYER)
    game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)
    game.init()

    return game

In [None]:
@tf.function
def preprocess(obs):
    obs = tf.transpose(obs, (1, 2, 0)) # Move the channel dim. to the end
    obs = obs[80:400, :, :]
    obs = tf.image.resize(obs, OBS_SIZE, method='area')
    return tf.cast(obs, tf.uint8)

@tf.function
def batch_preprocess(obs):
    obs = tf.transpose(obs, (0, 2, 3, 1)) # Move the channel dim. to the end
    obs = obs[:, 80:400, :, :]
    obs = tf.image.resize(obs, OBS_SIZE, method='area')
    return tf.cast(obs, tf.uint8)

In [None]:
game = initialize_doom_game()
obs_list = []
game.new_episode()
while not game.is_episode_finished():
    obs_list.append(preprocess(game.get_state().screen_buffer))
    game.make_action(random.choice(ACTIONS), FRAMES_PER_ACTION)
frames = tf.stack(obs_list)
print('Return: ', game.get_total_reward())

In [None]:
file_name = 'take_over.mp4'
skvideo.io.vwrite(file_name, frames)
show_video(file_name)

### Collect a new dataset (optional!)

In [None]:
obs_array = np.empty([TOTAL_STEPS, *OBS_SIZE, 3], dtype=np.uint8)
act_array = np.empty([TOTAL_STEPS, 1], dtype=np.int)
done_array = np.empty([TOTAL_STEPS, 1], dtype=np.int)

repeat = 4
done = False
game = initialize_doom_game()
iter_time = time.time()
for i in range(TOTAL_STEPS):
    if (i + 1) % 1000 == 0:
        steps_per_second = 1000 / (time.time() - iter_time)
        print(f'Step {(i + 1)//1000}k/{TOTAL_STEPS//1000}k, Steps per second: {steps_per_second:.0f}')
        iter_time = time.time()

    if i % repeat == 0:
        repeat = np.random.randint(1, (10 // FRAMES_PER_ACTION) + 1)
        act = random.randint(0, 2)

    obs = preprocess(game.get_state().screen_buffer)
    game.make_action(ACTIONS[act], FRAMES_PER_ACTION)
    done = game.is_episode_finished()

    if done:
        game.new_episode()

    obs_array[i] = obs
    act_array[i] = act
    done_array[i] = done

np.savez('doom_data_225k.npz', observations=obs_array, actions=act_array, dones=done_array)

## 1. Vision

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

tfk = tf.keras
tfkl = tf.keras.layers

BATCH_SIZE = 256
BASE_DEPTH = 32
OBS_SIZE = (64, 64)
INPUT_SHAPE = (*OBS_SIZE, 3)
KL_TOLERANCE = 0.5
LATENT_SIZE = 64

LOGS_DIR = 'logs'
VISION_IMAGES_DIR = os.path.join(LOGS_DIR, 'images/vision')
VISION_CKPT_PATH = os.path.join(LOGS_DIR, 'best_vision')
VISION_LOGS_PATH = os.path.join(LOGS_DIR, 'train_vision.csv')
VISION_WEIGHTS_PATH = os.path.join(LOGS_DIR, 'best_vision.h5')

if tf.test.gpu_device_name() != '/device:GPU:0':
  print('WARNING: GPU device not found.')
else:
  print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name()))

In [None]:
os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(VISION_IMAGES_DIR)

### Load the dataset

In [None]:
if not os.path.exists(f'doom_data_225k.npz'):
    raise ValueError('Collect the data first!')

with np.load('doom_data_225k.npz') as data:
    obs_array, act_array, done_array = data.values()

TOTAL_STEPS = obs_array.shape[0]
print('Number of samples: ', obs_array.shape[0])

In [None]:
def vision_preprocess(sample):
    image = tf.cast(sample, tf.float32) / 255.  # Scale to unit interval.
    return image, image

In [None]:
train_dataset = (tf.data.Dataset.from_tensor_slices((obs_array[:200000]))
                 .shuffle(int(1e4))
                 .batch(BATCH_SIZE)
                 .map(vision_preprocess)
                 .prefetch(tf.data.AUTOTUNE))

test_dataset = (tf.data.Dataset.from_tensor_slices((obs_array[200000:]))
                .batch(BATCH_SIZE)
                .map(vision_preprocess)
                .prefetch(tf.data.AUTOTUNE))

In [None]:
def plot_samples(epoch, logs, vae, samples_iter, title, prefix):
    samples = next(samples_iter)[0].numpy()
    samples = samples[np.random.randint(0, samples.shape[0], size=12)]
    x_pred = vae.predict(samples)

    nrows, ncols = 2, samples.shape[0]
    dx, dy = 1, 1
    figsize = plt.figaspect(float(dy * nrows) / float(dx * ncols))

    imgs = np.empty_like(np.concatenate((samples, x_pred)))
    imgs[:ncols] = samples
    imgs[ncols:] = x_pred

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i], interpolation='none')
        ax.set(xticks=[], yticks=[])

    fig.suptitle(f'{title} [Ground Truth over Predicted]',
                 fontsize=16)
    plt.savefig(os.path.join(VISION_IMAGES_DIR, f'{prefix}_epoch_{epoch}'))
    plt.close(fig)

class PlotSamplesCallable:
    def __init__(self, vae, samples_iter, title, prefix):
        self.vae = vae
        self.samples_iter = samples_iter
        self.title = title
        self.prefix = prefix

    def __call__(self, epoch, logs):
        plot_samples(epoch, logs, self.vae, self.samples_iter, self.title, self.prefix)

### Exercise

Based on this tutorial (https://www.tensorflow.org/tutorials/generative/cvae) fill in the gaps in the VAE code.

In [None]:
encoder_input = tfk.Input(shape=INPUT_SHAPE)

encoder_body = tfk.Sequential([
    tfkl.Lambda(lambda x: x - 0.5),
    ...
    tfkl.Flatten()
])(encoder_input)

mu = ...
logvar = ...
sigma = ...
z = ...

encoder = tf.keras.Model(inputs=encoder_input,
                         outputs=[z, mu, logvar])

In [None]:
decoder_input = tfk.Input(shape=[LATENT_SIZE])

decoder_body = tfk.Sequential([
    tfkl.Dense(32 * BASE_DEPTH, activation=tf.nn.relu),
    tfkl.Reshape([1, 1, 32 * BASE_DEPTH]),
    ...
])(decoder_input)

decoder = tf.keras.Model(inputs=decoder_input,
                         outputs=decoder_body)

In [None]:
def mse_loss(x_true, x_pred):
    # Reconstruction loss
    # NOTE: Shall be a logistic loss (binary crossentropy),
    #       but MSE was used in the official implementation.
    r_loss = ...
    r_loss = tf.reduce_mean(r_loss)

    return r_loss

In [None]:
# (Augmented) KL loss
kl_loss = ...
kl_loss = tf.maximum(kl_loss, KL_TOLERANCE * LATENT_SIZE)
kl_loss = tf.reduce_mean(kl_loss)

In [None]:
vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))
vae.add_loss(kl_loss) # Add the regularization loss
vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-4),
            loss=mse_loss)

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=25),
    tf.keras.callbacks.ModelCheckpoint(VISION_CKPT_PATH,
                                       verbose=1,
                                       save_best_only=True),
    tf.keras.callbacks.CSVLogger(VISION_LOGS_PATH,
                                 append=True)
]

callbacks.append(tf.keras.callbacks.LambdaCallback(
    on_epoch_end=PlotSamplesCallable(vae=vae,
                                     samples_iter=iter(train_dataset),
                                     title='Train examples',
                                     prefix='train')
    ))

callbacks.append(tf.keras.callbacks.LambdaCallback(
    on_epoch_end=PlotSamplesCallable(vae=vae,
                                     samples_iter=iter(test_dataset),
                                     title='Test examples',
                                     prefix='test')
    ))

In [None]:
history = vae.fit(train_dataset,
                  epochs=224444444,
                  initial_epoch=135,
                  validation_data=test_dataset,
                  callbacks=callbacks)

In [None]:
vae.save_weights(VISION_WEIGHTS_PATH)

<h3><center>...or...</center></h3>

In [None]:
# vae = tfk.models.load_model(VISION_CKPT_PATH, custom_objects={'mse_loss': mse_loss})
vae.load_weights(VISION_WEIGHTS_PATH)

## 2. Memory

In [None]:
import pickle
import random
import time

import numpy as np

ACTIONS_NUM = 3
BATCH_SIZE = 100
DONE_WEIGHT = 10. # Factor of importance for done = 1. (rare case for loss).
HIDDEN_DIM = 512
MASK_VALUE = 0.0
NUM_GAUSSIANS = 5
TEMPERATURE = 1.15
TF_LOG_SQRT_TWO_PI = tf.math.log(tf.math.sqrt(2 * np.pi))

MODEL_CKPT_PATH = os.path.join(LOGS_DIR, 'best_model')
MODEL_IMAGES_DIR = os.path.join(LOGS_DIR, 'images/model')
MODEL_LOGS_PATH = os.path.join(LOGS_DIR, 'train_model.csv')
MODEL_WEIGHTS_PATH = os.path.join(LOGS_DIR, 'best_model.h5')

In [None]:
os.makedirs(MODEL_IMAGES_DIR)

In [None]:
def plot_imgs(x_true, x_pred):
    nrows, ncols = 2, x_true.shape[0]
    dx, dy = 1, 1
    figsize = plt.figaspect(float(dy * nrows) / float(dx * ncols)) * 2

    imgs = np.empty_like(np.concatenate((x_true, x_pred)))
    imgs[:ncols] = x_true
    imgs[ncols:] = x_pred

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i], interpolation='none')
        ax.set(xticks=[], yticks=[])

    plt.show()

### Create the 10k episodes dataset

In [None]:
TOTAL_STEPS_10K = 2_000_000
OBS_BATCH_SIZE = 1000

latent_array = np.empty([TOTAL_STEPS_10K, LATENT_SIZE, 2], dtype=np.float32)
obs_batch = np.empty([OBS_BATCH_SIZE, *OBS_SIZE, 3], dtype=np.uint8)
act_array = np.empty([TOTAL_STEPS_10K, 1], dtype=np.int)
done_array = np.empty([TOTAL_STEPS_10K, 1], dtype=np.int)

done = False
game = initialize_doom_game()
iter_time = time.time()
for itr in range(TOTAL_STEPS_10K // OBS_BATCH_SIZE):
    # Collect
    repeat = 0
    for step in range(OBS_BATCH_SIZE):
        if repeat == 0:
            repeat = np.random.randint(1, (10 // FRAMES_PER_ACTION) + 1)
            act = random.randint(0, 2)

        obs = preprocess(game.get_state().screen_buffer)
        game.make_action(ACTIONS[act], FRAMES_PER_ACTION)
        done = game.is_episode_finished()

        if done:
            game.new_episode()

        obs_batch[step] = obs
        act_array[itr * OBS_BATCH_SIZE + step] = act
        done_array[itr * OBS_BATCH_SIZE + step] = done

        repeat -= 1

    # Encode
    idx_start = int(itr * OBS_BATCH_SIZE)
    idx_end = int((itr + 1) * OBS_BATCH_SIZE)

    z_pred = encoder.predict(vision_preprocess(obs_batch)[0])

    latent_array[idx_start:idx_end, :, 0] = z_pred[1]
    latent_array[idx_start:idx_end, :, 1] = z_pred[2]

    # Log
    steps_per_second = OBS_BATCH_SIZE / (time.time() - iter_time)
    iter_time = time.time()

    steps_left = TOTAL_STEPS_10K - ((itr + 1) * OBS_BATCH_SIZE)
    print(f'Step {((itr + 1) * OBS_BATCH_SIZE)//1000}k/{TOTAL_STEPS_10K//1000}k, steps/sec: {steps_per_second:.0f}, ETA: {(steps_left / steps_per_second)//60:.0f} mins')


### Encode the dataset...

In [None]:
latent_array = np.empty([TOTAL_STEPS, LATENT_SIZE, 2], dtype=np.float32)

iter_time = time.time()
for i in range(TOTAL_STEPS // 1000):
    idx_start = int(i * 1000)
    idx_end = int((i + 1) * 1000)

    obs_batch = obs_array[idx_start:idx_end, ...]
    obs_batch = vision_preprocess(obs_batch)[0]

    z_pred = encoder.predict(obs_batch)
    latent_array[idx_start:idx_end, :, 0] = z_pred[1]
    latent_array[idx_start:idx_end, :, 1] = z_pred[2]

    print(f'Step {(i + 1):3d}/{TOTAL_STEPS // 1000}, ETA: {(time.time() - iter_time) * (TOTAL_STEPS // 1000 - i):.0f}s')
    iter_time = time.time()

In [None]:
# Calculate episode boarders
cumsum = np.cumsum(np.ones_like(done_array))[:, None]
ep_borders = [0, *cumsum[np.nonzero(done_array)]]

latent_list = []
action_list = []
done_list = []
for idx_start, idx_end in zip(ep_borders[:-1], ep_borders[1:]):
    latent_list.append(
        latent_array[idx_start:idx_end])
    action_list.append(
        act_array[idx_start:idx_end])
    done_list.append(
        done_array[idx_start:idx_end])

with open('doom_data_2M.pkl', 'wb') as file:
    pickle.dump([latent_list, action_list, done_list], file)

### ...or load the dataset

In [None]:
with open('doom_data_2M.pkl', 'rb') as file:
    latent_list, action_list, done_list = pickle.load(file)

In [None]:
latent_ragged = tf.ragged.constant(latent_list)
action_ragged = tf.ragged.constant(action_list)
done_ragged = tf.ragged.constant(done_list)

In [None]:
latent_ragged.shape

In [None]:
def memory_preprocess(rlatent, raction, rdone):
    latent, action, done = rlatent.to_tensor(), raction.to_tensor(), rdone.to_tensor()

    mu, logvar = latent[..., 0], latent[..., 1]
    sigma = tf.exp(logvar / 2.0)
    z = mu + sigma * tf.random.normal(tf.shape(mu))

    action = tf.one_hot(tf.squeeze(action, axis=1), depth=ACTIONS_NUM)

    done = tf.cast(done, tf.float32) * 2. - 1.

    return (z[:-1], action[:-1]), (z[1:], done[1:])

train_dataset = (tf.data.Dataset.from_tensor_slices((latent_ragged[:10000], action_ragged[:10000], done_ragged[:10000]))
                 .shuffle(10000)
                 .map(memory_preprocess)
                 .padded_batch(BATCH_SIZE, padding_values=MASK_VALUE)
                 .prefetch(tf.data.AUTOTUNE))

test_dataset = (tf.data.Dataset.from_tensor_slices((latent_ragged[10000:], action_ragged[10000:], done_ragged[10000:]))
                .map(memory_preprocess)
                .padded_batch(BATCH_SIZE, padding_values=MASK_VALUE)
                .prefetch(tf.data.AUTOTUNE))

In [None]:
test_iter = iter(test_dataset)
batch = next(test_iter)

In [None]:
batch[0][0][3].shape

In [None]:
x_one = decoder.predict(batch[0][0][3][150:170]) # Step at the timestep `t`
x_two = decoder.predict(batch[1][0][3][150:170]) # Step at the timestep `t+1`
plot_imgs(x_one, x_two)

### Exercise

Base on the original code (https://github.com/hardmaru/WorldModelsExperiments/blob/fd982b9691a941b52c6addbde29bc801ca6202c8/doomrnn/doomrnn.py) implement the MDN-RNN losses.

In [None]:
obs, act = tfkl.Input(shape=[None, LATENT_SIZE]), tfkl.Input(shape=[None, ACTIONS_NUM])
masked_input = tfkl.Masking(mask_value=MASK_VALUE)(tfkl.Concatenate(axis=-1)([obs, act]))

lstm_output = tfkl.LSTM(HIDDEN_DIM, return_sequences=True)(masked_input)

# Predict the mixture mu, logstd, logmix
mdn_output = tfkl.TimeDistributed(tfkl.Dense(LATENT_SIZE * NUM_GAUSSIANS * 3))(lstm_output)
# Predict the done (log unnormalized) probability
done_output = tfkl.TimeDistributed(tfkl.Dense(1))(lstm_output)

def compute_mask(tensor):
    return tf.reduce_any(tf.math.not_equal(tensor, MASK_VALUE),
                         axis=-1,
                         keepdims=True)

def unstack_mdn_coef(output):
    batch_size = tf.shape(output)[0]
    shape = [batch_size, -1, LATENT_SIZE, NUM_GAUSSIANS, 3]
    return tf.unstack(tf.reshape(output, shape), axis=-1)

def mdn_sample(mdn_output):
    mu, logstd, logitmix = unstack_mdn_coef(mdn_output)
    sigma = tf.exp(logstd)

    idxs = tf.random.categorical(
        tf.reshape(logitmix, [-1, NUM_GAUSSIANS]) / TEMPERATURE, 1)
    idxs = tf.reshape(idxs, [*tf.shape(mu)[:3], 1])

    mu_sampled = tf.gather(mu, idxs, batch_dims=3)
    sigma_sampled = tf.gather(sigma, idxs, batch_dims=3)

    return mu_sampled + sigma_sampled * tf.random.normal(tf.shape(mu_sampled))

def mdn_loss(z_true, mdn_output):
    mask = tf.cast(compute_mask(z_true), dtype=tf.float32)
    z_true = tf.expand_dims(z_true, axis=-1)

    mu, logstd, logitmix = unstack_mdn_coef(mdn_output)
    # Normalize the log mixing coefficient
    logmix = ...
    sigma = tf.exp(logstd)

    lognormal = ...
    logmixnormal = ...

    # You sum here because it's a mixture of Gaussian and not joint probabil
    logprob = ...
    loss = -tf.reduce_sum(logprob, axis=-1, keepdims=True)

    return tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)

def done_loss(done_true, done_pred):
    mask = tf.cast(compute_mask(done_true), dtype=tf.float32)
    done_label = ((done_true + 1.) / 2.) * mask
    weight = mask + done_label * (DONE_WEIGHT - 1.) # Weight pred. done = 1.

    loss = ...

    return tf.reduce_sum(loss * weight) / tf.reduce_sum(mask)

mdn_rnn = tfk.models.Model(inputs=[obs, act], outputs=[mdn_output,  done_output])
mdn_rnn.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
               loss=[mdn_loss, done_loss])

In [None]:
def plot_samples(epoch, logs, mdn_rnn, decoder, batch, title, prefix):
    idxs = np.random.randint(0, batch[0][0].shape[0], size=20)

    mdn_output, _ = mdn_rnn(batch[0])
    z_pred = mdn_sample(mdn_output).numpy()[idxs, 42, ...] # Get timestep = 42
    z_true = batch[1][0].numpy()[idxs, 43, ...]

    x_pred = decoder.predict(z_pred)
    x_true = decoder.predict(z_true)

    nrows, ncols = 2, x_true.shape[0]
    dx, dy = 1, 1
    figsize = plt.figaspect(float(dy * nrows) / float(dx * ncols))

    imgs = np.empty_like(np.concatenate((x_true, x_pred)))
    imgs[:ncols] = x_true
    imgs[ncols:] = x_pred

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i], interpolation='none')
        ax.set(xticks=[], yticks=[])

    fig.suptitle(f'{title} [Ground Truth over Predicted]',
                 fontsize=16)
    plt.savefig(os.path.join(MODEL_IMAGES_DIR, f'{prefix}_epoch_{epoch}'))
    plt.close(fig)

class PlotSamplesCallable:
    def __init__(self, mdn_rnn, decoder, batch, title, prefix):
        self.mdn_rnn = mdn_rnn
        self.decoder = decoder
        self.batch = batch
        self.title = title
        self.prefix = prefix

    def __call__(self, epoch, logs):
        plot_samples(epoch, logs, self.mdn_rnn, self.decoder, self.batch, self.title, self.prefix)

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=7),
    tf.keras.callbacks.ModelCheckpoint(MODEL_CKPT_PATH,
                                       verbose=1,
                                       save_best_only=True),
    tf.keras.callbacks.CSVLogger(MODEL_LOGS_PATH,
                                 append=True)
]

callbacks.append(tf.keras.callbacks.LambdaCallback(
    on_epoch_end=PlotSamplesCallable(mdn_rnn=mdn_rnn,
                                     decoder=decoder,
                                     batch=next(iter(train_dataset)),
                                     title='Train memory examples',
                                     prefix='train_memory')
    ))

callbacks.append(tf.keras.callbacks.LambdaCallback(
    on_epoch_end=PlotSamplesCallable(mdn_rnn=mdn_rnn,
                                     decoder=decoder,
                                     batch=next(iter(test_dataset)),
                                     title='Test memory examples',
                                     prefix='test_memory')
    ))

In [None]:
history = mdn_rnn.fit(train_dataset,
                      epochs=224444444,
                      # initial_epoch=17,
                      validation_data=test_dataset,
                      callbacks=callbacks)

In [None]:
mdn_rnn.save_weights(MODEL_WEIGHTS_PATH)

<h3><center>...or...</center></h3>

In [None]:
# mdn_rnn = tfk.models.load_model(MODEL_CKPT_PATH,
#                                 custom_objects={'mdn_loss': mdn_loss,
#                                                 'done_loss': done_loss})

mdn_rnn.load_weights(MODEL_WEIGHTS_PATH)

In [None]:
mdn_output_, done_ = mdn_rnn.predict(batch[0])
mdn_loss(batch[1][0], mdn_output_)

In [None]:
z_pred = mdn_sample(mdn_output_)

In [None]:
batch_idx = 8
print('GT done: ', tf.where(batch[1][1][batch_idx] == 1.0)[0, 0])
gt_zero_idxs = tf.where(done_[batch_idx] > 0.0)[:, 0]
print('Pred. done: ', gt_zero_idxs)
print('Pred. values: ', tf.gather(done_[batch_idx], gt_zero_idxs)[:, 0])

In [None]:
x_one = decoder.predict(batch[1][0][batch_idx][80:100]) # GT observations
x_two = decoder.predict(z_pred[batch_idx][80:100]) # Predicted observations
plot_imgs(x_one, x_two)

### Create the world model!

In [None]:
class WorldModel(tf.keras.Model):
    def __init__(self):
        super(WorldModel, self).__init__()
        self.lstm = tfkl.LSTM(HIDDEN_DIM, return_state=True, stateful=True)
        # Predict the mixture mu, logstd, logmix
        self.mdn_head = tfkl.Dense(LATENT_SIZE * NUM_GAUSSIANS * 3)
        # Predict the done (log unnormalized) probability
        self.done_head = tfkl.Dense(1)

        self._current_latent = None

    def call(self, action, latent=None):
        if latent is None:
            latent = self._current_latent

        concat_input = tf.concat([latent, action], axis=-1)
        # Expand the time dimension
        concat_input = tf.expand_dims(concat_input, axis=1)

        lstm_output, lstm_hidden, _ = self.lstm(concat_input)
        mdn_output = self.mdn_head(lstm_output)
        done_output = self.done_head(lstm_output)

        next_latent = mdn_sample(mdn_output)
        if_done = done_output > 0.0

        # Flatten the time dimension
        next_latent = tf.reshape(next_latent, shape=[-1, LATENT_SIZE])
        if_done = tf.reshape(if_done, shape=[-1, 1])

        self._current_latent = next_latent
        # Add the cell state too?
        return (next_latent, lstm_hidden, if_done)

    def reset(self, init_latent=None):
        self.lstm.reset_states()
        self._current_latent = init_latent

In [None]:
world_model = WorldModel()
# Initialize the variables
world_model(action=tf.zeros([1, ACTIONS_NUM]), latent=tf.zeros([1, LATENT_SIZE]))
# Load the trained MDN-RNN weights
world_model.load_weights(MODEL_WEIGHTS_PATH)

In [None]:
repeat = 4
rollout_length = 400

In [None]:
z_array = np.empty([rollout_length, LATENT_SIZE], dtype=np.float32)
d_array = np.empty([rollout_length, 1], dtype=np.float32)

z_array[0] = batch[0][0][0, 0, :]
d_array[0] = np.NaN

world_model.reset(z_array[None, 0])
for i in range(rollout_length - 1):
    if i % repeat == 0:
        repeat = np.random.randint(1, (10 // FRAMES_PER_ACTION) + 1)
        action = tf.one_hot(random.randint(0, 2), depth=ACTIONS_NUM)

    next_z, _, done = world_model(action[None])

    z_array[i + 1] = next_z[0]
    d_array[i + 1] = done[0]

In [None]:
print(tf.where(d_array > 0.0)[:, 0])

In [None]:
frames = np.empty([rollout_length, 64, 64, 3], dtype=np.uint8)
for i, z in enumerate(z_array):
    frames[i] = (decoder.predict(tf.expand_dims(z, axis=0))[0] * 255.).astype(np.uint8)

# Show imagined game!

Note that:
1. The VAE neural network learnt to **render** the game state!
2. The RNN neural network learnt to **simulate** the game mechanics!

Stop for a second and think about it. Programmers code this stuff by hand. The game engine usually operates on some discrete state and does discontinuous operations on this state. Here the neural networks learnt from random interactions with this game engine how to approximate it with linear algebra only!

In [None]:
file_name = 'dream_take_over.mp4'
skvideo.io.vwrite(file_name, frames)
show_video(file_name)

## 3. Controller

In [None]:
!pip install cma

In [None]:
import cma

MAX_STEPS = 500
NUM_EPISODES = 5
POPULATION_SIZE = 128
STD_DEV_INIT = 0.1
WEIGHT_DECAY = 0.001

@tf.function
def preprocess(obs):
    obs = tf.transpose(obs, (1, 2, 0)) # Move the channel dim. to the end
    obs = obs[80:400, :, :]
    obs = tf.image.resize(obs, OBS_SIZE, method='area')
    return tf.cast(obs, tf.float32) / 255.

In [None]:
class LinearModel():
    def __init__(self, num_agents=POPULATION_SIZE):
        self.kernel = np.zeros([num_agents, LATENT_SIZE + HIDDEN_DIM, ACTIONS_NUM])
        self.bias = np.zeros([num_agents, ACTIONS_NUM])

    def __call__(self, latent, hidden):
        concat_input = np.concatenate([latent, hidden], axis=-1)
        logits = np.sum((concat_input[..., np.newaxis] * self.kernel), axis=1) + self.bias
        return np.argmax(logits, axis=-1)

    def set_parameters(self, params):
        copy_params = np.copy(params)
        self.bias = copy_params[:, :ACTIONS_NUM]
        self.kernel = copy_params[:, ACTIONS_NUM:].reshape(-1, LATENT_SIZE + HIDDEN_DIM, ACTIONS_NUM)

    @property
    def num_parameters(self):
        return (LATENT_SIZE + HIDDEN_DIM) * ACTIONS_NUM + ACTIONS_NUM

In [None]:
world_model = WorldModel()
world_model(action=tf.zeros([POPULATION_SIZE, ACTIONS_NUM]), latent=tf.zeros([POPULATION_SIZE, LATENT_SIZE]))
world_model.load_weights(MODEL_WEIGHTS_PATH)

controller = LinearModel()

es = cma.CMAEvolutionStrategy(controller.num_parameters * [0.], STD_DEV_INIT, {'popsize': POPULATION_SIZE})

In [None]:
INIT_MU = np.array([latent[0, :, 0] for latent in latent_list[:POPULATION_SIZE]])
INIT_SIGMA = np.array([np.exp(latent[0, :, 1] / 2.0) for latent in latent_list[:POPULATION_SIZE]])

### Exercise

Optimize the controller using the methods `stop`, `ask`, and `tell`. Documentation: https://cma-es.github.io/apidocs-pycma/cma.evolution_strategy.CMAEvolutionStrategy.html

In [None]:
while ...:
    # Sample
    solutions = np.array(...)

    # Evaluate
    dones = []
    returns = np.zeros([POPULATION_SIZE])
    controller.set_parameters(solutions)
    for ep_idx in range(NUM_EPISODES):
        latent = INIT_MU + INIT_SIGMA * np.random.randn(*INIT_MU.shape)
        hidden = np.zeros([POPULATION_SIZE, HIDDEN_DIM])
        world_model.reset(latent)
        for _ in range(MAX_STEPS):
            action = controller(latent, hidden)
            latent, hidden, done = world_model(tf.one_hot(action, depth=ACTIONS_NUM), latent)
            dones.append(done)
        dones.append(np.ones_like(dones[0]))
        done_array = np.squeeze(np.array(dones), axis=-1).transpose(1, 0)
        indices = tf.where(done_array)
        returns += tf.math.segment_min(indices[:, 1], indices[:, 0])
    returns /= NUM_EPISODES

    # Improve
    if WEIGHT_DECAY > 0:
        l2_decay = np.mean(solutions * solutions, axis=1)
        returns -= WEIGHT_DECAY * l2_decay

    # Convert minimizer to maximizer.
    values = (-1 * returns).numpy().tolist()

    # Improve!
    ...

    # Log
    es.disp()

In [None]:
np.save('mean_controller.npy', es.result[5])
np.save('best_controller.npy', es.best.x)

<h3><center>...or...</center></h3>

In [None]:
mean_params = np.load('mean_controller.npy')
best_params = np.load('best_controller.npy')
rand_params = np.random.randn(*best_params.shape)

In [None]:
world_model = WorldModel()
world_model(action=tf.zeros([1, ACTIONS_NUM]), latent=tf.zeros([1, LATENT_SIZE]))
world_model.load_weights(MODEL_WEIGHTS_PATH)

controller = LinearModel()

game = initialize_doom_game()

In [None]:
eval_episodes = 42
returns = []
controller.set_parameters(best_params[None, ...])
mean_time = 0.
iter_time = time.time()
for i in range(eval_episodes):
    game.new_episode()
    world_model.reset()
    hidden = np.zeros([1, HIDDEN_DIM])
    while not game.is_episode_finished():
        # Get the current latent state
        obs = preprocess(game.get_state().screen_buffer)
        latent = encoder.predict(tf.expand_dims(obs, axis=0))[1] # mean

        # Step the environment and the world model
        action = controller(latent, hidden)
        game.make_action(ACTIONS[int(action[0])], FRAMES_PER_ACTION)
        _, hidden, _ = world_model(tf.one_hot(action, depth=ACTIONS_NUM), latent)
    returns.append(game.get_total_reward())

    mean_time = (mean_time * i + (time.time() - iter_time)) / (i + 1)
    print(f'Step {(i + 1):2d}/{eval_episodes}, ETA: {mean_time * (eval_episodes - i):.0f}s')
    iter_time = time.time()
print('Avg. return: ', np.mean(returns),  '; Std. dev.: ', np.std(returns))

* Random agent: 299 +/-  101
* Best agent: 493 +/- 318
* Mean agent: 897 +/- 516

In [None]:
controller.set_parameters(mean_params[None, ...])

game.new_episode()
world_model.reset()
hidden = np.zeros([1, HIDDEN_DIM])
obs_list = []
while not game.is_episode_finished():
    # Get the current latent state
    obs = preprocess(game.get_state().screen_buffer)
    obs_list.append(obs)
    latent = encoder.predict(tf.expand_dims(obs, axis=0))[1] # mean

    # Step the environment and the world model
    action = controller(latent, hidden)
    game.make_action(ACTIONS[int(action[0])], FRAMES_PER_ACTION)
    _, hidden, _ = world_model(tf.one_hot(action, depth=ACTIONS_NUM), latent)
frames = tf.cast(tf.stack(obs_list) * 255.0, tf.uint8)

In [None]:
file_name = 'take_over.mp4'
skvideo.io.vwrite(file_name, frames)
show_video(file_name)