In [1]:
!pip install -q imitation gymnasium
import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

FAST = True

if FAST:
    N_RL_TRAIN_STEPS = 100_000
else:
    N_RL_TRAIN_STEPS = 2_000_000

venv = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=venv,
)

from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    venv,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

from imitation.algorithms.adversarial.airl import AIRL
from imitation.rewards.reward_nets import BasicShapedRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy


learner = PPO(
    env=venv,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0005,
    gamma=0.95,
    clip_range=0.1,
    vf_coef=0.1,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicShapedRewardNet(
    observation_space=venv.observation_space,
    action_space=venv.action_space,
    normalize_input_layer=RunningNorm,
)
airl_trainer = AIRL(
    demonstrations=rollouts,
    demo_batch_size=2048,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=16,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
)

venv.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)
airl_trainer.train(N_RL_TRAIN_STEPS)
venv.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, venv, 100, return_episode_rewards=True
)

print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

2025-01-05 01:17:03.303069: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
round:   0%|          | 0/6 [00:00<?, ?it/s]

------------------------------------------
| raw/                        |          |
|    gen/rollout/ep_len_mean  | 500      |
|    gen/rollout/ep_rew_mean  | 33.1     |
|    gen/time/fps             | 2481     |
|    gen/time/iterations      | 1        |
|    gen/time/time_elapsed    | 6        |
|    gen/time/total_timesteps | 16384    |
------------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.581    |
|    disc/disc_acc_expert             | 1        |
|    disc/disc_acc_gen                | 0.162    |
|    disc/disc_entropy                | 0.664    |
|    disc/disc_loss                   | 0.676    |
|    disc/disc_proportion_expert_pred | 0.919    |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 2.05e+03 |
|    disc/n_generated                 | 2.05e+03 |
-

round:  17%|█▋        | 1/6 [00:15<01:19, 15.85s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 500         |
|    gen/rollout/ep_rew_mean         | 34.6        |
|    gen/rollout/ep_rew_wrapped_mean | -525        |
|    gen/time/fps                    | 2134        |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 7           |
|    gen/time/total_timesteps        | 32768       |
|    gen/train/approx_kl             | 0.001363658 |
|    gen/train/clip_fraction         | 0.0238      |
|    gen/train/clip_range            | 0.1         |
|    gen/train/entropy_loss          | -0.692      |
|    gen/train/explained_variance    | -0.0116     |
|    gen/train/learning_rate         | 0.0005      |
|    gen/train/loss                  | 3.17        |
|    gen/train/n_updates             | 5           |
|    gen/train/policy_gradient_loss  | 7.75e-06    |
|    gen/train/value_loss            | 117    

round:  33%|███▎      | 2/6 [00:32<01:05, 16.44s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 35.4         |
|    gen/rollout/ep_rew_wrapped_mean | -1.47e+03    |
|    gen/time/fps                    | 2396         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 6            |
|    gen/time/total_timesteps        | 49152        |
|    gen/train/approx_kl             | 0.0010964653 |
|    gen/train/clip_fraction         | 0.00289      |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | 0.178        |
|    gen/train/learning_rate         | 0.0005       |
|    gen/train/loss                  | 171          |
|    gen/train/n_updates             | 10           |
|    gen/train/policy_gradient_loss  | -7.06e-06    |
|    gen/train/value_loss   

round:  50%|█████     | 3/6 [00:48<00:48, 16.18s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 38.3         |
|    gen/rollout/ep_rew_wrapped_mean | -1.52e+03    |
|    gen/time/fps                    | 2924         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 5            |
|    gen/time/total_timesteps        | 65536        |
|    gen/train/approx_kl             | 0.0016215986 |
|    gen/train/clip_fraction         | 0.0487       |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.691       |
|    gen/train/explained_variance    | 0.66         |
|    gen/train/learning_rate         | 0.0005       |
|    gen/train/loss                  | 89.1         |
|    gen/train/n_updates             | 15           |
|    gen/train/policy_gradient_loss  | -0.00034     |
|    gen/train/value_loss   

round:  67%|██████▋   | 4/6 [01:00<00:28, 14.49s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 40.9         |
|    gen/rollout/ep_rew_wrapped_mean | -1.69e+03    |
|    gen/time/fps                    | 3016         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 5            |
|    gen/time/total_timesteps        | 81920        |
|    gen/train/approx_kl             | 0.0029653944 |
|    gen/train/clip_fraction         | 0.149        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.687       |
|    gen/train/explained_variance    | 0.877        |
|    gen/train/learning_rate         | 0.0005       |
|    gen/train/loss                  | 4.79         |
|    gen/train/n_updates             | 20           |
|    gen/train/policy_gradient_loss  | -0.00277     |
|    gen/train/value_loss   

round:  83%|████████▎ | 5/6 [01:14<00:14, 14.17s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_len_mean         | 500          |
|    gen/rollout/ep_rew_mean         | 43.9         |
|    gen/rollout/ep_rew_wrapped_mean | -1.37e+03    |
|    gen/time/fps                    | 2643         |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 6            |
|    gen/time/total_timesteps        | 98304        |
|    gen/train/approx_kl             | 0.0028083713 |
|    gen/train/clip_fraction         | 0.168        |
|    gen/train/clip_range            | 0.1          |
|    gen/train/entropy_loss          | -0.685       |
|    gen/train/explained_variance    | 0.798        |
|    gen/train/learning_rate         | 0.0005       |
|    gen/train/loss                  | 2.62         |
|    gen/train/n_updates             | 25           |
|    gen/train/policy_gradient_loss  | -0.00353     |
|    gen/train/value_loss   

round: 100%|██████████| 6/6 [01:27<00:00, 14.60s/it]


Rewards before training: 102.6 +/- 24.11514047232568
Rewards after training: 53.66 +/- 2.790053762922858
