In [None]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from jaxtyping import Float, Int64

import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium.spaces import Discrete, Box
from typing import cast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
device

In [None]:
def init_layers(layer: nn.Module, std= np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
    if isinstance(layer, nn.Linear):
        torch.nn.init.orthogonal_(layer.weight, std)
        torch.nn.init.constant_(layer.bias, bias_const)
        return layer
    
    return layer

In [None]:
class ActorNewtwork(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int):
        super().__init__()
        self.model = nn.Sequential(
            init_layers(nn.Linear(state_dim, hidden_dim), std=np.sqrt(2)),
            nn.ReLU(),
            init_layers(nn.Linear(hidden_dim, hidden_dim), std=np.sqrt(2)),
            nn.ReLU(),
            init_layers(nn.Linear(hidden_dim, action_dim), std=0.01)
        )

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.model(state)

In [None]:
class CriticNetwork(nn.Module):
    def __init__(self, state_dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.model = nn.Sequential(
            init_layers(nn.Linear(state_dim, hidden_dim), std=np.sqrt(2)),
            nn.ReLU(),
            init_layers(nn.Linear(hidden_dim, hidden_dim), std=np.sqrt(2)),
            nn.ReLU(),
            init_layers(nn.Linear(hidden_dim, 1), std=1.0)
        )


    def forward(self, state: torch.Tensor):
        return self.model(state)

In [None]:
class PPOAgent(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int) -> None:
        super().__init__()

        self.actor = ActorNewtwork(state_dim, action_dim, hidden_dim)
        self.critic = CriticNetwork(state_dim, hidden_dim)

    def get_action_and_log_prob(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        action_logits = self.actor(state)
        action_distribution = Categorical(logits=action_logits)
        action = action_distribution.sample()
        action_log_probability = action_distribution.log_prob(action)

        return action, action_log_probability
    
    def get_action_log_prob_entropy(self, state: torch.Tensor, action) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        action_logits = self.actor(state)
        action_distribution = Categorical(logits=action_logits)
        action_log_probability = action_distribution.log_prob(action)
        entropy = action_distribution.entropy()

        return action, action_log_probability, entropy

In [None]:
import numpy as np


class RolloutBuffer(Dataset):
    def __init__(
            self,
            buffer_size: int,
            state_dimension: int,
            device: torch.device):
        self._state_buffer: Float[Tensor, "state_dimension"] = torch.zeros(
            (buffer_size, state_dimension)).to(device)
        self._action_buffer: Int64[Tensor, ""] = torch.zeros(
            (buffer_size), dtype=torch.int64).to(device)
        self._log_probabilities: Float[Tensor, ""] = torch.zeros(
            (buffer_size)).to(device)
        self._rewards: Float[Tensor, ""] = torch.zeros(
            (buffer_size)).to(device)
        self._done_flags: Float[Tensor, ""] = torch.zeros(
            (buffer_size)).to(device)
        self._state_value_predictions: Float[Tensor, ""] = torch.zeros(
            (buffer_size)).to(device)

        self.reset()
        self._buffer_size = buffer_size
        self._device = device

    def add(
            self,
            state: np.ndarray,
            action: int,
            logprob: float,
            reward: float,
            is_episode_done: bool,
            state_value_prediction: float):
        self._state_buffer[self._pos] = torch.as_tensor(
            state, device=self._device)
        self._action_buffer[self._pos] = action
        self._log_probabilities[self._pos] = logprob
        self._rewards[self._pos] = reward
        self._done_flags[self._pos] = float(is_episode_done)
        self._state_value_predictions[self._pos] = state_value_prediction

        self._pos += 1

    def compute_return_target(self, last_value, gamma, lamb):
        self._gaes = torch.zeros((self._buffer_size)).to(self._device)

        for t in reversed(range(self._buffer_size)):
            is_last = t == self._buffer_size - 1
            if is_last:
                next_value = last_value
                next_gae = 0
            else:
                next_value = self._state_value_predictions[t+1]
                next_gae = self._gaes[t+1]

            delta_t = self._rewards[t] + gamma*next_value * \
                (1-self._done_flags[t]) - self._state_value_predictions[t]

            gae = delta_t + gamma*lamb*(1-self._done_flags[t])*next_gae
            self._gaes[t] = gae

        self._return_targets = self._gaes + self._state_value_predictions

    @property
    def gaes(self):
        return self._gaes

    @property
    def return_targets(self):
        return self._return_targets

    def __len__(self):
        return self._state_buffer.shape[0]

    def __getitem__(self, idx: int):
        return {
            'state_buffer': self._state_buffer[idx],
            'action_buffer': self._action_buffer[idx],
            'log_probabilities': self._log_probabilities[idx],
            'rewards': self._rewards[idx],
            'done_flags': self._done_flags[idx],
            'state_value_predictions': self._state_value_predictions[idx],
            'return_targets': self._return_targets[idx],
            'gaes': self._gaes[idx]
        }
    
    def reset(self):
        self._pos = 0

In [None]:
NUM_ITERATIONS = 250
ENV_NAME = 'LunarLander-v3'

env = gym.make(ENV_NAME, gravity=-10.0,
               enable_wind=True, wind_power=15.0, turbulence_power=1.5)

action_dim = int(cast(Discrete, env.action_space).n)
state_dim = cast(Box, env.observation_space).shape[0]

In [None]:
HIDDEN_DIM = 64
BUFFER_SIZE = 1024
EPOCHS = 25
BATCH_SIZE = 64

gamma = 0.99
learning_rate = 3.0e-4
epsilon = 0.1

c1 = 0.5
c2 = 0.01

agent = PPOAgent(state_dim, action_dim, HIDDEN_DIM).to(device)
buffer = RolloutBuffer(BUFFER_SIZE, state_dim, device)

optimizer = torch.optim.Adam(agent.parameters(), lr=learning_rate)

In [None]:
gae_lamb = 0.95
clip_epsilon = 0.2
max_norm_value = 0.5

In [None]:
state, _ = env.reset()
last_state = state

avg_losses = []
avg_entropies = []
avg_rewards = []

for iteration in range(NUM_ITERATIONS):
    state = last_state
    for step in range(BUFFER_SIZE):
        with torch.no_grad():
            action, action_log_probability = agent.get_action_and_log_prob(torch.as_tensor(state, device=device))
            next_state, reward, terminated, truncated, _ = env.step(int(action.item()))
            is_episode_done = terminated or truncated 
            state_value_prediction = agent.critic(torch.as_tensor(state, device=device))
            buffer.add(state, int(action), float(action_log_probability), float(reward), is_episode_done, float(state_value_prediction))
        if is_episode_done:
            state, _ = env.reset()
            last_state = state
        else:
            state = next_state
            last_state = next_state

    with torch.no_grad():
        next_state_value_prediction = agent.critic(torch.as_tensor(state, device=device))

    buffer.compute_return_target(next_state_value_prediction, gamma, gae_lamb)
    train_loader = DataLoader(
        buffer, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        generator=torch.Generator(device=device)
    )
    losses = []
    entropies = []
    for epoch in range(EPOCHS):
        for batch in train_loader:
            state = batch['state_buffer']
            gaes: torch.Tensor = batch['gaes']
            old_log_probs = batch['log_probabilities']
            action = batch['action_buffer']

            _, log_prob, entropy = agent.get_action_log_prob_entropy(state, action)

            advatages_norm = (gaes - gaes.mean()) / (gaes.std() + epsilon)
            probability_ratio = torch.exp(log_prob - old_log_probs)

            L_clip = torch.min(
                probability_ratio*advatages_norm,
                torch.clip(probability_ratio, 1 - clip_epsilon, 1 + clip_epsilon)*advatages_norm
            )

            return_targets = batch['return_targets']
            current_value_pred = agent.critic(state).squeeze(-1)
            L_VF = 1/2 * torch.abs(current_value_pred - return_targets)**2

            loss = torch.mean(-L_clip + c1*L_VF - c2*entropy)

            optimizer.zero_grad()
            loss.backward()

            clip_grad_norm_(agent.parameters(), max_norm=max_norm_value)
            optimizer.step()

            losses.append(loss.item())
            entropies.append(entropy.mean().item())
        
    avg_loss = np.mean(losses)
    avg_entropy = np.mean(entropies)
    avg_reward = np.mean(buffer._rewards.cpu().numpy())

    avg_losses.append(avg_loss)
    avg_entropies.append(avg_entropy)
    avg_rewards.append(avg_reward)

    print(f"Iteration {iteration + 1}/{NUM_ITERATIONS}, Loss: {avg_loss:.4f}, Avg Entropy: {avg_entropy:.4f}, Avg Reward: {avg_reward:.4f}")
    buffer.reset()
        

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(avg_losses)
plt.title('Average Loss per Iteration')
plt.xlabel('Iteration')
plt.ylabel('Loss')

plt.subplot(1, 3, 2)
plt.plot(avg_entropies)
plt.title('Average Entropy per Iteration')
plt.xlabel('Iteration')
plt.ylabel('Entropy')

plt.subplot(1, 3, 3)
plt.plot(avg_rewards)
plt.title('Average Reward per Iteration')
plt.xlabel('Iteration')
plt.ylabel('Reward')
plt.tight_layout()
plt.show()

In [None]:
eval_env = gym.make(ENV_NAME, render_mode='human', gravity=-10.0,
               enable_wind=True, wind_power=15.0, turbulence_power=1.5)

agent.eval()
for episode in range(2):
    state, _ = eval_env.reset()
    done = False
    total_reward = 0.0

    while not done:
        with torch.no_grad():
            action, _ = agent.get_action_and_log_prob(torch.as_tensor(state, device=device))
            next_state, reward, terminated, truncated, _ = eval_env.step(int(action.item()))
            done = terminated or truncated
            total_reward += float(reward)
            state = next_state

    print(f"Episode {episode + 1}: Total Reward: {total_reward}")

eval_env.close()