# Stable-Baselines3 Introduction

Training an A2C agent on CartPole

In [None]:
import gymnasium as gym
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
from stable_baselines3 import A2C

# Create environment without rendering for training
env = gym.make("CartPole-v1")

# Train the model
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)

print("Training complete!")

In [None]:
# Test the trained agent with visualization in notebook
test_env = gym.make("CartPole-v1", render_mode="rgb_array")
obs, _ = test_env.reset()

episode_count = 0
episode_steps = 0

for i in range(50000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = test_env.step(action)
    episode_steps += 1
    
    # Render frame in notebook (every 10 steps to avoid too many updates)
    if i % 10 == 0:
        frame = test_env.render()
        plt.figure(figsize=(5, 4))
        plt.imshow(frame)
        plt.axis('off')
        plt.title(f"Step {i+1} | Episode {episode_count+1} | Ep Steps: {episode_steps}")
        clear_output(wait=True)
        display(plt.gcf())
        plt.close()
    
    if terminated or truncated:
        print(f"Episode {episode_count+1} finished after {episode_steps} steps")
        obs, _ = test_env.reset()  # Reset for next episode instead of breaking
        episode_count += 1
        episode_steps = 0

test_env.close()
env.close()