# MDNRNN model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from PIL import Image
import pandas as pd
import numpy as np
import os

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

## Init VAE Model

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=32,
                               kernel_size=4,
                               stride=2
                              )
        self.conv2 = nn.Conv2d(in_channels=32,
                               out_channels=64,
                               kernel_size=4,
                               stride=2
                              )
        self.conv3 = nn.Conv2d(in_channels=64,
                               out_channels=128,
                               kernel_size=4,
                               stride=2
                              )
        self.conv4 = nn.Conv2d(in_channels=128,
                               out_channels=256,
                               kernel_size=4,
                               stride=2
                              )
        
        self.fc_mu = nn.Linear(in_features=2*2*256, out_features=latent_dim)
        self.fc_logvar = nn.Linear(in_features=2*2*256, out_features=latent_dim)
        
        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.conv3(x))
        x = self.activation(self.conv4(x))
        x = x.view(x.shape[0], -1)
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        
        return x_mu, x_logvar

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(in_features=latent_dim, out_features=1024)
        
        self.conv4 = nn.ConvTranspose2d(in_channels=1024,
                                       out_channels=128,
                                       kernel_size=5,
                                       stride=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=128,
                                       out_channels=64,
                                       kernel_size=5,
                                       stride=2)
        self.conv2 = nn.ConvTranspose2d(in_channels=64,
                                       out_channels=32,
                                       kernel_size=6,
                                       stride=2)
        self.conv1 = nn.ConvTranspose2d(in_channels=32,
                                       out_channels=3,
                                       kernel_size=6,
                                       stride=2)
        
        self.ReLU_activation = nn.ReLU()
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 1024, 1, 1)
        x = self.ReLU_activation(self.conv4(x))
        x = self.ReLU_activation(self.conv3(x))
        x = self.ReLU_activation(self.conv2(x))
        x = torch.sigmoid(self.conv1(x))
        return x
        

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
        
    def latent_sample(self, mu, logvar):
        if self.training:
            std = (logvar * 0.5).exp()
            return torch.distributions.Normal(loc=mu, scale=std).rsample()
        else:
            return mu

In [None]:
vae = torch.load("./BACKUP_MODELS/vae").to(device)

# Dataset

In [None]:
def Create_dataset(vae, portion_train, path_rollouts_csv="./DATASET_ROLLOUTS/rollouts.csv"):
    # load dataset from csv if it already exists
    if os.path.exists("./DATASET_ROLLOUTS/dataset_rnn.csv"):
        res_df = pd.read_csv("./DATASET_ROLLOUTS/dataset_rnn.csv", sep=";", skipinitialspace=True)
        print(f"Dataset loaded from csv")
    else:
        transform = transforms.PILToTensor()
        df_rollouts = pd.read_csv(path_rollouts_csv, sep=";", skipinitialspace=True)
        res_paths = []
        i = 0
        # iterate for each rollout
        for idx_roll in range(len(df_rollouts)):
            # retrieve paths of the rollout
            path_obs_csv = df_rollouts.iloc[idx_roll, 2]
            path_obs = df_rollouts.iloc[idx_roll, 1]
            
            # load csv rollout
            obs_df = pd.read_csv(path_obs_csv, sep=";", skipinitialspace=True)
            
            # stack each frame of the rollout to be computed in batch with the vae
            stack_frames = []
            stack_next_frames = []
            
            # add each frame to be stacked
            for idx in range(len(obs_df)):
                stack_frames.append(transform(Image.open(obs_df.iloc[idx, 1])).float()/255)
                stack_next_frames.append(transform(Image.open(obs_df.iloc[idx, 3])).float()/255)
            
            # stack frames
            stack_frames = torch.stack(stack_frames)
            stack_next_frames = torch.stack(stack_next_frames)
            
            # compute mu and sigma
            with torch.no_grad():
                frame_mu, frame_sigma = vae.encoder(stack_frames.to(device))
                next_frame_mu, next_frame_sigma = vae.encoder(stack_next_frames.to(device))
                
            # create df of the results
            new_obs_df = pd.concat([pd.DataFrame(frame_mu.cpu().detach().numpy()),
                                pd.DataFrame(frame_sigma.cpu().detach().numpy()),
                                pd.DataFrame(next_frame_mu.cpu().detach().numpy()),
                                pd.DataFrame(next_frame_sigma.cpu().detach().numpy()),
                                obs_df[["Action"]]],
                                axis=1
                               )
            
            # save df into rollout directory
            path_roll_csv = os.path.join(path_obs, "rnn.csv")
            new_obs_df.to_csv(path_roll_csv, sep=";", index=False)
            
            # add result to main dataframe
            res_paths.append(path_roll_csv)
            
            print(f"Observation {i+1}/10000")
            i+=1
            
        # save resulting dataframes into csv
        res_df = pd.DataFrame(res_paths, columns=["Path rollout csv"])
        res_df.to_csv("./DATASET_ROLLOUTS/dataset_rnn.csv", sep=";", index=False)
        
    # create train and test set
    n_train = int(len(res_df) * portion_train)
    trainset = Trainset(res_df)
    print("Train set created")
    testset = None
    if portion_train < 1:
        res_df = res_df.drop(list(trainset.df.index.values))
        testset = Testset(res_df)
        print("Test set created")
    return trainset, testset

