In [1]:
#!/usr/bin/env python3
import numpy as np
import gym
import os
import io
import base64
import multiprocessing

import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
import time
from IPython import display
%matplotlib inline

from stable_baselines.common.cmd_util import mujoco_arg_parser
from stable_baselines import bench, logger
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env.vec_normalize import VecNormalize
from stable_baselines.ppo2 import PPO2
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv

def get_newest_model(env_id):
    dir_name = 'models_{}/'.format(env_id)
    files = os.listdir(dir_name)
    return max([dir_name + f for f in files], key=os.path.getctime)

  from ._conv import register_converters as _register_converters


In [6]:
env.venv.envs[0]._get_viewer('rgb_array').cam.trackbodyid = 1

In [None]:
trackbodyid = 0

In [None]:
#env_id='HalfCheetah-v2'
env_id='Ant-v2'
seed=123
load_model = True

#env = gym.make(env_id)

from gym.envs.mujoco.ant import AntEnv

class AntEnvMod(AntEnv):
    
    def __init__(self):
        super(AntEnv, self).__init__('ant.xml', 5)
        self._get_viewer('rgb_array').cam.distance = self.model.stat.extent * 0.5
        self.max_zoom = self.model.stat.extent * 2.
        
    def update_view(self):
        cam = self._get_viewer('rgb_array').cam
        x, y, z = self.get_body_com('torso')
        self._get_viewer('rgb_array').cam.lookat[0] = x
        self._get_viewer('rgb_array').cam.lookat[1] = y
        self._get_viewer('rgb_array').cam.lookat[2] = z
        if self._get_viewer('rgb_array').cam.distance < self.max_zoom:
            self._get_viewer('rgb_array').cam.distance *= 1.05
        
env = AntEnvMod()
env = DummyVecEnv([lambda: env])
env = VecNormalize(env)

if load_model:
    model = PPO2.load("model_{}".format(env_id))
else:
    policy = MlpPolicy
    model = PPO2(policy=policy, env=env, n_steps=2048, nminibatches=32, lam=0.95, gamma=0.99, noptepochs=10,
                 ent_coef=0.0, learning_rate=3e-4, cliprange=0.2, verbose=1, tensorboard_log='./{}/'.format(env_id))


logger.log("Running trained model")
obs = np.zeros((env.num_envs,) + env.observation_space.shape)
obs[:] = env.reset()

total_reward = 0
count = 0
frames = []
while True:
    actions = model.step(obs)[0]
    #obs[:] = env.step(actions)[0]
    
    obs[:], reward, done, info = env.step(actions)
    total_reward += reward
    
    #d = env.venv.envs[0].env.sim.render(500,500)
    d = env.venv.envs[0].sim.render(500,500) # for custom env
    env.venv.envs[0].update_view()
    #plt.imshow(d, origin='lower')
    #plt.show()
    frames.append(d)
        
    count += 1
    if done.any():
        print("Reward:", total_reward)
        print("Iters :", count)
        break

Loading a model without an environment, this model cannot be trained until it has a valid environment.
Running trained model


In [None]:
fig, ax = plt.subplots()
plt.axis('off')
l = ax.imshow(frames[0], origin='lower')

def animate(i):
    l.set_data(frames[i])

ani = animation.FuncAnimation(fig, animate, frames=len(frames), interval=50, repeat_delay=1000)

ani.save('{}.mp4'.format(env_id))

video = io.open('{}.mp4'.format(env_id), 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''<video alt="test" controls>
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii')))