In [1]:
import os

import gym
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines.common import set_global_seeds
from stable_baselines import PPO2
from stable_baselines.bench import Monitor
from stable_baselines.results_plotter import load_results, ts2xy
from stable_baselines.gail import generate_expert_traj, ExpertDataset

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


#### Extract Trajectory

In [2]:
env_id = "CartPole-v1"
expert_path = "tmp/gym/CartPole/behaviour_cloning/expert_model.zip"
save_path = "tmp/gym/CartPole/behaviour_cloning/expert_model_traj.npz"

In [3]:
if os.path.exists(save_path) is False:
    env = DummyVecEnv([lambda: gym.make(env_id)])
    env.max_episode_steps = 500
    model = PPO2.load(expert_path, env=env)
    generate_expert_traj(model, save_path=save_path, env=env, n_episodes=1000)
else:
    print('trajectory already exists. Extraction will not be performed')

trajectory already exists. Extraction will not be performed


In [4]:
dataset = ExpertDataset(expert_path=save_path, batch_size=128)

actions (498858, 1)
obs (498858, 4)
rewards (498858, 1)
episode_returns (1000,)
episode_starts (498858,)
Total trajectories: -1
Total transitions: 498858
Average returns: 498.858
Std for returns: 18.93720771391601


#### Pretraining

In [5]:
best_mean_reward, n_steps = -np.inf, 0

def callback(_locals, _globals):
    """
    Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
    :param _locals: (dict)
    :param _globals: (dict)
    """
    global n_steps, best_mean_reward
    # Print stats every 1000 calls
    if (n_steps + 1) % 10 == 0:
        # Evaluate policy training performance
        x, y = ts2xy(load_results(log_dir), 'timesteps')
        if len(x) > 0:
            mean_reward = np.mean(y[-100:])
            print(x[-1], 'timesteps')
            print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(best_mean_reward, mean_reward))

            # New best model, you could save the agent here
            if mean_reward > best_mean_reward:
                best_mean_reward = mean_reward
                # Example for saving best model
                print("Saving new best model")
                _locals['self'].save(log_dir + 'best_model.pkl')
    n_steps += 1
    return True

In [6]:
env_id = "CartPole-v1"
env = gym.make(env_id)
env = DummyVecEnv([lambda: env])

In [7]:
model = PPO2(MlpPolicy, env, verbose=1)
model.pretrain(dataset, n_epochs=100, val_interval=1)

Pretraining with Behavior Cloning...
==== Training progress 1.00% ====
Epoch 1
Training loss: 0.552960, Validation loss: 0.526862

==== Training progress 2.00% ====
Epoch 2
Training loss: 0.522406, Validation loss: 0.518123

==== Training progress 3.00% ====
Epoch 3
Training loss: 0.519236, Validation loss: 0.517809

==== Training progress 4.00% ====
Epoch 4
Training loss: 0.519103, Validation loss: 0.517788

==== Training progress 5.00% ====
Epoch 5
Training loss: 0.519108, Validation loss: 0.517702

==== Training progress 6.00% ====
Epoch 6
Training loss: 0.519120, Validation loss: 0.517585

==== Training progress 7.00% ====
Epoch 7
Training loss: 0.519264, Validation loss: 0.517529

==== Training progress 8.00% ====
Epoch 8
Training loss: 0.519108, Validation loss: 0.517818

==== Training progress 9.00% ====
Epoch 9
Training loss: 0.519132, Validation loss: 0.517596

==== Training progress 10.00% ====
Epoch 10
Training loss: 0.519204, Validation loss: 0.517634

==== Training progres

Process Process-1:
Process Process-2:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python3.6/site-packages/stable_baselines/gail/dataset/dataset.py", line 290, in _run
    self.queue.put((obs, actions))
  File "/Users/jiachengweng/opt/anaconda3/envs/stable-baselines/lib/python

KeyboardInterrupt: 

#### Pretrain Visualization

In [None]:
obs = env.reset()
for i in range(2000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()