class Trainset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        rollout_df = pd.read_csv(self.df.iloc[idx, 0], sep=";", skipinitialspace=True)
        
        # retrieve tensors from dataframe
        frame_mu = torch.from_numpy(rollout_df.iloc[:, :512].to_numpy().astype(np.float32)).to(device)
        frame_sigma = torch.from_numpy(rollout_df.iloc[:, 512:1024].to_numpy().astype(np.float32)).to(device)
        next_frame_mu = torch.from_numpy(rollout_df.iloc[:, 1024:1536].to_numpy().astype(np.float32)).to(device)
        next_frame_sigma = torch.from_numpy(rollout_df.iloc[:, 1536:2048].to_numpy().astype(np.float32)).to(device)
        action = torch.from_numpy(rollout_df[["Action"]].to_numpy().astype(np.float32)).to(device) / 5
        
        # compute returning tensors, by sampling x each time from mu and sigma
        std = (frame_sigma * 0.5).exp()
        x = torch.distributions.Normal(loc=frame_mu, scale=std).rsample()
        x = torch.cat((x, action), dim=1)
        y = next_frame_mu
        return (x,y)
    
class Testset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        rollout_df = pd.read_csv(self.df.iloc[idx, 0], sep=";", skipinitialspace=True)
        
        # retrieve tensors from dataframe
        frame_mu = torch.from_numpy(rollout_df.iloc[:, :512].to_numpy().astype(np.float32)).to(device)
        frame_sigma = torch.from_numpy(rollout_df.iloc[:, 512:1024].to_numpy().astype(np.float32)).to(device)
        next_frame_mu = torch.from_numpy(rollout_df.iloc[:, 1024:1536].to_numpy().astype(np.float32)).to(device)
        next_frame_sigma = torch.from_numpy(rollout_df.iloc[:, 1536:2048].to_numpy().astype(np.float32)).to(device)
        
        # tensor of normalized actions
        action = torch.from_numpy(rollout_df[["Action"]].to_numpy().astype(np.float32)).to(device)
        
        # compute returning tensors, by sampling x each time from mu and sigma
        std = (frame_sigma * 0.5).exp()
        x = torch.distributions.Normal(loc=frame_mu, scale=std).rsample()
        x = torch.cat((x, action), dim=1)
        y = next_frame_mu
        return (x,y)

## Init MDRNN Model

