In [1]:
!pip install -q imitation gymnasium
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

gail_trainer.train(200_000)

env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

2025-01-05 01:15:00.282035: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
round:   0%|          | 0/12 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 29.8     |
|    gen/time/fps             | 2605     |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 6        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 1        |
|    disc/disc_entropy                | 0.69     |
|    disc/disc_loss                   | 0.696    |
|    disc/disc_proportion_expert_pred | 0        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
-

round:   8%|▊         | 1/12 [00:16<02:57, 16.11s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 31.9        |
|    gen/rollout/ep_rew_wrapped_mean | 268         |
|    gen/time/fps                    | 3144        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.009048734 |
|    gen/train/clip_fraction         | 0.0295      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.686      |
|    gen/train/explained_variance    | 0.0301      |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.127       |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | -0.0015     |
|    gen/train/value_loss            | 4.43   

round:  17%|█▋        | 2/12 [00:28<02:16, 13.68s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 33.9        |
|    gen/rollout/ep_rew_wrapped_mean | 275         |
|    gen/time/fps                    | 2558        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 6           |
|    gen/time/total_timesteps        | 49152       |
|    gen/train/approx_kl             | 0.010725586 |
|    gen/train/clip_fraction         | 0.132       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.686      |
|    gen/train/explained_variance    | 0.841       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0172      |
|    gen/train/n_updates             | 10          |
|    gen/train/policy_gradient_loss  | -0.00802    |
|    gen/train/value_loss            | 0.247  

round:  25%|██▌       | 3/12 [00:43<02:11, 14.59s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 37.7        |
|    gen/rollout/ep_rew_wrapped_mean | 277         |
|    gen/time/fps                    | 2086        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 65536       |
|    gen/train/approx_kl             | 0.016089637 |
|    gen/train/clip_fraction         | 0.207       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.673      |
|    gen/train/explained_variance    | 0.82        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.019      |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.0131     |
|    gen/train/value_loss            | 0.0429 

round:  33%|███▎      | 4/12 [00:58<01:58, 14.76s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 39.2        |
|    gen/rollout/ep_rew_wrapped_mean | 283         |
|    gen/time/fps                    | 2815        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 81920       |
|    gen/train/approx_kl             | 0.017104484 |
|    gen/train/clip_fraction         | 0.24        |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.654      |
|    gen/train/explained_variance    | 0.892       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0361     |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0214     |
|    gen/train/value_loss            | 0.0187 

round:  42%|████▏     | 5/12 [01:11<01:37, 13.88s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 37.6        |
|    gen/rollout/ep_rew_wrapped_mean | 285         |
|    gen/time/fps                    | 3102        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 98304       |
|    gen/train/approx_kl             | 0.010718755 |
|    gen/train/clip_fraction         | 0.126       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.636      |
|    gen/train/explained_variance    | 0.882       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0112     |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.01       |
|    gen/train/value_loss            | 0.0116 

round:  50%|█████     | 6/12 [01:24<01:22, 13.77s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 34.3         |
|    gen/rollout/ep_rew_wrapped_mean | 282          |
|    gen/time/fps                    | 3280         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 4            |
|    gen/time/total_timesteps        | 114688       |
|    gen/train/approx_kl             | 0.0073878746 |
|    gen/train/clip_fraction         | 0.0788       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.629       |
|    gen/train/explained_variance    | 0.891        |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | -0.0101      |
|    gen/train/n_updates             | 30           |
|    gen/train/policy_gradient_loss  | -0.00389     |
|    gen/train/value_loss   

round:  58%|█████▊    | 7/12 [01:37<01:06, 13.37s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 30.4        |
|    gen/rollout/ep_rew_wrapped_mean | 275         |
|    gen/time/fps                    | 2677        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 6           |
|    gen/time/total_timesteps        | 131072      |
|    gen/train/approx_kl             | 0.008307765 |
|    gen/train/clip_fraction         | 0.0816      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.623      |
|    gen/train/explained_variance    | 0.938       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.0135     |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00574    |
|    gen/train/value_loss            | 0.0138 

round:  67%|██████▋   | 8/12 [01:52<00:55, 13.95s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 30.9        |
|    gen/rollout/ep_rew_wrapped_mean | 260         |
|    gen/time/fps                    | 3269        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 147456      |
|    gen/train/approx_kl             | 0.007876955 |
|    gen/train/clip_fraction         | 0.0751      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.624      |
|    gen/train/explained_variance    | 0.928       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00911     |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00404    |
|    gen/train/value_loss            | 0.0253 

round:  75%|███████▌  | 9/12 [02:03<00:39, 13.11s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 35.8        |
|    gen/rollout/ep_rew_wrapped_mean | 242         |
|    gen/time/fps                    | 3182        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 5           |
|    gen/time/total_timesteps        | 163840      |
|    gen/train/approx_kl             | 0.008375626 |
|    gen/train/clip_fraction         | 0.0865      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.627      |
|    gen/train/explained_variance    | 0.909       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.00897     |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.00455    |
|    gen/train/value_loss            | 0.0462 

round:  83%|████████▎ | 10/12 [02:16<00:26, 13.10s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 44.4        |
|    gen/rollout/ep_rew_wrapped_mean | 226         |
|    gen/time/fps                    | 2565        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 6           |
|    gen/time/total_timesteps        | 180224      |
|    gen/train/approx_kl             | 0.012100431 |
|    gen/train/clip_fraction         | 0.146       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.62       |
|    gen/train/explained_variance    | 0.925       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | -0.00326    |
|    gen/train/n_updates             | 50          |
|    gen/train/policy_gradient_loss  | -0.00836    |
|    gen/train/value_loss            | 0.0602 

round:  92%|█████████▏| 11/12 [02:32<00:13, 13.92s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 500        |
|    gen/rollout/ep_rew_mean         | 52.9       |
|    gen/rollout/ep_rew_wrapped_mean | 209        |
|    gen/time/fps                    | 2621       |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 6          |
|    gen/time/total_timesteps        | 196608     |
|    gen/train/approx_kl             | 0.01598726 |
|    gen/train/clip_fraction         | 0.207      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.607     |
|    gen/train/explained_variance    | 0.956      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.0126     |
|    gen/train/n_updates             | 55         |
|    gen/train/policy_gradient_loss  | -0.0112    |
|    gen/train/value_loss            | 0.0597     |
------------

round: 100%|██████████| 12/12 [02:45<00:00, 13.81s/it]


Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 177.13 +/- 181.77814252544223
