In [3]:
from typing import Tuple, Dict

# data
import numpy as np
import matplotlib.pyplot as plt

# torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical, Normal
from torch.utils.data import DataLoader

# gym
import gymnasium as gym

In [4]:
%matplotlib inline

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
class Policy(nn.Module):
    def __init__(self, n_inputs: int, n_outputs: int, hidden_size: int = 128, continuous_actions: bool = False):
        super(Policy, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(n_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_outputs),
        )
        self.critic = nn.Sequential(
            nn.Linear(n_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )
        
        self.continuous_actions = continuous_actions
        
        if continuous_actions:
            self.log_std = nn.Parameter(torch.zeros(n_outputs))
        else:
            self.softmax = nn.Softmax(dim=0)
        
    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).to(device)
        
        value = self.critic(x)
        logits = self.actor(x) # distribution means if self.continuous_actions
        
        if self.continuous_actions:
            std = torch.exp(self.log_std).expand_as(logits)
            dist = Normal(logits, std)
            
        else:
            probs = self.softmax(logits)
            dist = Categorical(probs)
        
        return value, dist

    def evaluate_state(self, state: torch.Tensor) -> torch.Tensor:
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).to(device)
        
        return self.critic(state)
    
    def sample_action(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).to(device)
        
        logits = self.actor(state)
        if self.continuous_actions:
            std = torch.exp(self.log_std).expand_as(logits)
            dist = Normal(logits, std)
            
        else:
            probs = self.softmax(logits)
            dist = Categorical(probs)
            
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        return action, log_prob
    
    def action_state_log_probs(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state).to(device)
            
        if isinstance(action, np.ndarray):
            action = torch.from_numpy(action).to(device)
        
        logits = self.actor(state)
        if self.continuous_actions:
            std = torch.exp(self.log_std).expand_as(logits)
            dist = Normal(logits, std)
            
        else:
            probs = self.softmax(logits)
            dist = Categorical(probs)
            
        log_prob = dist.log_prob(action)
        entropy = dist.entropy().mean()
        return log_prob, entropy
    
    

In [37]:
env = gym.make("LunarLander-v2")
observation, info = env.reset()

In [27]:
env.close()

In [23]:
model = Policy(
    n_inputs=env.observation_space.shape[0],
    n_outputs=env.action_space.n,
    continuous_actions=False
).to(device=device)

optimiser = optim.Adam(model.parameters(), lr=3e-4)
loss_mse = nn.MSELoss().to(device)

In [9]:
DISCOUNT_FACTOR = 0.99
EPSILON = 0.2
CRITIC_DISCOUNT = 0.5
ENTROPY_COEFF = 0.01
MAX_TRAJECTORY_SIZE = 128
BATCH_SIZE = 32
TRAINING_STEPS = 200_000_000
EPOCHS = 10

In [10]:
def compute_rtgs(batch_rewards):
    """
        Compute the Reward-To-Go of each timestep in a batch given the rewards.

        Parameters:
            batch_rews - the rewards in a batch, Shape: (number of episodes, number of timesteps per episode)

        Return:
            batch_rtgs - the rewards to go, Shape: (number of timesteps in batch)
    """
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    episode_rtgs = []

    # Iterate through each episode
    for trajectory in reversed(batch_rewards):
        discounted_reward = 0 # The discounted reward so far
        for reward in reversed(trajectory):
            # Iterate through all rewards in the episode. We go backwards for smoother calculation of each
            # discounted return (think about why it would be harder starting from the beginning)

            discounted_reward = reward + discounted_reward * DISCOUNT_FACTOR
            episode_rtgs.insert(0, discounted_reward)

    # Convert the rewards-to-go into a tensor
    episode_rtgs = torch.tensor(episode_rtgs, dtype=torch.float)

    return episode_rtgs

In [11]:
def compute_gae(rewards, values, masks, discount_factor=0.99, gae_lambda=0.95):
    
    
    returns = []
    for trajectory in reversed(range(BATCH_SIZE)):
        gae = 0
        values[trajectory].append(0) # last step doesn't exist
    
        for step in reversed(range(len(rewards[trajectory]))):
            delta = rewards[trajectory][step] + discount_factor * values[trajectory][step + 1] * masks[trajectory][step] - values[trajectory][step]
            gae = delta + discount_factor * gae_lambda * masks[trajectory][step] * gae
            returns.insert(0, gae + values[trajectory][step])
    returns = torch.tensor(returns, dtype=torch.float)
    return returns