In [None]:
class MDNRNN(nn.Module):
    def __init__(self, hidden_units, z_dim, num_layers, n_gaussians):
        super().__init__()
        self.hidden_units = hidden_units
        self.z_dim = z_dim
        self.num_layers = num_layers
        self.n_gaussians = n_gaussians
        self.hidden = None
        self.cell = None
        
        # RNN
        self.lstm = nn.LSTM(self.z_dim+1, self.hidden_units, batch_first=True)
        
        # MDN
        # weights for the results of the gaussians
        self.z_pi = nn.Linear(self.hidden_units, self.n_gaussians*(self.z_dim))
        # parameters of the gaussians
        self.z_sigma = nn.Linear(self.hidden_units, self.n_gaussians*(self.z_dim))
        self.z_mu = nn.Linear(self.hidden_units, self.n_gaussians*(self.z_dim))
        
    
    def forward(self, x):
        # init the lstm if it is the first run of the sequence
        if self.hidden == None and self.cell == None:
            z, state = self.lstm(x)
            self.hidden, self.cell = state
        else:
            # otherwise run the lstm with the current hidden and cell states
            z, state = self.lstm(x, (self.hidden, self.cell))
            self.hidden, self.cell = state
        
        # unpack values from the packed sequences
        z, _ = nn.utils.rnn.pad_packed_sequence(z, batch_first=True)
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        seq_len = x.shape[1]
        
        # compute the values of pi
        pi = self.z_pi(z).view(-1, seq_len, self.n_gaussians, self.z_dim)
        # transform the pi values to let them sum to 1
        pi = F.softmax(pi, dim=2)
        # transform sigmas with exponential to ensures they are all positive
        sigma = torch.exp(self.z_sigma(z)).view(-1, seq_len, self.n_gaussians, self.z_dim)
        # compute mus
        mu = self.z_mu(z).view(-1, seq_len, self.n_gaussians, self.z_dim)
        return pi, sigma, mu
    
    def reset_state(self):
        self.hidden = None
        self.cell = None
        

## Train

In [None]:
def loss_mdnrnn(out_pi, out_sigma, out_mu, y):
    # log-likelihood of output and ground truth
    seq_len = y.shape[1]
    latent_dim = y.shape[2]
    y = y.view(-1, seq_len, 1, latent_dim)
    loss = torch.distributions.Normal(loc=out_mu, scale=out_sigma)
    loss = torch.exp(loss.log_prob(y))
    loss = torch.sum(loss * out_pi, dim=2)
    loss = -torch.log(1e-3 + loss)
    return torch.mean(loss)

In [None]:
# obtain trainset and testset
trainset, testset = Create_dataset(None, 1)

In [None]:
# to run sequences in batch they need to be padded in order to have the same sequence lenght
def collate_fn_padd(batch):
    len_x = [x.shape[0] for x,y in batch]
    len_y = [y.shape[0] for x,y in batch]
    
    # padding the sequences to the same lenght
    x = nn.utils.rnn.pad_sequence([t[0] for t in batch], batch_first=True)
    y = nn.utils.rnn.pad_sequence([t[1] for t in batch], batch_first=True)
    
    # packing the sequences to be given in input to the lstm
    X = nn.utils.rnn.pack_padded_sequence(x, len_x, batch_first=True, enforce_sorted=False)
    Y = nn.utils.rnn.pack_padded_sequence(y, len_y, batch_first=True, enforce_sorted=False)
    return (X, Y)

In [None]:
lr = 1e-3
hidden_units = 1024
z_dim = 512
n_gaussians = 16

mdnrnn = MDNRNN(hidden_units=hidden_units, z_dim=z_dim, num_layers=1, n_gaussians=n_gaussians).to(device)

optimizer = torch.optim.Adam(params=mdrnn.parameters(), lr=lr)
batch_size = 10
i=0
while(True):
    losses = []
    n_batch = 1
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
    for x, y in train_loader:
        # unpack y values
        y, _ = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
        
        # run mdnrnn
        x = x.to(device)
        mdnrnn.reset_state()
        res_pi, res_sigma, res_mu = mdrnn(x)
        
        # compute loss and backward
        loss = loss_mdnrnn(res_pi, res_sigma, res_mu, y)
        optimizer.zero_grad()
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        
    print(f"EPOCH: {i+1} MEAN LOSSES EPOCH: {sum(losses)/len(losses)}")
    print(f"Loss: {loss.item()}")
    i+=1
    torch.save(mdnrnn, f"./BACKUP_MODELS/mdrnn")