In [2]:
%matplotlib inline
from models import vae, rnn
import numpy as np
import gym
from gym.spaces.box import Box
from gym.envs.box2d.car_racing import CarRacing
import numpy as np
from skimage.transform import resize
import matplotlib.pyplot as plt
import os
from models.vae import ConvVAE
from models.rnn import MDMRNN
import torch
from torch import nn, optim
from torch.nn import functional as F
import glob
import random
from IPython.display import clear_output
import time
import pickle
import math
import matplotlib.pyplot as plt
import os

num_mixtures=5
hidden_size=256
input_size=35
num_layers=1
batch_size=1
sequence_len=1000
output_size=32


def init_hidden(num_layers, batch_size, hidden_size, device):
    hidden = torch.zeros(num_layers, batch_size, hidden_size, device=device)
    cell = torch.zeros(num_layers, batch_size, hidden_size, device=device)
    return hidden, cell

def data_iterator(batch_size):
    data_files = glob.glob('data/obs_data_VAE_*')
    while True:
        states_list = []
        actions_list = []
        target_states_list = []
        for i in range(batch_size):
            data = pickle.load( open( random.sample(data_files, batch_size)[0], "rb" ) )
            states, actions, target_states = zip(*data)
            states_list += states
            actions_list += actions
            target_states_list += target_states
        states = np.array(states_list)
        actions = np.array(actions_list)
        target_states = np.array(target_states_list)
        states = np.moveaxis(states, 3, 1) # Reshape so that channels first
        target_states = np.moveaxis(target_states, 3, 1) # Reshape so that channels first
        yield states, actions, target_states

if __name__ == "__main__":
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"]="2"
    
    # Create model object. Load trained model if already exists
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vae_model = vae.ConvVAE()
    
    
    if os.path.exists("checkpoints/vae_checkpoint.pth"):
        vae_model.load_state_dict(torch.load("checkpoints/vae_checkpoint.pth"))
    vae_model.to(device)
    vae_model = vae_model.eval()
    
    rnn_model = rnn.MDMRNN(num_mixtures, hidden_size, input_size, num_layers, batch_size, sequence_len, output_size)
    #rnn_model = nn.DataParallel(rnn_model)
    if os.path.exists("checkpoints/rnn_checkpoint.pth"):
        rnn_model.load_state_dict(torch.load("checkpoints/rnn_checkpoint.pth"))
    rnn_model.to(device)
    rnn_model = rnn_model.eval()
    
    training_data = data_iterator(batch_size)
    
    i = 0
    while True:
        states, actions, target_states = next(training_data)
        states = torch.tensor(states, device=device)
        actions = torch.tensor(actions, device=device)
        target_states = torch.tensor(target_states, device=device)
        
        
        embedded, recon_data, mu, logvar = vae_model(states)
        y, recon_data, mu, logvar = vae_model(target_states)
        
        
        hidden_state, cell_state = init_hidden(num_layers, batch_size, hidden_size, device)
        actions = actions.type(torch.float32)
        inputs = torch.cat((embedded, actions), dim=1)
        pi, mean, sigma, hidden, cell = rnn_model(inputs, hidden_state, cell_state)
        recon_next_state = rnn.reparameterize(pi, mean, sigma)
        recon_next_state = vae_model.decode(recon_next_state)
        
        states = states.cpu().numpy()
        recon_data = recon_data.detach().cpu().numpy()
        target_states = target_states.detach().cpu().numpy()

        recon_next_state = recon_next_state.detach().cpu().numpy()

        states = np.moveaxis(states, 1, 3)
        recon_data = np.moveaxis(recon_data, 1, 3)
        target_states = np.moveaxis(target_states, 1, 3)


        recon_next_state = np.moveaxis(recon_next_state, 1, 3)
        for i in range(states.shape[0]):
            clear_output(wait=True)
            w=10
            h=10
            fig=plt.figure(figsize=(8, 8))
            columns = 2
            rows = 2
            fig.add_subplot(rows, columns, 1)

            plt.imshow(np.reshape(states[i],(64, 64, 3)))
            fig.add_subplot(rows, columns, 2)
            plt.imshow(np.reshape(recon_data[i],(64, 64, 3)))
            fig.add_subplot(rows, columns, 3)
            plt.imshow(np.reshape(target_states[i],(64, 64, 3)))
            fig.add_subplot(rows, columns, 4)
            plt.imshow(np.reshape(recon_next_state[i],(64, 64, 3)))
            plt.show()

KeyboardInterrupt: 