## Vanilla Policy Gradient

Code written by Alan Cooney

Copied from https://github.com/alan-cooney/rl-from-scratch/blob/main/src/vanilla_policy_gradient.py

In [22]:
%pip install gymnasium > /dev/null 2>&1
import copy
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam, Optimizer
import numpy as np
import gymnasium  # type: ignore

from matplotlib import animation
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.pyplot as plt
from IPython.display import HTML, Video, Image
from base64 import b64encode
import datetime

Note: you may need to restart the kernel to use updated packages.


In [2]:
def create_model(number_observation_features: int, number_actions: int) -> nn.Module:
    """Create the MLP model

    Args:
        number_observation_features (int): Number of features in the (flat)
        observation tensor
        number_actions (int): Number of actions

    Returns:
        nn.Module: Simple MLP model
    """
    hidden_layer_features = 32

    return nn.Sequential(
        nn.Linear(in_features=number_observation_features,
                  out_features=hidden_layer_features),
        nn.ReLU(),
        nn.Linear(in_features=hidden_layer_features,
                  out_features=number_actions),
    )


def get_policy(model: nn.Module, observation: np.ndarray) -> Categorical:
    """Get the policy from the model, for a specific observation

    Args:
        model (nn.Module): MLP model
        observation (np.ndarray): Environment observation

    Returns:
        Categorical: Multinomial distribution parameterized by model logits
    """
    observation_tensor = torch.as_tensor(observation, dtype=torch.float32)
    logits = model(observation_tensor)

    # Categorical will also normalize the logits for us
    return Categorical(logits=logits)


def get_action(policy: Categorical) -> tuple[int, torch.Tensor]:
    """Sample an action from the policy

    Args:
        policy (Categorical): Policy

    Returns:
        tuple[int, torch.Tensor]: Tuple of the action and it's log probability
    """
    action = policy.sample()  # Unit tensor

    # Converts to an int, as this is what Gym environments require
    action_int = int(action.item())

    # Calculate the log probability of the action, which is required for
    # calculating the loss later
    log_probability_action = policy.log_prob(action)

    return action_int, log_probability_action


def calculate_loss(epoch_log_probability_actions: torch.Tensor, epoch_action_rewards: torch.Tensor) -> torch.Tensor:
    """Calculate the 'loss' required to get the policy gradient

    Formula for gradient at
    https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#deriving-the-simplest-policy-gradient

    Note that this isn't really loss - it's just the sum of the log probability
    of each action times the episode return. We calculate this so we can
    back-propagate to get the policy gradient.

    Args:
        epoch_log_probability_actions (torch.Tensor): Log probabilities of the
            actions taken
        epoch_action_rewards (torch.Tensor): Rewards for each of these actions

    Returns:
        torch.Tensor: Pseudo-loss
    """
    return -(epoch_log_probability_actions * epoch_action_rewards).mean()


def train_one_epoch(
        env: gymnasium.Env,
        model: nn.Module,
        optimizer: Optimizer,
        max_timesteps=5000,
        episode_timesteps=200) -> float:
    """Train the model for one epoch

    Args:
        env (gymnasium.Env): Gymnasium environment
        model (nn.Module): Model
        optimizer (Optimizer): Optimizer
        max_timesteps (int, optional): Max timesteps per epoch. Note if an
            episode is part-way through, it will still complete before finishing
            the epoch. Defaults to 5000.
        episode_timesteps (int, optional): Timesteps per episode. Defaults to 200.

    Returns:
        float: Average return from the epoch
    """
    epoch_total_timesteps = 0

    # Returns from each episode (to keep track of progress)
    epoch_returns: list[float] = []

    # Action log probabilities and rewards per step (for calculating loss)
    epoch_log_probability_actions = []
    epoch_action_rewards = []

    # Loop through episodes
    while True:

        # Stop if we've done over the total number of timesteps
        if epoch_total_timesteps > max_timesteps:
            break

        # Running total of this episode's rewards
        episode_reward: float = 0

        # Reset the environment and get a fresh observation
        observation, _ = env.reset()

        # Loop through timesteps until the episode is done (or the max is hit)
        for timestep in range(episode_timesteps):
            epoch_total_timesteps += 1

            # Get the policy and act
            policy = get_policy(model, observation)
            action, log_probability_action = get_action(policy)
            observation, reward, terminated, truncated, _ = env.step(action)
            done = truncated or terminated

            # Increment the episode rewards
            episode_reward += reward

            # Add epoch action log probabilities
            epoch_log_probability_actions.append(log_probability_action)

            # Finish the action loop if this episode is done
            if done is True:
                # Add one reward per timestep
                for _ in range(timestep + 1):
                    epoch_action_rewards.append(episode_reward)
                break

        # Increment the epoch returns
        epoch_returns.append(episode_reward)

    # Line up the log probabilities and rewards
    epoch_log_probability_actions = torch.stack(epoch_log_probability_actions)
    epoch_action_rewards = torch.tensor(epoch_action_rewards, dtype=torch.float32)
    shortest = min(epoch_log_probability_actions.shape[0], 
                      epoch_action_rewards.shape[0])

    # Calculate the policy gradient
    epoch_loss = calculate_loss(
        epoch_log_probability_actions[:shortest],
        epoch_action_rewards[:shortest]
    )

    # Update the weights
    epoch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return float(np.mean(epoch_returns))

