In [1]:
from typing import Tuple, Dict

# data
import numpy as np
import matplotlib.pyplot as plt

# torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical, Normal
from torch.utils.data import DataLoader

# gym
import gymnasium as gym

In [2]:
%matplotlib inline
#np.random.seed(0)
#torch.manual_seed(0)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class Policy(nn.Module):
    def __init__(self, n_inputs: int, n_outputs: int, hidden_size: int = 128, continuous_actions: bool = False):
        super(Policy, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(n_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_outputs),
        )
        self.critic = nn.Sequential(
            nn.Linear(n_inputs, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )
        
        self.continuous_actions = continuous_actions
        
        if continuous_actions:
            self.log_std = nn.Parameter(torch.zeros(n_outputs))
        else:
            self.softmax = nn.Softmax(dim=0)
        
    def forward(self, x):
        value = self.critic(x)
        logits = self.actor(x) # distribution means if self.continuous_actions
        
        if self.continuous_actions:
            std = torch.exp(self.log_std).expand_as(logits)
            dist = Normal(logits, std)
            
        else:
            probs = self.softmax(logits)
            dist = Categorical(probs)
        
        return value, dist

In [36]:
env = gym.make("LunarLander-v2", render_mode="human")
observation, info = env.reset()

In [34]:
model = Policy(
    n_inputs=env.observation_space.shape[0],
    n_outputs=env.action_space.n,
    continuous_actions=False
).to(device=device)

optimiser = optim.Adam(model.parameters(), lr=3e-4)
loss_mse = nn.MSELoss().to(device)

In [23]:
DISCOUNT_FACTOR = 0.99
EPSILON = 0.2
CRITIC_DISCOUNT = 0.5
ENTROPY_COEFF = 0.01
BATCH_SIZE = 32
BATCH_UPDATES = 10

In [8]:
def play_step(model, observation) -> Dict[str, float | torch.Tensor | int | np.ndarray]:
    value, dist = model.forward(torch.from_numpy(observation).to(device))
    action = dist.sample()
    new_observation, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
    mask = 1
    
    if terminated or truncated:
        new_observation, info = env.reset()
        mask = 0
    
    return {
        "state": torch.from_numpy(observation).to(device),
        "action": action,
        "value": value,
        "reward": torch.FloatTensor([reward]).to(device),
        "log_prob": dist.log_prob(action),
        "mask": torch.FloatTensor([mask]).to(device),
        "new_state": new_observation
    }

In [9]:
def compute_gae(rewards, values, masks, discount_factor=0.99, gae_lambda=0.95):
    values = values + [0] # last step won't have a next value
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + discount_factor * values[step + 1] * masks[step] - values[step]
        gae = delta + discount_factor * gae_lambda * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

In [10]:
def normalize(x):
    x -= x.mean()
    x /= (x.std() + 1e-8)
    return x

In [11]:
def compute_advantage(returns, gae) -> torch.Tensor:
    return normalize(gae - returns)

In [12]:
def play_episode(model, observation) -> Dict[str, torch.Tensor]:
    episode_states = []
    episode_actions = []
    episode_values = []
    episode_rewards = []
    episode_log_probs = []
    episode_masks = []

    while True:
        step_data = play_step(model, observation)
        observation = step_data["new_state"]
        episode_states.append(step_data["state"])
        episode_actions.append(step_data["action"])
        episode_values.append(step_data["value"])
        episode_rewards.append(step_data["reward"])
        episode_log_probs.append(step_data["log_prob"])
        episode_masks.append(step_data["mask"])
        if episode_masks[-1] == 0:
            break

    episode_gae = torch.tensor(compute_gae(episode_rewards, episode_values, episode_masks)).to(device)
    
    episode_values = torch.tensor(episode_values).to(device)
    episode_log_probs = torch.tensor(episode_log_probs).to(device)
    episode_states = torch.stack(episode_states).to(device)
    episode_actions = torch.stack(episode_actions).to(device)
    
    return {
        "states": episode_states,
        "actions": episode_actions,
        "values": episode_values,
        "log_probs": episode_log_probs,
        "gae": episode_gae
    }

In [14]:
def sample_episodes(model, batch_size, env) -> Dict[str, torch.Tensor]:
    
    states = torch.Tensor([]).to(device)
    actions = torch.Tensor([]).to(device)
    log_probs = torch.Tensor([]).to(device)
    values = torch.Tensor([]).to(device)
    gaes = torch.Tensor([]).to(device)

    for _ in range(batch_size):
        observation, _ = env.reset()
        episode_data = play_episode(model, observation)
        states = torch.cat([states, episode_data["states"]])
        actions = torch.cat([actions, episode_data["actions"]])
        log_probs = torch.cat([log_probs, episode_data["log_probs"]])
        values = torch.cat([values, episode_data["values"]])
        gaes = torch.cat([gaes, episode_data["gae"]])
    advantage = compute_advantage(gaes, values)
    
    return {
        "states": states,
        "actions": actions,
        "log_probs": log_probs,
        "values": values,
        "gaes": gaes,
        "advantage": advantage
    }

In [31]:
def fit_model(data) -> Dict[str, float]:
    total_loss = 0
    total_actor_loss = 0
    total_critic_loss = loss_mse(data["values"], data["gaes"]) * CRITIC_DISCOUNT
    total_entropy_loss = 0
    
    loss_tracker = torch.tensor([])
    for state, action, old_log_prob, values, gaes, advantage in zip(data["states"], data["actions"], data["log_probs"], data["values"], data["gaes"], data["advantage"]):
        _, dist = model(state)
        new_log_prob = dist.log_prob(action)
        entropy = dist.entropy().mean()
        
        ratio = (new_log_prob - old_log_prob).exp()
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) * advantage
        actor_loss = -torch.min(surr1, surr2)
        
        #critic_loss = loss_mse(values, gaes)
        
        total_actor_loss += actor_loss
        #total_critic_loss += critic_loss * CRITIC_DISCOUNT
        total_entropy_loss += entropy * ENTROPY_COEFF
        
    total_loss = total_actor_loss + total_critic_loss - total_entropy_loss
    loss_tracker = torch.cat([loss_tracker, total_loss.unsqueeze(0).detach().cpu()])
        
    
    optimiser.zero_grad()
    total_loss.backward()
    optimiser.step()
        
        
    return {"loss": loss_tracker.mean().item()}
        

In [32]:
def train(model, env, batch_size, current_epoch, batch_updates) -> None:
    total_loss = np.array([])
    for batch_update in range(current_epoch, batch_updates + current_epoch):
        data = sample_episodes(model, batch_size, env)
        loss = fit_model(data)["loss"]
        total_loss = np.append(total_loss, loss)
        if batch_update % 5 == 0:
            print(f"Epoch [{current_epoch} - {batch_update}] | Loss: {total_loss.mean()}")

In [None]:
for epoch in range(1000):
    train(model, env, BATCH_SIZE, epoch, BATCH_UPDATES)

In [30]:
env.close()