# PPO

### Setup

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import gymnasium as gym
from gymnasium.wrappers import TransformObservation
from itertools import chain

In [2]:
def pytorch_env(env):
    env = TransformObservation(env, lambda x: torch.from_numpy(x).float(), env.observation_space)
    return env

In [3]:
env = pytorch_env(gym.make('HalfCheetah-v5'))

In [4]:
def update_plot(data, title="", xlabel="", ylabel="", grid=True, sleep=0.01):
    clear_output(wait=True)
    plt.plot(data)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(grid)
    plt.show()
    time.sleep(sleep)

In [5]:
N_OBSERVATIONS = env.observation_space.shape[0]
N_ACTIONS = env.action_space.shape[0]

TEST = False
TRAIN = True
SAVE = False

In [21]:
print(N_ACTIONS)

6


### Hyperparameters

In [6]:
# Hyperparameters
n_agents = 4
n_timesteps = 128
gamma = 0.99
epsilon = 0.2
v = 1.0 # Value constant
learning_rate = 3e-4
batch_size = 64
n_epochs = 10
n_trainsteps = 20
max_grad_norm = 0.5
layer_dim = 64
std = 1.

### Model

Aren't there some inconsistencies with the naming here (obs vs state_batch)? Also, get_distributions doesn't handle the case where only one observation is given as input.

In [7]:
class Actor(nn.Module):
    def __init__(self, n_observations, n_actions, layer_dim):
        """
        Actor network with learned standard deviations and vectorized operations.
        
        Args:
            n_observations (int): Dimension of the observation space
            n_actions (int): Dimension of the action space
            layer_dim (int): Dimension of hidden layers
        """
        super().__init__()
        self.n_actions = n_actions

        # Network for computing mean of actions
        self.mu_network = nn.Sequential(
            nn.Linear(n_observations, layer_dim),
            nn.LayerNorm(layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, layer_dim),
            nn.LayerNorm(layer_dim),
            nn.ReLU(),
            nn.Linear(layer_dim, n_actions)
        )

        # Learnable log stds
        # Initialize to log(0.5) = -0.69
        self.log_std = nn.Parameter(torch.ones(n_actions) * -0.69)

        self.n_actions = n_actions

    def forward(self, obs):
        """
        Compute the mean actions for given observations.

        Args:
            obs (torch.Tensor): Observations, shape [batch_size, n_observations] or [n_observations] for a single observation
        Returns:
            torch.Tensor: Mean actions
        """
        if obs.dim() == 1:
            obs = obs.unsqueeze(0)
        
        means = self.mu_network(obs)
        return means
       
    def get_distribution(self, state_batch):
        """
        Create a batched MultivariateNormal distribution for the given states

        Args:
            state_batch (torch.Tensor): Batch of states, shape [batch_size, n_observations]
        Returns:
            torch.distributions.MultivariateNormal: Distribution with batched parameters
        """
        

    def select_action(self, state):
        with torch.no_grad():
            means = self.forward(state)
        cov_matrix = torch.eye(self.n_actions) * (self.std ** 2)
        action_distribution = torch.distributions.MultivariateNormal(means, cov_matrix)
        action = action_distribution.sample()
        return action

In [8]:
class Critic(nn.Module):
	def __init__(self, n_observations, layer_dim):
		super().__init__()
		self.v_network = nn.Sequential(
			nn.Linear(n_observations, layer_dim),
			nn.LayerNorm(layer_dim),
			nn.Tanh(),
			nn.Linear(layer_dim, layer_dim),
			nn.LayerNorm(layer_dim),
			nn.Tanh(),
			nn.Linear(layer_dim, 1)
		)

	def forward(self, obs):
		value = self.v_network(obs)
		return value

### Test model

In [9]:
def test_model(actor, env):
	episode_reward = 0.
	state, _ = env.reset()
	terminated, truncated = False, False
	while not (terminated or truncated):
		action = actor.select_action(state)
		state, reward, terminated, truncated, _ = env.step(action.numpy())
		episode_reward += reward
	env.close()
	return episode_reward

### PPO Dataset