In [13]:
def save_frames_as_gif(frames, path='./', filename=None, filetype="gif", fps=10):

    fig = plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)

    if filename is None:
        now = str(datetime.datetime.now())[:19].replace(" ", "_")
        filename = f"{now}.{filetype}"
    else:
        fig.suptitle(filename, fontsize=36)
        if not filename.endswith(filetype):
            filename += f".{filetype}"

    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    if filename.endswith("mp4"):
        anim = animation.FuncAnimation(fig, animate, frames=len(frames))
        anim.save(path + filename, fps=fps)
    elif filename.endswith("gif"):
        anim = FuncAnimation(fig, animate, frames=len(frames))
        anim.save(path + filename, writer=PillowWriter(fps=fps))
    else:
        raise ValueError(f"Can't save {filename}; need gif or mp4")

    plt.close(fig)
    return path + filename

def animate_model(model, num_episodes=1, name=None, fps=10, filetype="mp4") -> None:
    """
    Create a simple animation of the model running on CartPole
    """

    env = gymnasium.make("CartPole-v1", render_mode="rgb_array")

    survival = []
    frames = []
    for _ in range(num_episodes):
        observation, _ = env.reset()
        current_survival = 0
        for timestep in range(1000):
            current_survival += 1
            policy = get_policy(model, observation)
            action, log_probability_action = get_action(policy)
            observation, reward, terminated, truncated, _ = env.step(action)
            frames.append(env.render())
            done = truncated or terminated
            if done:
                break
        survival.append(current_survival)

    print(f"Survived {np.mean(survival):.1f} timesteps on average")
    fn = save_frames_as_gif(frames, filename=name, fps=fps, filetype=filetype)
    if filetype == "mp4":
        return Video(fn, embed=True)
    else:
        return Image(filename=fn)

In [23]:
def train(epochs=40) -> nn.Module:
    """Train a Vanilla Policy Gradient model on CartPole

    Args:
        epochs (int, optional): The number of epochs to run for. Defaults to 40.
    """

    # Create the Gym Environment
    # https://www.gymlibrary.dev/environments/classic_control/cart_pole/
    env = gymnasium.make("CartPole-v1")

    # Use random seeds (to make experiments deterministic)
    torch.manual_seed(0)

    # Create the MLP model
    number_observation_features = env.observation_space.shape[0]
    number_actions = env.action_space.n
    model = create_model(number_observation_features, number_actions)

    # Create the optimizer
    optimizer = Adam(model.parameters(), 1e-2)

    best_model = None
    best_return = -np.inf
    # Loop for each epoch
    for epoch in range(epochs):
        average_return = train_one_epoch(env, model, optimizer)

        if average_return > best_return:
            best_model = copy.deepcopy(model)
            best_return = average_return

        if epoch == 0 or (epoch + 1) % (epochs // 10) == 0:
            print('Epoch: %3d Reward: %.1f' % (epoch + 1, average_return))

    env.close()

    return model, best_model

In [14]:
def move_left_model(*args):
    return torch.tensor([1, 0])

animate_model(move_left_model, num_episodes=5, name="MoveLeft", filetype="mp4", fps=3)

Survived 14.4 timesteps on average


In [15]:
def simple(observation):
    pole_angle = observation[3]
    if pole_angle < 0:
        return torch.tensor([1, 0])
    else:
        return torch.tensor([0, 1])

animate_model(simple, num_episodes=5, name="Simple", filetype="mp4", fps=10)

Survived 64.8 timesteps on average


In [24]:
model, _ = train(epochs=0)
animate_model(model, num_episodes=5, name="0 Epochs", filetype="mp4", fps=10)

Survived 23.8 timesteps on average


In [25]:
last_model10, best_model10 = train(epochs=10)
animate_model(best_model10, num_episodes=5, name="10 Epochs", filetype="mp4", fps=10)

Epoch:   1 Reward: 19.8
Epoch:   2 Reward: 19.8
Epoch:   3 Reward: 22.5
Epoch:   4 Reward: 24.2
Epoch:   5 Reward: 24.0
Epoch:   6 Reward: 27.8
Epoch:   7 Reward: 30.0
Epoch:   8 Reward: 31.0
Epoch:   9 Reward: 35.4
Epoch:  10 Reward: 37.0
Survived 43.6 timesteps on average


In [26]:
last_model100, best_model100 = train(epochs=100)

Epoch:   1 Reward: 19.4
Epoch:  10 Reward: 35.3
Epoch:  20 Reward: 58.1
Epoch:  30 Reward: 91.3
Epoch:  40 Reward: 146.9
Epoch:  50 Reward: 106.0
Epoch:  60 Reward: 128.3
Epoch:  70 Reward: 149.6
Epoch:  80 Reward: 92.6
Epoch:  90 Reward: 117.7
Epoch: 100 Reward: 176.6


In [27]:
animate_model(last_model100, num_episodes=5, name="100 Epochs", filetype="mp4", fps=60)

Survived 215.6 timesteps on average


In [29]:
animate_model(best_model100, num_episodes=5, name="100 Epochs", filetype="mp4", fps=60)

Survived 265.2 timesteps on average
