In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

In [9]:
class PolicyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out = self.model(x)
        return out
    
    def get_action(self, x):
        out = self(x)
        action = torch.distributions.Categorical(out).sample()
        return action

class ValueModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        out = self.model(x)
        return out

class Agent:
    def __init__(self, gamma=0.99, gae_lambda=0.95, epsilon=0.2, lr=0.001, env_name="CartPole-v1"):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.policy_model = PolicyModel()
        self.value_model = ValueModel()
        self.policy_optimizer = optim.Adam(list(self.policy_model.parameters()), lr=lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()), lr=lr)
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.epsilon = epsilon
    
    def update(self, policy_loss = None, value_loss = None):
        if policy_loss:
            self.policy_optimizer.zero_grad()
            if value_loss:
                policy_loss.backward(retain_graph=True)
            else:
                policy_loss.backward()
            self.policy_optimizer.step()
        if value_loss:
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()
    
    def run_episode(self, env_name="CartPole-v1"):
        # records state, action, reward for each step
        env = gym.make(env_name)  # Remove render_mode for training
        episode_obs = torch.tensor([])
        episode_aux = torch.tensor([])
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            obs_input = obs_output
            action = self.policy_model.get_action(obs_input)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
            episode_obs = torch.cat([episode_obs, torch.cat((obs_input, obs_output))[None, :]])
            episode_aux = torch.cat([episode_aux, torch.tensor([action, reward])[None, :]])
        return episode_obs, episode_aux

    def get_losses(self, all_episodes, base_probs, epsilon=0.2, get_value=False):
        if get_value:
            value_loss = torch.tensor([])
            for episode_obs, episode_aux in all_episodes:
                advantages = episode_aux[:, 1] + self.gamma * self.value_model(episode_obs[:, 1])[:,0] - self.value_model(episode_obs[:, 0])[:,0]
                value_loss = torch.cat([value_loss, advantages**2])
            return None, value_loss.mean()
        else:
            policy_loss = torch.tensor([])
            for (episode_obs, episode_aux), base_prob in zip(all_episodes, base_probs):
                td_errors = episode_aux[:, 1] + self.gamma * self.value_model(episode_obs[:, 1])[:,0] - self.value_model(episode_obs[:, 0])[:,0]
                decay_schedule = (self.gae_lambda * self.gamma) ** torch.arange(len(episode_obs))
                advantages = (td_errors * decay_schedule).flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) / decay_schedule

                curr_probs = self.policy_model(episode_obs[:, 0])[torch.arange(len(episode_obs)), episode_aux[:, 0].to(torch.int64)]
                clipped_weighted_advantages = advantages * torch.clip(curr_probs/base_prob, 1-epsilon, 1+epsilon)
                weighted_advantages = advantages * curr_probs/base_prob

                policy_loss = torch.cat([policy_loss, -torch.min(clipped_weighted_advantages, weighted_advantages)])
            return policy_loss.mean(), None
    
    def ppo_update(self, all_episodes, steps=10):
        base_probs = [self.policy_model(episode_obs[:, 0])[torch.arange(len(episode_obs)), episode_aux[:, 0].to(torch.int64)].detach() for episode_obs, episode_aux in all_episodes]
        for _ in range(steps):
            policy_loss, value_loss = self.get_losses(all_episodes, base_probs)
            self.update(policy_loss=policy_loss, value_loss=value_loss)
        _, value_loss = self.get_losses(all_episodes, base_probs, get_value=True)
        self.update(value_loss=value_loss)
        return policy_loss, value_loss

    def avg_reward(self, episodes):
        return torch.tensor([episode[1][:, 1].sum() for episode in episodes]).mean()

    def train(self, num_episodes=100, print_loss=True):
        # collects episodes, updates policy and value models
        all_episodes = []
        for i in range(num_episodes):
            episode = self.run_episode()
            all_episodes.append(episode)
        
        policy_loss, value_loss = self.ppo_update(all_episodes)
        total_reward = self.avg_reward(all_episodes).item()

        if print_loss:
            print(f"Episode {i} policy loss: {policy_loss.item()}")
            print(f"Episode {i} value loss: {value_loss.item()}")
            print(f"Episode {i} average total reward: {total_reward}")
        return (policy_loss, value_loss, total_reward)
    
    def demo(self, env_name="CartPole-v1"):
        env = gym.make(env_name, render_mode="human")
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            action = self.policy_model.get_action(obs_output)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        env.close()

In [8]:
agent = Agent()
policy_losses = []
value_losses = []
total_rewards = []
for i in tqdm(range(1000), desc="Training"):
    policy_loss, value_loss, total_reward = agent.train(num_episodes=100, print_loss=False)
    policy_losses.append(policy_loss)
    value_losses.append(value_loss)
    total_rewards.append(total_reward)
    if total_reward > 500:
        print(f"Episode {i} average total reward: {total_reward}")
        break

agent.demo()

Training:   8%|▊         | 76/1000 [03:20<40:52,  2.65s/it]  

In [6]:
%debug

> [0;32m/var/folders/22/ff9c0t7s29vfz_hk8wdcgkcc0000gn/T/ipykernel_52418/194965015.py[0m(90)[0;36mget_losses[0;34m()[0m
[0;32m     88 [0;31m                [0mtd_errors[0m [0;34m=[0m [0mepisode_aux[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m1[0m[0;34m][0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mgamma[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mvalue_model[0m[0;34m([0m[0mepisode_obs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m1[0m[0;34m][0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m[0;36m0[0m[0;34m][0m [0;34m-[0m [0mself[0m[0;34m.[0m[0mvalue_model[0m[0;34m([0m[0mepisode_obs[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m0[0m[0;34m][0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     89 [0;31m                [0mdecay_schedule[0m [0;34m=[0m [0;34m([0m[0mself[0m[0;34m.[0m[0mgae_lambda[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mgamma[0m[0;34m)[0m [0;34m**