In [22]:
class PPODataset(Dataset):
	def __init__(self, n_agents, n_timesteps, n_observations, n_actions, gamma):
		# Store variables
		self.n_agents = n_agents
		self.n_timesteps = n_timesteps
		self.n_observations = n_observations
		self.n_actions = n_actions
		self.gamma = gamma

		# Create tensors
		self.states = torch.zeros((n_agents, n_timesteps, n_observations))
		self.actions = torch.zeros((n_agents, n_timesteps, N_ACTIONS))
		self.rewards = torch.zeros((n_agents, n_timesteps))
		self.logprobs = torch.zeros((n_agents, n_timesteps))
		self.pred_values = torch.zeros((n_agents, n_timesteps))
		self.target_values = torch.zeros((n_agents, n_timesteps))
		self.advantages = torch.zeros((n_agents, n_timesteps))

		# Store episode ends
		self.episode_ends = [[] for _ in range(n_agents)]

	def add_step(self, agent, t, state, action, reward, logprob, pred_value):
		self.states[agent, t] = state
		self.actions[agent, t] = action
		self.rewards[agent, t] = reward
		self.logprobs[agent, t] = logprob
		self.pred_values[agent, t] = pred_value

	def mark_episode_end(self, agent, t):
		self.episode_ends[agent].append(t)

	def __compute_advantages_and_target_values(self):
		for agent in range(self.n_agents):
			# Iterate from last to first
			for t in range(self.n_timesteps-1, -1, -1):
				# Get reward and predicted value
				reward = self.rewards[agent, t]
				pred_value = self.pred_values[agent, t]

				# If step is terminal (end of episode)
				if t in self.episode_ends[agent]:
					target_value = reward
				# If step is last (t = n_timesteps - 1) but not terminal
				elif t == self.n_timesteps - 1:
					target_value = pred_value
				# non-terminal non-last step
				else:
					# Get previous target value (of t+1) and compute
					previous_target_value = self.target_values[agent][t+1]
					target_value = reward + self.gamma * previous_target_value
				
				# Compute advantage
				advantage = target_value - pred_value

				# Store target value and advantage
				self.target_values[agent, t] = target_value
				self.advantages[agent, t] = advantage

	def __flatten_data(self):
		# Flatten all tensors that will be fetched
		self.states = self.states.view(-1, self.n_observations)
		self.actions = self.actions.view(-1)
		self.logprobs = self.logprobs.view(-1, self.n_actions)
		self.target_values = self.target_values.view(-1)
		self.advantages = self.advantages.view(-1)

	def compute_advantages_and_target_values_and_flatten_data(self):
		self.__compute_advantages_and_target_values()
		self.__flatten_data()

	def __len__(self):
		return self.n_agents * self.n_timesteps
	
	def __getitem__(self, i):
		# Don't use before calling compute_advantages_and_target_values and flatten_data
		state = self.states[i]
		action = self.actions[i]
		logprob = self.logprobs[i]
		target_value = self.target_values[i]
		advantage = self.advantages[i]

		return state, action, logprob, target_value, advantage

### Loss function(s)

In [23]:
def clipped_objective_fn(pred_logprob_batch, old_logprob_batch, action_batch, advantage_batch, epsilon):
	# Compute p_ratio, ratio of new and old probabilities
    p_ratio = torch.exp(pred_logprob_batch - old_logprob_batch)

    # Compute unclipped and clipped surrogate objectives
    unclipped_surrogate_objective = p_ratio * advantage_batch
    clipped_surrogate_objective = torch.clamp(p_ratio, 1. - epsilon, 1. + epsilon) * advantage_batch

    # Compute elementwise minimum of two and return
    clipped_objective = torch.mean(torch.min(unclipped_surrogate_objective, clipped_surrogate_objective))
    return clipped_objective

In [24]:
def value_loss_fn(pred_values, target_value_batch):
    # Compute and return loss
    smoothl1 = nn.SmoothL1Loss(reduction='mean')
    loss = smoothl1(pred_values, target_value_batch)
    return loss

This is an unreadable mess add comments for the love of god

In [25]:
def loss_fn(actor, critic, state_batch, action_batch, old_logprob_batch, target_value_batch, advantage_batch, epsilon, v):
    # Get pred log probs and values
    distribution_list = actor.distribution_list(state_batch)
    pred_logprob_batch = []
    for i in range(len(state_batch)):
        distribution = distribution_list[i]
        action = action_batch[i].unsqueeze(0)
        pred_logprob = distribution.log_prob(action)
        pred_logprob_batch.append(pred_logprob)
    pred_logprob_batch = torch.stack(pred_logprob_batch)
    pred_value_batch = critic(state_batch)

    # Compute individual losses
    clipped_objective = clipped_objective_fn(pred_logprob_batch, old_logprob_batch, action_batch, advantage_batch, epsilon)
    value_loss = value_loss_fn(pred_value_batch, target_value_batch)

    # Compute and return total loss
    loss = - clipped_objective + v * value_loss
    return loss

### Train loop

