In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm

In [2]:
class PolicyModel(nn.Module):
    def __init__(self, input_dim=4):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 16),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.model(x)
        return out
    
    def get_action(self, x):
        out = self(x)
        action = torch.distributions.Normal(out[:, :8], torch.exp(out[:, 8:])).sample()
        return action
    
    def get_probs(self, obs, action):
        out = self(obs)
        probs = torch.distributions.Normal(out[:, :8], torch.exp(out[:, 8:])).log_prob(action)
        return probs.prod(dim=-1)

class ValueModel(nn.Module):
    def __init__(self, input_dim=4):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        out = self.model(x)
        return out

In [3]:
def discounted_returns(rewards, gamma):
    # rewards: shape [T]
    T = rewards.shape[0]
    device = rewards.device

    indices = torch.arange(T, device=device)
    # Create a T x T grid of indices
    j_mat, i_mat = torch.meshgrid(indices, indices, indexing='ij')

    # Mask for upper-triangular (including diagonal): j >= i
    mask = (j_mat > i_mat)

    # Compute exponents for gamma^(j - i)
    exps = i_mat - j_mat
    exps += torch.tensor(1e8, dtype=torch.long) * mask

    # Construct the discount matrix G
    G = gamma ** exps

    # Compute the discounted returns R = G @ rewards
    return G @ rewards

In [4]:
from torch.utils.data import Dataset, DataLoader

# Define a custom dataset for episodes and additional tensors
class EpisodeDataset(Dataset):
    def __init__(self, all_episodes_obs, all_episodes_acts, all_episodes_rews, base_probs, base_advantages, rtgs):
        self.all_episodes_obs = all_episodes_obs
        self.all_episodes_acts = all_episodes_acts
        self.all_episodes_rews = all_episodes_rews
        self.base_probs = base_probs
        self.base_advantages = base_advantages
        self.rtgs = rtgs

    def __len__(self):
        return len(self.all_episodes_obs)

    def __getitem__(self, idx):
        return (self.all_episodes_obs[idx], 
                self.all_episodes_acts[idx], 
                self.all_episodes_rews[idx], 
                self.base_probs[idx], 
                self.base_advantages[idx], 
                self.rtgs[idx])

    def collate_fn(self, batch):
        return (list([item[0] for item in batch]), 
                list([item[1] for item in batch]), 
                list([item[2] for item in batch]), 
                list([item[3] for item in batch]), 
                list([item[4] for item in batch]),
                list([item[5] for item in batch]))



In [9]:

class Agent:
    def __init__(self, gamma=0.99, gae_lambda=0.95, epsilon=0.2, lr=0.0001, env_name="Ant-v5"):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.policy_model = PolicyModel(input_dim=self.observation_space.shape[0])
        self.value_model = ValueModel(input_dim=self.observation_space.shape[0])
        self.policy_optimizer = optim.Adam(list(self.policy_model.parameters()), lr=lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()), lr=lr)
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.epsilon = epsilon
    
    def update(self, policy_loss = None, value_loss = None):
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()
    
    def run_episode(self):
        # records state, action, reward for each step
        env = gym.make(self.env_name)  # Remove render_mode for training
        episode_obs = torch.tensor([])
        episode_acts = torch.tensor([])
        episode_rews = torch.tensor([])
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            obs_input = obs_output
            action = self.policy_model.get_action(obs_input)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
            episode_obs = torch.cat([episode_obs, torch.cat((obs_input, obs_output))[None, :]])
            episode_acts = torch.cat([episode_acts, action])
            episode_rews = torch.cat([episode_rews, torch.tensor([reward])])
        return episode_obs, episode_acts, episode_rews

    def get_losses(self, states, actions, base_probs, base_advantages, real_rtg, epsilon=0.2):
        mse = nn.MSELoss()
        value_loss = mse(self.value_model(states)[:, 0], real_rtg)

        curr_probs = self.policy_model.get_probs(states, actions)
        clipped_weighted_advantages = base_advantages * torch.clip(curr_probs/base_probs, 1-epsilon, 1+epsilon)
        weighted_advantages = base_advantages * curr_probs/base_probs
        policy_loss = -torch.min(clipped_weighted_advantages, weighted_advantages).mean()
        return policy_loss.mean(), value_loss

    def compute_statistics(self, all_episodes_obs, all_episodes_acts, all_episodes_rews):
        base_probs = []
        base_advantages = []
        rtgs = []
        for episode_obs, episode_acts, episode_rews in zip(all_episodes_obs, all_episodes_acts, all_episodes_rews):
            base_probs.append(self.policy_model.get_probs(episode_obs[:, 0], episode_acts).detach())

            td_error = episode_rews + self.gamma * self.value_model(episode_obs[:, 1])[:,0] - self.value_model(episode_obs[:, 0])[:,0]
            gae_schedule = (self.gae_lambda * self.gamma) ** torch.arange(episode_obs.shape[0])
            new_base_advantage = (td_error * gae_schedule).flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) / gae_schedule
            base_advantages.append(new_base_advantage.detach())

            gamma_schedule = self.gamma ** torch.arange(episode_obs.shape[0])
            real_rtg = (episode_rews * gamma_schedule).flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) / gamma_schedule
            rtgs.append(real_rtg.detach())
        base_probs = torch.cat(base_probs)
        base_advantages = torch.cat(base_advantages)
        rtgs = torch.cat(rtgs)
        return base_probs, base_advantages, rtgs
        
        
    def ppo_update(self, all_episodes_obs, all_episodes_acts, all_episodes_rews, steps=4, batch_size=32):
        base_probs, base_advantages, rtgs = self.compute_statistics(all_episodes_obs, all_episodes_acts, all_episodes_rews)
        
        # Create a DataLoader for mini-batching
        dataset = EpisodeDataset(torch.cat(all_episodes_obs).tolist(), torch.cat(all_episodes_acts).tolist(), torch.cat(all_episodes_rews).tolist(), base_probs.tolist(), base_advantages.tolist(), rtgs.tolist())
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn)

        for _ in range(steps):
            for states, acts, rews, base_prob, base_adv, rtgs in dataloader:
                policy_loss, value_loss = self.get_losses(torch.tensor(states)[:, 0], torch.tensor(acts), torch.tensor(base_prob), torch.tensor(base_adv), torch.tensor(rtgs))
                self.update(policy_loss=policy_loss, value_loss=value_loss)
        
        return policy_loss, value_loss

    def avg_reward(self, episodes):
        return torch.tensor([episode[1][:, 1].sum() for episode in episodes]).mean()

    def train(self, num_episodes=100, print_loss=True):
        # collects episodes, updates policy and value models
        all_episodes = []
        for i in range(num_episodes):
            episode = self.run_episode()
            all_episodes.append(episode)
        
        all_episodes_obs = [episode[0] for episode in all_episodes]
        all_episodes_acts = [episode[1] for episode in all_episodes]
        all_episodes_rews = [episode[2] for episode in all_episodes]
        policy_loss, value_loss = self.ppo_update(all_episodes_obs, all_episodes_acts, all_episodes_rews)
        total_reward = self.avg_reward(all_episodes).item()

        if print_loss:
            print(f"Episode {i} policy loss: {policy_loss.item()}")
            print(f"Episode {i} value loss: {value_loss.item()}")
            print(f"Episode {i} average total reward: {total_reward}")
        return (policy_loss, value_loss, total_reward)
    
    def demo(self):
        env = gym.make(self.env_name, render_mode="human")
        observation, info = env.reset()
        obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        terminated = False
        truncated = False
        while not terminated and not truncated:
            action = self.policy_model.get_action(obs_output)
            observation, reward, terminated, truncated, info = env.step(action.numpy()[0])
            obs_output = torch.tensor(observation, dtype=torch.float32)[None, :]
        env.close()

In [None]:
agent = Agent()
policy_losses = []
value_losses = []
total_rewards = []
for i in tqdm(range(50), desc="Training"):
    policy_loss, value_loss, total_reward = agent.train(num_episodes=10, print_loss=False)
    policy_losses.append(policy_loss)
    value_losses.append(value_loss)
    total_rewards.append(total_reward)
    if total_reward > 500:
        print(f"Episode {i} average total reward: {total_reward}")
        break

agent.demo()

: 

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Set the style
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.2)

# Create figure with 3 subplots
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 10))

# Plot total rewards
sns.lineplot(data=total_rewards, ax=ax1)
ax1.set_title('Total Rewards over Time')
ax1.set_xlabel('Training Iteration')
ax1.set_ylabel('Average Total Reward')

# Plot policy losses
sns.lineplot(data=[loss.item() for loss in policy_losses], ax=ax2)
ax2.set_title('Policy Loss over Time')
ax2.set_xlabel('Training Iteration')
ax2.set_ylabel('Policy Loss')

# Plot value losses
sns.lineplot(data=[loss.item() for loss in value_losses], ax=ax3)
ax3.set_title('Value Loss over Time')
ax3.set_xlabel('Training Iteration')
ax3.set_ylabel('Value Loss')

plt.tight_layout()
plt.show()

In [None]:
%debug