In [18]:
def sample_batch() -> Dict[str, torch.Tensor]:
    observations = []
    actions = []
    log_probs = []
    rewards = []
    episode_lengths = []
    values = []
    masks = []
    for _ in range(BATCH_SIZE):
        trajectory_rewards = []
        trajectory_masks = []
        trajectory_values = []
        observation, _ = env.reset()
        for step in range(MAX_TRAJECTORY_SIZE):
            observations.append(observation)
            
            action, log_prob = model.sample_action(observation)
            value = model.evaluate_state(observation)
            observation, reward, terminated, _, _ = env.step(action.cpu().numpy())
            
            actions.append(action)
            log_probs.append(log_prob)
            
            trajectory_rewards.append(reward)
            trajectory_masks.append(1 - terminated)
            trajectory_values.append(value)
            
            if terminated:
                break
           
        episode_lengths.append(step + 1)
        rewards.append(trajectory_rewards)
        values.append(trajectory_values)
        masks.append(trajectory_masks)
    
    observations = torch.tensor(np.array(observations))
    actions = torch.tensor(actions)
    log_probs = torch.tensor(log_probs)
    #discounted_rewards = compute_rtgs(rewards)
    gaes = calculate_gae(rewards, values, masks) # type: ignore
    episode_lengths = torch.tensor(episode_lengths)
    
    return {
        "observations": observations,
        "actions": actions,
        "log_probs": log_probs,
        "episode_lengths": episode_lengths,
        "gaes": gaes,
        "rewards": torch.tensor(rewards[0]),
    }

In [17]:
def calculate_gae(rewards, values, masks, discount_factor=0.99, gae_lambda=0.95):
    batch_advantages = []
    for ep_rews, ep_vals, ep_masks in zip(rewards, values, masks):
        advantages = []
        last_advantage = 0

        for t in reversed(range(len(ep_rews))):
            if t + 1 < len(ep_rews):
                delta = ep_rews[t] + discount_factor * ep_vals[t+1] * (1 - ep_masks[t+1]) - ep_vals[t]
            else:
                delta = ep_rews[t] - ep_vals[t]

            advantage = delta + discount_factor * gae_lambda * (1 - ep_masks[t]) * last_advantage
            last_advantage = advantage
            advantages.insert(0, advantage)

        batch_advantages.extend(advantages)

    return torch.tensor(batch_advantages, dtype=torch.float)

In [14]:
def normalize(x):
    x -= x.mean()
    x /= (x.std() + 1e-8)
    return x

In [29]:
def learn():
    model.train()
    total_timesteps = 0
    while total_timesteps < TRAINING_STEPS:
        batch = sample_batch()
        observations = batch["observations"].to(device)
        
        actions = batch["actions"].to(device)
        log_probs = batch["log_probs"].to(device)
        gaes = batch["gaes"].to(device)
        
        episode_lengths = batch["episode_lengths"]
        total_timesteps += sum(episode_lengths).item()
        rewards = batch["rewards"]
        
        values = model.evaluate_state(observations).detach()
        advantages = normalize(gaes - values)
        losses = np.array([])
        
        for _ in range(EPOCHS):
            values = model.evaluate_state(observations)
            current_log_probs, entropy = model.action_state_log_probs(observations, actions)
            
            ratios = torch.exp(current_log_probs - log_probs)
            
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - EPSILON, 1 + EPSILON) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            
            critic_loss = CRITIC_DISCOUNT * loss_mse(values, gaes.unsqueeze(1))
            
            
            loss = actor_loss + critic_loss - ENTROPY_COEFF * entropy
            losses = np.append(losses, loss.detach().cpu().numpy())
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            
        print(f"Total timesteps: {total_timesteps} | Mean loss: {losses.mean()} | Mean rewards: {rewards.mean()}")
    

In [None]:
learn()

In [39]:
torch.save(model, "ppo_2.pt")