In [3]:
import gymnasium as gym
from imitation.algorithms.adversarial.gail import GAIL
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env
from imitation.rewards.reward_nets import BasicRewardNet
from stable_baselines3 import PPO
from stable_baselines3.a2c import MlpPolicy
from imitation.data import rollout
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
# https://imitation.readthedocs.io/en/latest/_api/imitation.data.huggingface_utils.html#imitation.data.huggingface_utils.TrajectoryDatasetSequence look at this for link
gym.register(
    id='WordleGame-v0',
    entry_point='Wordle:WordleEnv',
    max_episode_steps=6
)

env = gym.make('WordleGame-v0')
#check_env(env)

venv = make_vec_env(
    "WordleGame-v0",
    rng=np.random.default_rng(),
    n_envs=1,
    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
)
#rollouts = trajectories
rollouts = np.load('data/trajectories_all.npy', allow_pickle=True)

transitions = rollout.flatten_trajectories_with_rew(rollouts)
learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=None,
)

reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=transitions,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True
)

# evaluate the learner before training
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


In [6]:
# train the learner and evaluate again
gail_trainer.train(20000)  # Train for 800_000 steps to match expert.
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)

round:   0%|          | 0/9 [00:00<?, ?it/s]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 6          |
|    gen/rollout/ep_rew_mean         | -8.14      |
|    gen/rollout/ep_rew_wrapped_mean | 4.05       |
|    gen/time/fps                    | 80         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 25         |
|    gen/time/total_timesteps        | 4096       |
|    gen/train/approx_kl             | 0.04498999 |
|    gen/train/clip_fraction         | 0.502      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -9.47      |
|    gen/train/explained_variance    | -0.158     |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.147     |
|    gen/train/n_updates             | 5          |
|    gen/train/policy_gradient_loss  | -0.113     |
|    gen/train/value_loss            | 0.369      |
------------