In [30]:
def train_loop(actor, critic, optimizer, batch_size, n_epochs, n_trainsteps, n_agents, n_timesteps, n_observations, n_actions, gamma, epsilon, v, max_grad_norm):
	# Initialize rewards and losses
	rewards = []
	losses = []

	# Initilize all environments
	envs = []
	for agent in range(n_agents):
		env = pytorch_env(gym.make('HalfCheetah-v5'))
		state, _ = env.reset()
		terminated, truncated = False, False
		envs.append([env, state, terminated, truncated])
			
	for _ in range(n_trainsteps):
		# Initialize dataset
		dataset = PPODataset(n_agents, n_timesteps, n_observations, n_actions, gamma)

		# Collect data
		for agent in range(n_agents):
			env = envs[agent][0]
			state, terminated, truncated = envs[agent][1:]

			for t in range(n_timesteps):
				# THIS SUCKS I SHOULD HAVE A BETTER WAY TO COMPUTE DISTRIBUTION AND ACTION
				# But it works or at least seems to work
				# Compute distributions predicted value
				with torch.no_grad():
					pred_value = critic(state)
					distribution = actor.distribution_list([state])[0]
				
				# Select and perform action
				action = actor.select_action(state)
				next_state, reward, terminated, truncated, _ = env.step(action.numpy())

				# Compute logprob
				logprob = distribution.log_prob(action)

				# Store data
				dataset.add_step(agent, t, state, action, reward, logprob, pred_value)

				# If terminated reset env and mark end, otherwise update state
				if terminated or truncated:
					dataset.mark_episode_end(agent, t)
					state, _ = env.reset()
					terminated, truncated = False, False
				else:
					state = next_state

		# Compute target values and advantages and flatten data
		dataset.compute_advantages_and_target_values_and_flatten_data()

		# Create dataloader
		dataloader = DataLoader(
			dataset=dataset,
			batch_size=batch_size,
			shuffle=True,
		)

		# Train model on policy for N_EPOCHS
		for _ in range(n_epochs):
			for state_batch, action_batch, logprob_batch, target_value_batch, advantage_batch in dataloader:
				# Compute loss and optimize
				loss = loss_fn(
					actor,
					critic,
					state_batch,
					action_batch,
					logprob_batch,
					target_value_batch,
					advantage_batch,
					epsilon,
					v,
				)
				loss.backward()
				# Clip gradients DISABLED FOR NOW
				#clip_grad_norm_(model.parameters(), max_grad_norm)
				optimizer.step()
				optimizer.zero_grad()

				# Append losses
				losses.append(loss.item())

		# Test model, append reward and display
		test_reward = test_model(actor, env)
		rewards.append(test_reward)
		update_plot(rewards, "Test rewards", "Reward", "Train loop")

	return rewards, losses

### Train

It seems to be working now, although EXTREMELY slowly. So next I should fix the speed. But I'm really glad it's working!

In [31]:
actor = Actor(N_OBSERVATIONS, N_ACTIONS, layer_dim, std)
critic = Critic(N_OBSERVATIONS, layer_dim)
optimizer = optim.Adam(chain(actor.parameters(), critic.parameters()))

In [32]:
if TRAIN:
    rewards, losses = train_loop(
        actor,
        critic,
        optimizer,
        batch_size,
        n_epochs,
        n_trainsteps,
        n_agents,
        n_timesteps,
        N_OBSERVATIONS,
        N_ACTIONS,
        gamma,
        epsilon,
        v,
        max_grad_norm,
	)

RuntimeError: shape '[-1, 6]' is invalid for input of size 512

In [None]:
plt.plot(losses, label="losses")
plt.legend()
plt.grid(True)
plt.show()

NameError: name 'plt' is not defined

In [None]:
test_model(actor, pytorch_env(gym.make('HalfCheetah-v5', render_mode='human')))

41.0

In [2]:
import torch
from torch import nn
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import TransformObservation
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
import sympy

def pytorch_env(env):
    """Convert environment observations to PyTorch tensors."""
    return TransformObservation(env, lambda x: torch.from_numpy(x).float(), env.observation_space)

def update_plot(data, title="", xlabel="", ylabel="", grid=True, sleep=0.01):
    """Update training progress plot."""
    clear_output(wait=True)
    plt.plot(data)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(grid)
    plt.show()
    time.sleep(sleep)

