In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

In [2]:
class PolicyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out = self.model(x)
        return out
    
    def get_action(self, x):
        out = self(x)
        action = torch.distributions.Categorical(out).sample()
        return action

class ValueModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        out = self.model(x)
        return out

In [3]:
def discounted_returns(rewards, gamma):
    # rewards: shape [T]
    T = rewards.shape[0]
    device = rewards.device

    indices = torch.arange(T, device=device)
    # Create a T x T grid of indices
    j_mat, i_mat = torch.meshgrid(indices, indices, indexing='ij')

    # Mask for upper-triangular (including diagonal): j >= i
    mask = (j_mat > i_mat)

    # Compute exponents for gamma^(j - i)
    exps = i_mat - j_mat
    exps += torch.tensor(1e8, dtype=torch.long) * mask

    # Construct the discount matrix G
    G = gamma ** exps

    # Compute the discounted returns R = G @ rewards
    return G @ rewards

In [4]:
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset for episodes and additional tensors
class EpisodeDataset(Dataset):
    def __init__(self, all_episodes_obs, all_episodes_aux, base_probs, base_advantages, rtgs):
        self.all_episodes_obs = all_episodes_obs
        self.all_episodes_aux = all_episodes_aux
        self.base_probs = base_probs
        self.base_advantages = base_advantages
        self.rtgs = rtgs

    def __len__(self):
        return len(self.all_episodes_obs)

    def __getitem__(self, idx):
        return (self.all_episodes_obs[idx], 
                self.all_episodes_aux[idx], 
                self.base_probs[idx], 
                self.base_advantages[idx], 
                self.rtgs[idx])

    def collate_fn(self, batch):
        return (list([item[0] for item in batch]), 
                list([item[1] for item in batch]), 
                list([item[2] for item in batch]), 
                list([item[3] for item in batch]), 
                list([item[4] for item in batch]))


In [15]:

class Agent:
    def __init__(self, gamma=0.99, gae_lambda=0.95, epsilon=0.2, lr=0.0001, env_name="CartPole-v1"):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.policy_model = PolicyModel()
        self.value_model = ValueModel()
        self.policy_optimizer = optim.Adam(list(self.policy_model.parameters()), lr=lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()), lr=lr)
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.epsilon = epsilon
    
    def update(self, policy_loss = None, value_loss = None):
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()
    
    def run_episode(self, env_name="CartPole-v1"):
        # records state, action, reward for each step
        env = gym.make(env_name)  # Remove render_mode for training
        episode_obs = torch.tensor([])
        episode_aux = torch.tensor([])
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            obs_input = obs_output
            action = self.policy_model.get_action(obs_input)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
            episode_obs = torch.cat([episode_obs, torch.cat((obs_input, obs_output))[None, :]])
            episode_aux = torch.cat([episode_aux, torch.tensor([action, reward])[None, :]])
        return episode_obs, episode_aux

    def get_losses(self, states, actions, base_probs, base_advantages, real_rtg, epsilon=0.2):
        mse = nn.MSELoss()
        value_loss = mse(self.value_model(states)[:, 0], real_rtg)

        curr_probs = self.policy_model(states)[torch.arange(len(states)), actions.to(torch.int64)]
        clipped_weighted_advantages = base_advantages * torch.clip(curr_probs/base_probs, 1-epsilon, 1+epsilon)
        weighted_advantages = base_advantages * curr_probs/base_probs
        policy_loss = -torch.min(clipped_weighted_advantages, weighted_advantages).mean()
        return policy_loss.mean(), value_loss

    def compute_statistics(self, all_episodes_obs, all_episodes_aux):
        base_probs = []
        base_advantages = []
        rtgs = []
        for episode_obs, episode_aux in zip(all_episodes_obs, all_episodes_aux):
            base_probs.append(self.policy_model(episode_obs[:, 0])[torch.arange(episode_obs.shape[0]), episode_aux[:, 0].to(torch.int64)].detach())

            td_error = episode_aux[:, 1] + self.gamma * self.value_model(episode_obs[:, 1])[:,0] - self.value_model(episode_obs[:, 0])[:,0]
            gae_schedule = (self.gae_lambda * self.gamma) ** torch.arange(episode_obs.shape[0])
            new_base_advantage = (td_error * gae_schedule).flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) / gae_schedule
            base_advantages.append(new_base_advantage.detach())

            gamma_schedule = self.gamma ** torch.arange(episode_obs.shape[0])
            real_rtg = (episode_aux[:, 1] * gamma_schedule).flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) / gamma_schedule
            rtgs.append(real_rtg.detach())
        base_probs = torch.cat(base_probs)
        base_advantages = torch.cat(base_advantages)
        rtgs = torch.cat(rtgs)
        return base_probs, base_advantages, rtgs
        
        
    def ppo_update(self, all_episodes_obs, all_episodes_aux, steps=4, batch_size=32):
        base_probs, base_advantages, rtgs = self.compute_statistics(all_episodes_obs, all_episodes_aux)
        
        # Create a DataLoader for mini-batching
        dataset = EpisodeDataset(torch.cat(all_episodes_obs).tolist(), torch.cat(all_episodes_aux).tolist(), base_probs.tolist(), base_advantages.tolist(), rtgs.tolist())
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn)

        for _ in range(steps):
            for states, aux, base_prob, base_adv, rtgs in dataloader:
                policy_loss, value_loss = self.get_losses(torch.tensor(states)[:, 0], torch.tensor(aux)[:, 0], torch.tensor(base_prob), torch.tensor(base_adv), torch.tensor(rtgs))
                self.update(policy_loss=policy_loss, value_loss=value_loss)
        
        return policy_loss, value_loss

    def avg_reward(self, episodes):
        return torch.tensor([episode[1][:, 1].sum() for episode in episodes]).mean()

    def train(self, num_episodes=100, print_loss=True):
        # collects episodes, updates policy and value models
        all_episodes = []
        for i in range(num_episodes):
            episode = self.run_episode()
            all_episodes.append(episode)
        
        all_episodes_obs = [episode[0] for episode in all_episodes]
        all_episodes_aux = [episode[1] for episode in all_episodes]
        policy_loss, value_loss = self.ppo_update(all_episodes_obs, all_episodes_aux)
        total_reward = self.avg_reward(all_episodes).item()

        if print_loss:
            print(f"Episode {i} policy loss: {policy_loss.item()}")
            print(f"Episode {i} value loss: {value_loss.item()}")
            print(f"Episode {i} average total reward: {total_reward}")
        return (policy_loss, value_loss, total_reward)
    
    def demo(self, env_name="CartPole-v1"):
        env = gym.make(env_name, render_mode="human")
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            action = self.policy_model.get_action(obs_output)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        env.close()

In [13]:
agent = Agent()
policy_losses = []
value_losses = []
total_rewards = []
for i in tqdm(range(100), desc="Training"):
    policy_loss, value_loss, total_reward = agent.train(num_episodes=40, print_loss=False)
    policy_losses.append(policy_loss)
    value_losses.append(value_loss)
    total_rewards.append(total_reward)
    if total_reward > 500:
        print(f"Episode {i} average total reward: {total_reward}")
        break

agent.demo()

Training:   3%|▎         | 31/1000 [02:00<1:57:49,  7.30s/it]

In [14]:
%debug

> [0;32m/var/folders/22/ff9c0t7s29vfz_hk8wdcgkcc0000gn/T/ipykernel_60265/3961413046.py[0m(81)[0;36mppo_update[0;34m()[0m
[0;32m     79 [0;31m        [0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0msteps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m            [0;32mfor[0m [0mstates[0m[0;34m,[0m [0maux[0m[0;34m,[0m [0mbase_prob[0m[0;34m,[0m [0mbase_adv[0m[0;34m,[0m [0mrtgs[0m [0;32min[0m [0mdataloader[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m                [0mpolicy_loss[0m[0;34m,[0m [0mvalue_loss[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_losses[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mstates[0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m0[0m[0;34m][0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0maux[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;36m0[0m[0;34m][0m[0;34m)[0m[0;34m,[0m [0mtorch