In [1]:
import argparse
import collections
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import dmc_remastered as dmcr
import gym
import envs

from dm_control import suite
from matplotlib import animation
from algos.dreamer_mpc import DreamerMPC
from algos.dreamer_sac import DreamerSAC
from algos.dreamer_value import DreamerValue
from wrappers.action_repeat_wrapper import ActionRepeat
from wrappers.frame_stack_wrapper import FrameStack
from wrappers.gym_wrapper import GymWrapper
from wrappers.pixel_observation_wrapper import PixelObservation

os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
def save_video_as_gif(frames):
    """
    make video with given frames and save as "video_prediction.gif"
    """
    plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        plt.title('Left: GT frame' + ' '*20 + 'Right: predicted frame \n Step %d' % (i))

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=150)
    anim.save('video_prediction.gif', writer='imagemagick')

In [3]:
args = {
    'train_args_dir': '/Users/jan/Development/Projects/thesis/output/FetchReachRandom-v2-2021-05-25-14-03-53/config/args.json',
    'load_model_dir': '/Users/jan/Development/Projects/thesis/output/FetchReachRandom-v2-2021-05-25-14-03-53/model/model_final',
    'video_length': 100
}

with open(args['train_args_dir']) as json_file:
    config = json.load(json_file)
config.update(args)
keys = config.keys()
values = config.values()
args = collections.namedtuple('args', keys)(* values)

In [4]:
# create env
if args.env_type == 'dm_control':
    if args.randomize_env:
        _, env = dmcr.benchmarks.visual_generalization(args.domain_name, args.task_name, num_levels=100)
    else:
        env = suite.load(args.domain_name, args.task_name, task_kwargs={'random': args.seed})
        env = GymWrapper(env)
elif args.env_type == 'gym':
    env = gym.make(args.env_name)

# augment observations by pixel values
env = PixelObservation(env, args.observation_size)

# stack several consecutive frames together
env = FrameStack(env, args.frame_stack)

# repeat actions
env = ActionRepeat(env, args.action_repeat)

# define models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

env_args = {
    'observation_shape': env.observation_space.shape,
    'action_dim': env.action_space.shape[0],
    'action_range': [
        float(env.action_space.low.min()),
        float(env.action_space.high.max())
    ]
}
config.update(env_args)
keys = config.keys()
values = config.values()
args = collections.namedtuple('args', keys)(*values)

# algorithm
if hasattr(args, 'sac_batch_size'):
    algorithm = DreamerSAC(env, None, None, device, args)
elif hasattr(args, 'value_eps'):
    algorithm = DreamerValue(env, None, None, device, args)
elif hasattr(args, 'controller_type'):
    algorithm = DreamerMPC(env, None, None, device, args)
else:
    raise ValueError('configuration file not valid')

# load models
algorithm.load_model(args.load_model_dir)

eval_episodes = 10
eval_steps = 250
for _ in range(eval_episodes):
    obs = env.reset()
    not_done = True
    while not_done:
        action = algorithm.agent.get_action(obs)
        obs, _, done, _ = env.step(action)
        not_done = not(done)
        obs = torch.as_tensor(obs, device=device)
        env.render()

Creating window glfw
Creating window glfw