class Actor(nn.Module):
    def __init__(self, n_observations, n_actions, hidden_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(n_observations, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )
        # Initialize log_std outside the network
        self.log_std = nn.Parameter(torch.zeros(n_actions))
        
    def forward(self, state):
        mean = self.network(state)
        std = torch.exp(self.log_std)
        return mean, std
    
    def sample_action(self, state):
        mean, std = self(state)
        normal = torch.distributions.Normal(mean, std)
        action = normal.sample()
        log_prob = normal.log_prob(action).sum(-1)
        return action, log_prob

class Critic(nn.Module):
    def __init__(self, n_observations, hidden_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(n_observations, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(self, state):
        return self.network(state).squeeze(-1)

class PPO:
    def __init__(
        self,
        env_name,
        hidden_dim=64,
        lr=3e-4,
        gamma=0.99,
        epsilon=0.2,
        value_coef=0.5,
        n_steps=2048,
        batch_size=64,
        n_epochs=10
    ):
        # Initialize environment
        self.env = pytorch_env(gym.make(env_name))
        self.n_observations = self.env.observation_space.shape[0]
        self.n_actions = self.env.action_space.shape[0]
        
        # Initialize networks
        self.actor = Actor(self.n_observations, self.n_actions, hidden_dim)
        self.critic = Critic(self.n_observations, hidden_dim)
        
        # Initialize optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
        
        # Store hyperparameters
        self.gamma = gamma
        self.epsilon = epsilon
        self.value_coef = value_coef
        self.n_steps = n_steps
        self.batch_size = batch_size
        self.n_epochs = n_epochs

    def compute_returns(self, rewards, dones, last_value):
        """Compute returns with TD(λ) estimation."""
        returns = []
        running_return = last_value
        for reward, done in zip(reversed(rewards), reversed(dones)):
            running_return = reward + self.gamma * running_return * (1 - done)
            returns.insert(0, running_return)
        return torch.tensor(returns)

    def collect_rollout(self):
        """Collect experience data."""
        states, actions, rewards, log_probs, values, dones = [], [], [], [], [], []
        state, _ = self.env.reset()
        done = False
        
        # Collect steps
        episode_reward = 0
        for _ in range(self.n_steps):
            # Get action and value
            with torch.no_grad():
                action, log_prob = self.actor.sample_action(state)
                value = self.critic(state)
            
            # Take step in environment
            next_state, reward, terminated, truncated, _ = self.env.step(action.numpy())
            done = terminated or truncated
            
            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            log_probs.append(log_prob)
            values.append(value)
            dones.append(done)
            
            episode_reward += reward
            
            if done:
                state, _ = self.env.reset()
                done = False
            else:
                state = next_state
        
        # Get final value for return computation
        with torch.no_grad():
            last_value = self.critic(state)
        
        return (
            torch.stack(states),
            torch.stack(actions),
            torch.tensor(rewards),
            torch.stack(log_probs),
            torch.stack(values),
            torch.tensor(dones),
            last_value,
            episode_reward
        )

    def update_policy(self, states, actions, old_log_probs, returns, advantages):
        """Update policy using PPO objective."""
        # Convert to tensors
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.float32)
        old_log_probs = torch.tensor(old_log_probs, dtype=torch.float32)
        returns = torch.tensor(returns, dtype=torch.float32)
        advantages = torch.tensor(advantages, dtype=torch.float32)
        
        for _ in range(self.n_epochs):
            # Generate random indices
            indices = torch.randperm(len(states))
            
            # Update in mini-batches
            for start in range(0, len(states), self.batch_size):
                end = start + self.batch_size
                batch_indices = indices[start:end]
                
                # Get batch data
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_returns = returns[batch_indices]
                batch_advantages = advantages[batch_indices]
                
                # Get current policy outputs
                mean, std = self.actor(batch_states)
                dist = torch.distributions.Normal(mean, std)
                new_log_probs = dist.log_prob(batch_actions).sum(-1)
                
                # Compute ratio and surrogate objectives
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * batch_advantages
                
                # Compute actor and critic losses
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_pred = self.critic(batch_states)
                critic_loss = self.value_coef * nn.MSELoss()(critic_pred, batch_returns)
                
                # Update actor
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                
                # Update critic
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()

    def train(self, n_iterations=100):
        """Main training loop."""
        rewards_history = []
        
        for iteration in range(n_iterations):
            # Collect experience
            states, actions, rewards, log_probs, values, dones, last_value, episode_reward = self.collect_rollout()
            
            # Compute returns and advantages
            returns = self.compute_returns(rewards, dones, last_value)
            advantages = returns - values
            
            # Normalize advantages
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            
            # Update policy
            self.update_policy(states, actions, log_probs, returns, advantages)
            
            # Store rewards and update plot
            rewards_history.append(episode_reward)
            if (iteration + 1) % 10 == 0:
                update_plot(
                    rewards_history,
                    "Training Progress",
                    "Iteration",
                    "Episode Reward"
                )
                print(f"Iteration {iteration + 1}, Reward: {episode_reward:.2f}")
        
        return rewards_history

# Example usage:
if __name__ == "__main__":
    # Create and train PPO agent
    ppo = PPO("HalfCheetah-v5")
    rewards = ppo.train(n_iterations=100)
    
    # Plot final results
    plt.figure(figsize=(10, 5))
    plt.plot(rewards)
    plt.title("Training Results")
    plt.xlabel("Iteration")
    plt.ylabel("Episode Reward")
    plt.grid(True)
    plt.show()

RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.

In [1]:
from sympy import S