In [15]:
import os
import gym
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines3 import PPO
from stable_baselines3.ppo.policies import MlpPolicy
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.evaluation import evaluate_policy


## Create the environment and define the agent

In [23]:
env = gym.make("CartPole-v1")
model = PPO(MlpPolicy, env, verbose=0)



In [20]:
# Evaluate untrained agent
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100, warn=False)
print(f"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}")

mean_reward: 35.89 +/- 8.26


## Train the agent

In [25]:
model.learn(total_timesteps=10_000)

<stable_baselines3.ppo.ppo.PPO at 0x14d101150>

In [None]:
# Evaluate the trained agent
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
print(f"mean_reward: {mean_reward:.2f} +/- {std_reward:.2f}")

### Persist the trained agent

In [None]:
model.save("./saves/ppo_cartpole")

## Evaluate the agent

### Load agent from memory

In [None]:
model_loaded = PPO.load("./saves/ppo_cartpole")

In [None]:
env_render = gym.make("CartPole-v1", render_mode="human")

observation, info = env_render.reset(seed=42)
for _ in range(500):
   action, _states = model.predict(observation, deterministic=True)
   observation, reward, terminated, truncated, info = env_render.step(action)

   if terminated or truncated:
      observation, info = env_render.reset()

env_render.close()