<a href="https://colab.research.google.com/github/emilianodesu/RLA2/blob/main/cartpole/cart-pole-ppo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PPO on classic control

This notebook will explore the implementation of PPO from Stable Baselines3 on a classic-control environment.

### What you will learn?
* Using PPO from Stable Baseline3
* Training on cart pole environment

## Setup
Install necessary dependencies.

In [None]:
!pip install gymnasium
!pip install "stable-baselines3[extra]"
!pip install moviepy
!sudo apt-get update
!apt-get install -y xvfb ffmpeg freeglut3-dev

Import libraries

In [None]:
# Import necessary libraries and modules
import os
import gymnasium as gym
import stable_baselines3
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import base64
from IPython import display as ipythondisplay

print(f"{gym.__version__=}")
print(f"{stable_baselines3.__version__=}")

Configure environment

In [None]:
# Set up a fake display for rendering videos in the cloud environment
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ['DISPLAY'] = ':1'

## Callbacks and directory setup

In [None]:
# Callback for saving the model at regular intervals
class SaveOnIntervalCallback(BaseCallback):
    def __init__(self, save_interval: int, save_path: str, verbose=1):
        super().__init__(verbose)
        self.save_interval = save_interval
        self.save_path = save_path

    def _on_step(self) -> bool:
        if self.num_timesteps % self.save_interval == 0:
            save_file = os.path.join(self.save_path, f'model_{self.num_timesteps}')
            self.model.save(save_file)
            if self.verbose > 0:
                print(f'Saving model to {save_file}.zip')
        return True

In [None]:
# Creating directories for storing logs and models
log_dir = "ppo/logs/"
models_dir = "ppo/models/"
videos_dir = "ppo/videos/"

os.makedirs(log_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

## Cart pole environment

In [None]:
# ### CHANGE 1: Create the CartPole Environment ###
env_id = "CartPole-v1"
# For PPO, it's common to use multiple environments in parallel
env = DummyVecEnv([lambda: Monitor(gym.make(env_id), f"{log_dir}/{i}") for i in range(4)])

## PPO Model

In [None]:
policy_kwargs = dict(net_arch=dict(pi=[32, 32], vf=[32, 32]))

model = PPO(
    "MlpPolicy",
    env,
    verbose=1,
    n_steps=256,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.2,
    ent_coef=0.0,
    learning_rate=3e-4,
    policy_kwargs=policy_kwargs # Use our custom network
)

## Train the model

In [None]:
# Set up the callback for saving models
save_interval = 10000
save_callback = SaveOnIntervalCallback(save_interval=save_interval, save_path=models_dir)

# ### CHANGE 4: Reduce total timesteps ###
total_timesteps = 50000

# Train the DQN agent
model.learn(total_timesteps=total_timesteps, callback=save_callback)

# Save the final trained model
final_model_path = os.path.join(models_dir, f'model_{total_timesteps}')
model.save(final_model_path)

## Performance evaluation

In [None]:
def plot_results(log_folder: str):
    """
    Plots the training curve from the Monitor log file.
    :param log_folder: the save directory of the Monitor logs
    """
    x, y = ts2xy(load_results(log_folder), 'timesteps')

    # Smooth the curve
    y_smooth = np.convolve(y, np.ones(100)/100, mode='valid')
    x_smooth = x[len(x) - len(y_smooth):]

    fig = plt.figure("Training Curve")
    plt.plot(x_smooth, y_smooth, label="Smoothed Reward")
    plt.plot(x, y, alpha=0.2, label="Raw Reward")
    plt.xlabel("Number of Timesteps")
    plt.ylabel("Rewards")
    plt.title("Training Curve")
    plt.legend()
    plt.show()

# Call the function after training to see the learning curve
plot_results(log_dir)

## Video Recording and Display Functions

In [None]:
# Functions to record and show videos of the agent playing

def show_videos(video_path="", prefix=""):
    html = []
    for mp4 in Path(video_path).glob(f"{prefix}*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            f'''<video alt="{mp4.name}" autoplay loop controls style="height: 400px;">
                  <source src="data:video/mp4;base64,{video_b64.decode('ascii')}" type="video/mp4" />
             </video>'''
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

# ### MODIFICATION ###: Updated video folder and env_id
def record_video(env_id, model, video_length=500, prefix="", video_folder="ppo/videos"):
    os.makedirs(video_folder, exist_ok=True)
    eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
    eval_env = VecVideoRecorder(
        eval_env,
        video_folder=video_folder,
        record_video_trigger=lambda step: step == 0,
        video_length=video_length,
        name_prefix=prefix,
    )
    obs = eval_env.reset()
    for _ in range(video_length):
        action, _ = model.predict(obs, deterministic=True) # Use deterministic for evaluation
        obs, _, _, _ = eval_env.step(action)
    eval_env.close()

In [None]:
# Functions for analyzing the trained models

# Function to get the identifiers of saved models
def get_model_identifiers(models_dir):
    files = os.listdir(models_dir)  # Listing files in the models directory
    model_files = [f for f in files if f.startswith('model_')]  # Filtering out model files
    identifiers = [f.split('_')[1] for f in model_files]  # Extracting identifiers from file names
    return identifiers

# Function to find key identifiers (earliest, middle, final)
def find_key_identifiers(identifiers):
    identifiers.sort()  # Sorting identifiers
    earliest = identifiers[0]  # Earliest identifier
    final = identifiers[-1]  # Final identifier
    middle = identifiers[len(identifiers) // 2]  # Middle identifier
    return earliest, middle, final

# Function to view videos of the models at different training stages
def view(models_dir):
    identifiers = get_model_identifiers(models_dir)  # Getting model identifiers
    print(identifiers)
    earliest, middle, final = find_key_identifiers(identifiers)  # Finding key identifiers

    # Recording and displaying videos at the beginning, middle, and end of training
    for stage, identifier in zip(["beginning", "middle", "end"], [earliest, middle, final]):
        model_path = os.path.join(models_dir, f'model_{identifier}')  # Forming the model path
        model = DQN.load(model_path)  # Loading the model
        record_video("CartPole-v1", model, video_length=5000, prefix=f'ppo-cartpole-{stage}')  # Recording video
        show_videos("ppo/videos/", prefix=f'ppo-cartpole-{stage}')  # Showing videos

In [None]:
# Changes the model directory below
models_dir="ppo/models"
view(models_dir)  # Calling the view function