round:  11%|█         | 1/9 [00:33<04:24, 33.08s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 6          |
|    gen/rollout/ep_rew_mean         | -8.37      |
|    gen/rollout/ep_rew_wrapped_mean | 4.55       |
|    gen/time/fps                    | 71         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 28         |
|    gen/time/total_timesteps        | 6144       |
|    gen/train/approx_kl             | 0.05243803 |
|    gen/train/clip_fraction         | 0.688      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -9.46      |
|    gen/train/explained_variance    | 0.833      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.101     |
|    gen/train/n_updates             | 10         |
|    gen/train/policy_gradient_loss  | -0.123     |
|    gen/train/value_loss            | 0.12       |
------------

round:  22%|██▏       | 2/9 [01:10<04:10, 35.80s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 6          |
|    gen/rollout/ep_rew_mean         | -8.21      |
|    gen/rollout/ep_rew_wrapped_mean | 4.43       |
|    gen/time/fps                    | 72         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 28         |
|    gen/time/total_timesteps        | 8192       |
|    gen/train/approx_kl             | 0.06104154 |
|    gen/train/clip_fraction         | 0.702      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -9.45      |
|    gen/train/explained_variance    | 0.881      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.0772    |
|    gen/train/n_updates             | 15         |
|    gen/train/policy_gradient_loss  | -0.12      |
|    gen/train/value_loss            | 0.142      |
------------

round:  33%|███▎      | 3/9 [01:48<03:38, 36.47s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 6          |
|    gen/rollout/ep_rew_mean         | -8.68      |
|    gen/rollout/ep_rew_wrapped_mean | 3.86       |
|    gen/time/fps                    | 68         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 29         |
|    gen/time/total_timesteps        | 10240      |
|    gen/train/approx_kl             | 0.07242447 |
|    gen/train/clip_fraction         | 0.693      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -9.45      |
|    gen/train/explained_variance    | 0.827      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | -0.0727    |
|    gen/train/n_updates             | 20         |
|    gen/train/policy_gradient_loss  | -0.119     |
|    gen/train/value_loss            | 0.22       |
------------

round:  44%|████▍     | 4/9 [02:26<03:06, 37.33s/it]

--------------------------------------------------
| raw/                               |           |
|    gen/rollout/ep_len_mean         | 6         |
|    gen/rollout/ep_rew_mean         | -8.2      |
|    gen/rollout/ep_rew_wrapped_mean | 3.53      |
|    gen/time/fps                    | 69        |
|    gen/time/iterations             | 1         |
|    gen/time/time_elapsed           | 29        |
|    gen/time/total_timesteps        | 12288     |
|    gen/train/approx_kl             | 0.0825832 |
|    gen/train/clip_fraction         | 0.699     |
|    gen/train/clip_range            | 0.2       |
|    gen/train/entropy_loss          | -9.44     |
|    gen/train/explained_variance    | 0.708     |
|    gen/train/learning_rate         | 0.0004    |
|    gen/train/loss                  | -0.00246  |
|    gen/train/n_updates             | 25        |
|    gen/train/policy_gradient_loss  | -0.117    |
|    gen/train/value_loss            | 0.424     |
-------------------------------

round:  56%|█████▌    | 5/9 [03:05<02:30, 37.70s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 6           |
|    gen/rollout/ep_rew_mean         | -8.06       |
|    gen/rollout/ep_rew_wrapped_mean | 3.23        |
|    gen/time/fps                    | 68          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 29          |
|    gen/time/total_timesteps        | 14336       |
|    gen/train/approx_kl             | 0.091529116 |
|    gen/train/clip_fraction         | 0.717       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -9.43       |
|    gen/train/explained_variance    | 0.567       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.154       |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.113      |
|    gen/train/value_loss            | 0.652  

round:  67%|██████▋   | 6/9 [03:43<01:54, 38.00s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_len_mean         | 6          |
|    gen/rollout/ep_rew_mean         | -8.02      |
|    gen/rollout/ep_rew_wrapped_mean | 2.75       |
|    gen/time/fps                    | 63         |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 32         |
|    gen/time/total_timesteps        | 16384      |
|    gen/train/approx_kl             | 0.09535798 |
|    gen/train/clip_fraction         | 0.686      |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -9.41      |
|    gen/train/explained_variance    | 0.413      |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 0.111      |
|    gen/train/n_updates             | 35         |
|    gen/train/policy_gradient_loss  | -0.11      |
|    gen/train/value_loss            | 0.972      |
------------

round:  78%|███████▊  | 7/9 [04:26<01:19, 39.51s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 6           |
|    gen/rollout/ep_rew_mean         | -8.35       |
|    gen/rollout/ep_rew_wrapped_mean | 2.79        |
|    gen/time/fps                    | 65          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 31          |
|    gen/time/total_timesteps        | 18432       |
|    gen/train/approx_kl             | 0.093715414 |
|    gen/train/clip_fraction         | 0.656       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -9.4        |
|    gen/train/explained_variance    | 0.344       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.377       |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.108      |
|    gen/train/value_loss            | 1.28   

round:  89%|████████▉ | 8/9 [05:08<00:40, 40.40s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_len_mean         | 6           |
|    gen/rollout/ep_rew_mean         | -8.16       |
|    gen/rollout/ep_rew_wrapped_mean | 2.57        |
|    gen/time/fps                    | 53          |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 38          |
|    gen/time/total_timesteps        | 20480       |
|    gen/train/approx_kl             | 0.104001865 |
|    gen/train/clip_fraction         | 0.675       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -9.38       |
|    gen/train/explained_variance    | 0.32        |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.465       |
|    gen/train/n_updates             | 45          |
|    gen/train/policy_gradient_loss  | -0.109      |
|    gen/train/value_loss            | 1.43   

round: 100%|██████████| 9/9 [05:58<00:00, 39.80s/it]


In [18]:
learner.save('gail')
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 5, render=True
)
print("mean reward after training:", np.mean(learner_rewards_after_training))
print("mean reward before training:", np.mean(learner_rewards_before_training))

TypeError: BaseAlgorithm.save() missing 1 required positional argument: 'path'