Imitation Leaning experiments:
Conclusion: BC works. Others don't seem to be stable even on CartPole.

In [1]:
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper

In [2]:
env = gym.make("CartPole-v1")
rng = np.random.default_rng(0)

In [3]:
## Behavior Cloning (BC)

In [4]:
rng = np.random.default_rng(0)
env = gym.make("CartPole-v1")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(10000)

<stable_baselines3.ppo.ppo.PPO at 0x110b6fd30>

In [5]:
reward, _ = evaluate_policy(
    expert.policy,  # type: ignore[arg-type]
    env,
    n_eval_episodes=1,
    render=False,
)
print(f"Reward of expert: {reward}")

Reward of expert: 308.0


In [6]:
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)

In [7]:
transitions = rollout.flatten_trajectories(rollouts)

In [8]:
bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)

In [9]:
bc_trainer.train(n_epochs=10)

0batch [00:00, ?batch/s]

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 0         |
|    ent_loss       | -0.000693 |
|    entropy        | 0.693     |
|    epoch          | 0         |
|    l2_loss        | 0         |
|    l2_norm        | 36.5      |
|    loss           | 0.693     |
|    neglogp        | 0.694     |
|    prob_true_act  | 0.5       |
|    samples_so_far | 32        |
---------------------------------


103batch [00:00, 1027.63batch/s]
353batch [00:00, 1194.67batch/s][A
473batch [00:00, 1181.52batch/s][A

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 500       |
|    ent_loss       | -0.000621 |
|    entropy        | 0.621     |
|    epoch          | 2         |
|    l2_loss        | 0         |
|    l2_norm        | 42.1      |
|    loss           | 0.615     |
|    neglogp        | 0.616     |
|    prob_true_act  | 0.57      |
|    samples_so_far | 16032     |
---------------------------------


595batch [00:00, 1194.67batch/s]
845batch [00:00, 1224.23batch/s][A
968batch [00:00, 1211.46batch/s][A

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 1000      |
|    ent_loss       | -0.000613 |
|    entropy        | 0.613     |
|    epoch          | 4         |
|    l2_loss        | 0         |
|    l2_norm        | 42.3      |
|    loss           | 0.64      |
|    neglogp        | 0.641     |
|    prob_true_act  | 0.567     |
|    samples_so_far | 32032     |
---------------------------------


1090batch [00:00, 1212.40batch/s]
1343batch [00:01, 1238.64batch/s][A
1467batch [00:01, 1237.81batch/s][A

--------------------------------
| batch_size        | 32       |
| bc/               |          |
|    batch          | 1500     |
|    ent_loss       | -0.00063 |
|    entropy        | 0.63     |
|    epoch          | 6        |
|    l2_loss        | 0        |
|    l2_norm        | 42.9     |
|    loss           | 0.564    |
|    neglogp        | 0.564    |
|    prob_true_act  | 0.591    |
|    samples_so_far | 48032    |
--------------------------------



1719batch [00:01, 1245.14batch/s][A
1974batch [00:01, 1256.86batch/s][A

---------------------------------
| batch_size        | 32        |
| bc/               |           |
|    batch          | 2000      |
|    ent_loss       | -0.000615 |
|    entropy        | 0.615     |
|    epoch          | 8         |
|    l2_loss        | 0         |
|    l2_norm        | 43.1      |
|    loss           | 0.559     |
|    neglogp        | 0.56      |
|    prob_true_act  | 0.597     |
|    samples_so_far | 64032     |
---------------------------------



2226batch [00:01, 1248.07batch/s][A
2270batch [00:01, 1222.35batch/s][A


In [10]:
reward, _ = evaluate_policy(bc_trainer.policy, env, 50)
print("Reward:", reward)

Reward: 252.56


In [30]:
#### GAIL algorithm - train a PPO to fool a discriminator into thinking the trajectory is expert

In [11]:
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
import seals

In [12]:
env = gym.make("seals/CartPole-v0")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(20000)

<stable_baselines3.ppo.ppo.PPO at 0x154c73490>

In [13]:
rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    make_vec_env(
        "seals/CartPole-v0",
        n_envs=5,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        rng=rng,
    ),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=rng,
)

In [14]:
venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng)
learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
)
reward_net = BasicRewardNet(
    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net
)

In [15]:
learner_rewards_before_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)
print(np.array(learner_rewards_before_training).sum()/len(learner_rewards_before_training))

8.31


In [16]:
gail_trainer.train(300000)  # Note: set to 300000 for better results

round:   0%|                                             | 0/18 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 35.8     |
|    gen/time/fps             | 16474    |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 0        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.431    |
|    disc/disc_acc_expert             | 0.231    |
|    disc/disc_acc_gen                | 0.63     |
|    disc/disc_entropy                | 0.692    |
|    disc/disc_loss                   | 0.702    |
|    disc/disc_proportion_expert_pred | 0.301    |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   6%|██                                   | 1/18 [00:03<00:56,  3.35s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 30.2        |
|    gen/rollout/ep_rew_wrapped_mean | 327         |
|    gen/time/fps                    | 18377       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.008570074 |
|    gen/train/clip_fraction         | 0.0254      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.688      |
|    gen/train/explained_variance    | 0.0455      |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.614       |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00226    |
|    gen/train/value_loss            | 16.5   

round:  11%|████                                 | 2/18 [00:06<00:52,  3.27s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 32.1        |
|    gen/rollout/ep_rew_wrapped_mean | 334         |
|    gen/time/fps                    | 18286       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 49152       |
|    gen/train/approx_kl             | 0.006876415 |
|    gen/train/clip_fraction         | 0.0408      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.682      |
|    gen/train/explained_variance    | 0.625       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 0.4         |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.00113    |
|    gen/train/value_loss            | 5.25   

round:  17%|██████▏                              | 3/18 [00:09<00:48,  3.25s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 31.1         |
|    gen/rollout/ep_rew_wrapped_mean | 338          |
|    gen/time/fps                    | 18297        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 65536        |
|    gen/train/approx_kl             | 0.0061582853 |
|    gen/train/clip_fraction         | 0.0402       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.675       |
|    gen/train/explained_variance    | 0.567        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 2.46         |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.00161     |
|    gen/train/value_loss   

round:  22%|████████▏                            | 4/18 [00:13<00:45,  3.24s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 34.3         |
|    gen/rollout/ep_rew_wrapped_mean | 344          |
|    gen/time/fps                    | 18131        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 81920        |
|    gen/train/approx_kl             | 0.0047906237 |
|    gen/train/clip_fraction         | 0.0303       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.671       |
|    gen/train/explained_variance    | 0.526        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.327        |
|    gen/train/n_updates             | 40           |
|    gen/train/policy_gradient_loss  | -0.00175     |
|    gen/train/value_loss   

round:  28%|██████████▎                          | 5/18 [00:16<00:42,  3.25s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 33.5        |
|    gen/rollout/ep_rew_wrapped_mean | 345         |
|    gen/time/fps                    | 17632       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.005959993 |
|    gen/train/clip_fraction         | 0.0319      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.663      |
|    gen/train/explained_variance    | 0.669       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 1.75        |
|    gen/train/n_updates             | 50          |
|    gen/train/policy_gradient_loss  | -0.00126    |
|    gen/train/value_loss            | 13.1   

round:  33%|████████████▎                        | 6/18 [00:19<00:39,  3.26s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 30.2        |
|    gen/rollout/ep_rew_wrapped_mean | 344         |
|    gen/time/fps                    | 18094       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 114688      |
|    gen/train/approx_kl             | 0.006669498 |
|    gen/train/clip_fraction         | 0.0568      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.653      |
|    gen/train/explained_variance    | 0.762       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 8.96        |
|    gen/train/n_updates             | 60          |
|    gen/train/policy_gradient_loss  | -0.00417    |
|    gen/train/value_loss            | 13     

round:  39%|██████████████▍                      | 7/18 [00:22<00:35,  3.26s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 26.4        |
|    gen/rollout/ep_rew_wrapped_mean | 343         |
|    gen/time/fps                    | 18233       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.007987807 |
|    gen/train/clip_fraction         | 0.0897      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.621      |
|    gen/train/explained_variance    | 0.844       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 1.82        |
|    gen/train/n_updates             | 70          |
|    gen/train/policy_gradient_loss  | -0.00317    |
|    gen/train/value_loss            | 9.23   

round:  44%|████████████████▍                    | 8/18 [00:26<00:32,  3.28s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 25.3        |
|    gen/rollout/ep_rew_wrapped_mean | 342         |
|    gen/time/fps                    | 18278       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.007992975 |
|    gen/train/clip_fraction         | 0.0869      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.58       |
|    gen/train/explained_variance    | 0.914       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 1.42        |
|    gen/train/n_updates             | 80          |
|    gen/train/policy_gradient_loss  | -0.00251    |
|    gen/train/value_loss            | 9.12   

round:  50%|██████████████████▌                  | 9/18 [00:29<00:29,  3.27s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 28.4        |
|    gen/rollout/ep_rew_wrapped_mean | 340         |
|    gen/time/fps                    | 18117       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.007016823 |
|    gen/train/clip_fraction         | 0.0716      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.572      |
|    gen/train/explained_variance    | 0.888       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 7.09        |
|    gen/train/n_updates             | 90          |
|    gen/train/policy_gradient_loss  | -0.00169    |
|    gen/train/value_loss            | 12     

round:  56%|████████████████████                | 10/18 [00:32<00:26,  3.27s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 38.2        |
|    gen/rollout/ep_rew_wrapped_mean | 337         |
|    gen/time/fps                    | 17732       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 180224      |
|    gen/train/approx_kl             | 0.008935659 |
|    gen/train/clip_fraction         | 0.0707      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.54       |
|    gen/train/explained_variance    | 0.795       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 1.39        |
|    gen/train/n_updates             | 100         |
|    gen/train/policy_gradient_loss  | -0.00218    |
|    gen/train/value_loss            | 17.8   

round:  61%|██████████████████████              | 11/18 [00:35<00:23,  3.29s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 43.3         |
|    gen/rollout/ep_rew_wrapped_mean | 339          |
|    gen/time/fps                    | 18145        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 196608       |
|    gen/train/approx_kl             | 0.0039812494 |
|    gen/train/clip_fraction         | 0.043        |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.568       |
|    gen/train/explained_variance    | 0.682        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 6.37         |
|    gen/train/n_updates             | 110          |
|    gen/train/policy_gradient_loss  | -0.000628    |
|    gen/train/value_loss   

round:  67%|████████████████████████            | 12/18 [00:39<00:19,  3.28s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 44.2         |
|    gen/rollout/ep_rew_wrapped_mean | 345          |
|    gen/time/fps                    | 18277        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 212992       |
|    gen/train/approx_kl             | 0.0036734813 |
|    gen/train/clip_fraction         | 0.0435       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.552       |
|    gen/train/explained_variance    | 0.606        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 2.43         |
|    gen/train/n_updates             | 120          |
|    gen/train/policy_gradient_loss  | -0.000974    |
|    gen/train/value_loss   

round:  72%|██████████████████████████          | 13/18 [00:42<00:16,  3.27s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 40.9         |
|    gen/rollout/ep_rew_wrapped_mean | 356          |
|    gen/time/fps                    | 17914        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 229376       |
|    gen/train/approx_kl             | 0.0038760426 |
|    gen/train/clip_fraction         | 0.0296       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.543       |
|    gen/train/explained_variance    | 0.593        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 6.68         |
|    gen/train/n_updates             | 130          |
|    gen/train/policy_gradient_loss  | -0.000713    |
|    gen/train/value_loss   

round:  78%|████████████████████████████        | 14/18 [00:45<00:13,  3.27s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 41.8         |
|    gen/rollout/ep_rew_wrapped_mean | 356          |
|    gen/time/fps                    | 18213        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 245760       |
|    gen/train/approx_kl             | 0.0034987563 |
|    gen/train/clip_fraction         | 0.0388       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.544       |
|    gen/train/explained_variance    | 0.256        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 3.55         |
|    gen/train/n_updates             | 140          |
|    gen/train/policy_gradient_loss  | -0.000371    |
|    gen/train/value_loss   

round:  83%|██████████████████████████████      | 15/18 [00:49<00:09,  3.27s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 45.9         |
|    gen/rollout/ep_rew_wrapped_mean | 347          |
|    gen/time/fps                    | 18212        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 262144       |
|    gen/train/approx_kl             | 0.0043584188 |
|    gen/train/clip_fraction         | 0.0465       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.54        |
|    gen/train/explained_variance    | 0.245        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 31           |
|    gen/train/n_updates             | 150          |
|    gen/train/policy_gradient_loss  | -0.000356    |
|    gen/train/value_loss   

round:  89%|████████████████████████████████    | 16/18 [00:52<00:06,  3.27s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 48.8        |
|    gen/rollout/ep_rew_wrapped_mean | 332         |
|    gen/time/fps                    | 18174       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 278528      |
|    gen/train/approx_kl             | 0.004511253 |
|    gen/train/clip_fraction         | 0.0412      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.526      |
|    gen/train/explained_variance    | 0.171       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 30.7        |
|    gen/train/n_updates             | 160         |
|    gen/train/policy_gradient_loss  | -0.000516   |
|    gen/train/value_loss            | 49     

round:  94%|██████████████████████████████████  | 17/18 [00:55<00:03,  3.27s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 51.7         |
|    gen/rollout/ep_rew_wrapped_mean | 315          |
|    gen/time/fps                    | 18212        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 294912       |
|    gen/train/approx_kl             | 0.0038500326 |
|    gen/train/clip_fraction         | 0.0391       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.508       |
|    gen/train/explained_variance    | 0.176        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 41.4         |
|    gen/train/n_updates             | 170          |
|    gen/train/policy_gradient_loss  | -0.00133     |
|    gen/train/value_loss   

round: 100%|████████████████████████████████████| 18/18 [00:58<00:00,  3.27s/it]


In [17]:
learner_rewards_after_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)
print(np.array(learner_rewards_after_training).sum()/len(learner_rewards_after_training))

66.96


In [18]:
learner_rewards_after_training

[89.0,
 66.0,
 56.0,
 57.0,
 85.0,
 100.0,
 91.0,
 64.0,
 62.0,
 62.0,
 41.0,
 67.0,
 76.0,
 86.0,
 56.0,
 61.0,
 67.0,
 57.0,
 55.0,
 58.0,
 52.0,
 66.0,
 80.0,
 58.0,
 82.0,
 35.0,
 81.0,
 56.0,
 33.0,
 60.0,
 70.0,
 119.0,
 42.0,
 71.0,
 60.0,
 21.0,
 73.0,
 86.0,
 64.0,
 58.0,
 48.0,
 47.0,
 30.0,
 56.0,
 82.0,
 81.0,
 45.0,
 74.0,
 77.0,
 72.0,
 71.0,
 80.0,
 108.0,
 90.0,
 111.0,
 81.0,
 65.0,
 85.0,
 61.0,
 67.0,
 57.0,
 59.0,
 86.0,
 68.0,
 55.0,
 66.0,
 127.0,
 71.0,
 59.0,
 68.0,
 67.0,
 24.0,
 53.0,
 83.0,
 71.0,
 61.0,
 70.0,
 62.0,
 30.0,
 73.0,
 59.0,
 96.0,
 54.0,
 38.0,
 57.0,
 66.0,
 31.0,
 58.0,
 69.0,
 72.0,
 60.0,
 65.0,
 84.0,
 86.0,
 86.0,
 87.0,
 71.0,
 53.0,
 94.0,
 46.0]

In [None]:
#### COPY PASTE OF GAIL CODE AGAIN

In [None]:
import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
expert = PPO(policy=MlpPolicy, env=env, n_steps=64)
expert.learn(1000)

rollouts = rollout.rollout(
    expert,
    make_vec_env(
        "seals/CartPole-v0",
        n_envs=5,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        rng=rng,
    ),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=rng,
)

venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng)
learner = PPO(env=venv, policy=MlpPolicy)
reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)

gail_trainer.train(20000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)

In [None]:
### AIRL CODE PASTE TO TRY

In [20]:
import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.airl import AIRL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
import seals

rng = np.random.default_rng(0)

env = gym.make("seals/CartPole-v0")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(50000)

rollouts = rollout.rollout(
    expert,
    make_vec_env(
        "seals/CartPole-v0",
        rng=rng,
        n_envs=5,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
    ),
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=rng,
)

venv = make_vec_env("seals/CartPole-v0", rng=rng, n_envs=8)
learner = PPO(env=venv, policy=MlpPolicy)
reward_net = BasicShapedRewardNet(
    venv.observation_space,
    venv.action_space,
    normalize_input_layer=RunningNorm,
)
airl_trainer = AIRL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)
airl_trainer.train(200000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)

round:   0%|                                             | 0/12 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 29.8     |
|    gen/time/fps             | 17021    |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 0        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 1        |
|    disc/disc_acc_gen                | 0        |
|    disc/disc_entropy                | 0.602    |
|    disc/disc_loss                   | 0.806    |
|    disc/disc_proportion_expert_pred | 1        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   8%|███                                  | 1/12 [00:03<00:36,  3.36s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 25.3        |
|    gen/rollout/ep_rew_wrapped_mean | 354         |
|    gen/time/fps                    | 16997       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.009204772 |
|    gen/train/clip_fraction         | 0.0408      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.685      |
|    gen/train/explained_variance    | -0.0582     |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 3.06        |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.003      |
|    gen/train/value_loss            | 88.8   

round:  17%|██████▏                              | 2/12 [00:06<00:33,  3.33s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 25           |
|    gen/rollout/ep_rew_wrapped_mean | 340          |
|    gen/time/fps                    | 17010        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 49152        |
|    gen/train/approx_kl             | 0.0115054175 |
|    gen/train/clip_fraction         | 0.0935       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.67        |
|    gen/train/explained_variance    | 0.668        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 12.9         |
|    gen/train/n_updates             | 20           |
|    gen/train/policy_gradient_loss  | -0.00547     |
|    gen/train/value_loss   

round:  25%|█████████▎                           | 3/12 [00:09<00:29,  3.33s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 21.2        |
|    gen/rollout/ep_rew_wrapped_mean | 277         |
|    gen/time/fps                    | 17027       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.014686596 |
|    gen/train/clip_fraction         | 0.143       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.638      |
|    gen/train/explained_variance    | 0.903       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 2.9         |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.00609    |
|    gen/train/value_loss            | 21.3   

round:  33%|████████████▎                        | 4/12 [00:13<00:26,  3.35s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 20.2         |
|    gen/rollout/ep_rew_wrapped_mean | 250          |
|    gen/time/fps                    | 17047        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 81920        |
|    gen/train/approx_kl             | 0.0070656263 |
|    gen/train/clip_fraction         | 0.0459       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.596       |
|    gen/train/explained_variance    | 0.899        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 4.53         |
|    gen/train/n_updates             | 40           |
|    gen/train/policy_gradient_loss  | -0.00214     |
|    gen/train/value_loss   

round:  42%|███████████████▍                     | 5/12 [00:16<00:23,  3.34s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 16.5        |
|    gen/rollout/ep_rew_wrapped_mean | 231         |
|    gen/time/fps                    | 17002       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.009441974 |
|    gen/train/clip_fraction         | 0.0758      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.534      |
|    gen/train/explained_variance    | 0.885       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 4.53        |
|    gen/train/n_updates             | 50          |
|    gen/train/policy_gradient_loss  | -0.00321    |
|    gen/train/value_loss            | 19.2   

round:  50%|██████████████████▌                  | 6/12 [00:19<00:19,  3.33s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 14.2        |
|    gen/rollout/ep_rew_wrapped_mean | 241         |
|    gen/time/fps                    | 17057       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 114688      |
|    gen/train/approx_kl             | 0.013196757 |
|    gen/train/clip_fraction         | 0.106       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.472      |
|    gen/train/explained_variance    | 0.934       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 2.65        |
|    gen/train/n_updates             | 60          |
|    gen/train/policy_gradient_loss  | -0.00709    |
|    gen/train/value_loss            | 9.82   

round:  58%|█████████████████████▌               | 7/12 [00:23<00:16,  3.31s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 12          |
|    gen/rollout/ep_rew_wrapped_mean | 213         |
|    gen/time/fps                    | 17040       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.003924359 |
|    gen/train/clip_fraction         | 0.0634      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.408      |
|    gen/train/explained_variance    | 0.958       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 6.19        |
|    gen/train/n_updates             | 70          |
|    gen/train/policy_gradient_loss  | -0.00362    |
|    gen/train/value_loss            | 4.08   

round:  67%|████████████████████████▋            | 8/12 [00:26<00:13,  3.32s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 11.1         |
|    gen/rollout/ep_rew_wrapped_mean | 152          |
|    gen/time/fps                    | 17189        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 147456       |
|    gen/train/approx_kl             | 0.0028734354 |
|    gen/train/clip_fraction         | 0.0365       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.383       |
|    gen/train/explained_variance    | 0.963        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 1.59         |
|    gen/train/n_updates             | 80           |
|    gen/train/policy_gradient_loss  | -0.000981    |
|    gen/train/value_loss   

round:  75%|███████████████████████████▊         | 9/12 [00:29<00:09,  3.31s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 11          |
|    gen/rollout/ep_rew_wrapped_mean | 92.1        |
|    gen/time/fps                    | 16629       |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 0           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.017922994 |
|    gen/train/clip_fraction         | 0.0559      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.293      |
|    gen/train/explained_variance    | 0.936       |
|    gen/train/learning_rate         | 0.0003      |
|    gen/train/loss                  | 2.13        |
|    gen/train/n_updates             | 90          |
|    gen/train/policy_gradient_loss  | -0.00269    |
|    gen/train/value_loss            | 2.21   

round:  83%|██████████████████████████████      | 10/12 [00:33<00:06,  3.31s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 10.6         |
|    gen/rollout/ep_rew_wrapped_mean | 27.6         |
|    gen/time/fps                    | 17109        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 180224       |
|    gen/train/approx_kl             | 0.0021357045 |
|    gen/train/clip_fraction         | 0.0223       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.28        |
|    gen/train/explained_variance    | 0.919        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.571        |
|    gen/train/n_updates             | 100          |
|    gen/train/policy_gradient_loss  | -0.000156    |
|    gen/train/value_loss   

round:  92%|█████████████████████████████████   | 11/12 [00:36<00:03,  3.30s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 10.5         |
|    gen/rollout/ep_rew_wrapped_mean | -36.8        |
|    gen/time/fps                    | 17012        |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 0            |
|    gen/time/total_timesteps        | 196608       |
|    gen/train/approx_kl             | 0.0032767793 |
|    gen/train/clip_fraction         | 0.0195       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.252       |
|    gen/train/explained_variance    | 0.908        |
|    gen/train/learning_rate         | 0.0003       |
|    gen/train/loss                  | 0.634        |
|    gen/train/n_updates             | 110          |
|    gen/train/policy_gradient_loss  | -0.00102     |
|    gen/train/value_loss   

round: 100%|████████████████████████████████████| 12/12 [00:39<00:00,  3.32s/it]


Rewards: [7.0, 8.0, 9.0, 8.0, 8.0, 9.0, 9.0, 7.0, 8.0, 9.0, 7.0, 8.0, 9.0, 7.0, 8.0, 8.0, 8.0, 9.0, 9.0, 8.0, 8.0, 9.0, 8.0, 10.0, 9.0, 8.0, 8.0, 8.0, 7.0, 9.0, 7.0, 8.0, 7.0, 9.0, 8.0, 9.0, 9.0, 8.0, 9.0, 9.0, 9.0, 9.0, 9.0, 8.0, 7.0, 9.0, 8.0, 7.0, 9.0, 9.0, 9.0, 9.0, 7.0, 8.0, 8.0, 8.0, 7.0, 10.0, 8.0, 8.0, 9.0, 9.0, 9.0, 8.0, 7.0, 8.0, 9.0, 7.0, 7.0, 8.0, 9.0, 8.0, 9.0, 8.0, 7.0, 9.0, 7.0, 9.0, 10.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 8.0, 9.0, 9.0, 9.0, 8.0, 8.0, 9.0, 9.0]
