# World Models

**Algorithm**

1. Collect 10k random episodes. `extract.py`
2. Train VAE. `vae_train.py`
3. Pre-process collected data using VAE. `series.py`
4. Train MDN-RNN. `rnn_train.py`
5. Run CMA-ES. `train.py`

# VAE

In [30]:
# Source: https://pytorch-lightning.readthedocs.io/en/latest/notebooks/course_UvA-DL/08-deep-autoencoders.html
# ! pip install --quiet "torch>=1.6, <1.9" "pytorch-lightning>=1.3" "torchvision" "seaborn" "torchmetrics>=0.3"
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from collections import OrderedDict, deque, namedtuple
from typing import List, Tuple
from torch.utils.data.dataset import IterableDataset

pl.seed_everything(42)
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

class Encoder(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int, act_fn: object = nn.GELU):
        """
        Args:
           input_dim: Number of input features
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            act_fn(),
            nn.Linear(latent_dim, latent_dim),
            act_fn(),
            nn.Linear(latent_dim, latent_dim),
        )

    def forward(self, x):
        return self.net(x)
    
class Decoder(nn.Module):
    def __init__(self, output_dim: int, latent_dim: int, act_fn: object = nn.GELU):
        """
        Args:
           output_dim: Number of output features
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, latent_dim),
            act_fn(),
            nn.Linear(latent_dim, latent_dim),
            act_fn(),
            nn.Linear(latent_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)
    
class Autoencoder(pl.LightningModule):
    def __init__(
        self,
        input_dim: int,
        latent_dim: int,
        encoder_class: object = Encoder,
        decoder_class: object = Decoder,
    ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(input_dim, latent_dim)
        self.decoder = decoder_class(input_dim, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, input_dim)

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)"""
        x = batch
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)
        
class ObsDataset(data.Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

In [32]:
dataset = ObsDataset(np.ones((100,3), dtype=np.float32))
train_loader = data.DataLoader(dataset, batch_size=32)
val_loader = data.DataLoader(dataset, batch_size=32)

model = Autoencoder(input_dim=3, latent_dim=16)
trainer = pl.Trainer(gpus=1, max_epochs=500)
trainer.fit(model, train_loader, val_loader)

In [36]:
"""
Variational encoder model, used as a visual model
for our model of the world.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class Decoder(nn.Module):
    """ VAE decoder """
    def __init__(self, output_dim=8, latent_dim=16, n_units=128):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.n_units = n_units

        self.fc1 = nn.Linear(latent_dim, n_units)
        self.fc2 = nn.Linear(n_units, n_units)
        self.fc3 = nn.Linear(n_units, output_dim)

    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
    """ VAE encoder """
    def __init__(self, input_dim=8, latent_dim=16, n_units=128):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.n_units = n_units
        
        self.fc1 = nn.Linear(input_dim, n_units)
        self.fc2 = nn.Linear(n_units, n_units)
        self.fc_mu = nn.Linear(n_units, latent_dim)
        self.fc_logsigma = nn.Linear(n_units, latent_dim)

    def forward(self, x): # pylint: disable=arguments-differ
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        mu = self.fc_mu(x)
        logsigma = self.fc_logsigma(x)
        return mu, logsigma

class VAE(nn.Module):
    """ Variational Autoencoder """
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(input_dim, latent_dim)

    def forward(self, x):
        mu, logsigma = self.encoder(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)

        recon_x = self.decoder(z)
        return recon_x, mu, logsigma

In [43]:
X = torch.from_numpy(np.ones((28,8), dtype=np.float32))

In [44]:
vae = VAE(8, 16)
vae(X)

(tensor([[ 1.7793e-01, -1.1581e-01,  9.9639e-02,  1.2234e-01,  2.7219e-01,
           1.6652e-01,  3.9213e-02,  1.3524e-02],
         [ 1.5171e-01,  1.0672e-01,  8.0299e-02,  2.1089e-01,  1.6755e-01,
           8.2226e-02, -3.6395e-02, -2.8142e-02],
         [ 3.1949e-01,  1.5323e-01,  5.4087e-02,  6.0500e-02,  2.6304e-01,
           6.5669e-02, -6.9476e-03,  3.0243e-02],
         [ 1.4787e-01, -5.6834e-02,  2.2940e-01,  1.7896e-01,  1.5475e-01,
           1.3417e-02,  4.5627e-02,  1.4061e-01],
         [ 1.0682e-01,  2.5448e-02,  4.2714e-02,  1.1468e-02,  1.8435e-01,
          -2.0532e-03,  9.6457e-03,  9.7366e-02],
         [ 1.6481e-01,  1.3478e-01,  1.2075e-01,  2.0632e-01,  1.7861e-01,
           3.3075e-02,  7.4786e-03,  7.9657e-02],
         [ 1.4285e-01, -9.6923e-02,  1.5459e-01,  9.4515e-02,  2.7128e-01,
           5.1347e-02,  1.1681e-01,  3.0045e-02],
         [ 8.5337e-02, -1.0595e-02,  8.1042e-02,  2.1737e-01,  1.5248e-01,
           6.2465e-02,  2.3276e-02,  1.2906e-01],


# MDN-RNN