In [25]:
import gymnasium as gym
import numpy as np
from tqdm import tqdm
from gymnasium.wrappers import RecordVideo
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
import torch.optim as optim
import matplotlib.pyplot as plt

env = gym.make(
    "LunarLander-v3",
    continuous=False,     
    gravity=-10.0,        
    enable_wind=False,   
    wind_power=15.0,      
    turbulence_power=1.0, 
    render_mode="rgb_array" 
)

In [26]:
space_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# Training parameters (following paper)
gamma = 0.99                # Discount factor
lr = 3e-4                   # Learning rate
lamb = 0.95                 # Generalised Advantage Estimation (GAE) lambda
epsilon = 0.2               # Clipping value
h = 0.01                    # Entropy coefficient
v = 0.5                     # Value loss coefficient
max_timesteps = 1e6         # Maximal number of iterations
eval_epochs = 1000           # Epochs for evaluation
N = 1                       # Number of agents collecting training data
T = 2048                    # Maximal trajectory length
K = 10                      # Number of epoches per update
max_accuracy = 0            # Accuracy for evaluation
max_reward = 0
minibatch_size = 64         # Size of a mini batch
number_minibatches = N * T / minibatch_size     # Number of mini batches
actor_losses = []
critic_losses = []

In [27]:
# The network to select an action
ActorNetwork = nn.Sequential(
    nn.Linear(space_dim, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 128),
    nn.LeakyReLU(),
    nn.Linear(128, action_dim)
)

# The network to get value of a state
CriticNetwork = nn.Sequential(
    nn.Linear(space_dim, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 128),
    nn.LeakyReLU(),
    nn.Linear(128, 1)
)

# Optimizer using Adam Gradient Descent
actor_optimizer = optim.Adam(ActorNetwork.parameters(), lr=lr)
critic_optimizer = optim.Adam(CriticNetwork.parameters(), lr=lr)

In [28]:
"""
GAE estimates the advantage of taking an action in a state
"""
def compute_GAE(next_value, rewards, values, dones):
    advantages = []
    GAE = 0
    for t in reversed(range(len(rewards))):
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]      # TD error
        GAE = delta + gamma * lamb * (1 - dones[t]) * GAE
        advantages.insert(0, GAE)
        next_value = values[t].item()
    return advantages

In [29]:
video_env = gym.make(
    "LunarLander-v3",
    continuous=False,     
    gravity=-10.0,        
    enable_wind=False,   
    wind_power=15.0,      
    turbulence_power=1.0, 
    render_mode="rgb_array" 
)

def evaluation(eval_epochs):
    total_reward = 0
    for i in range(eval_epochs):
        state, _ = video_env.reset()
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state)
            action_pred = ActorNetwork(state_tensor)
            action = torch.argmax(action_pred).item()
            state, reward, terminated, truncated, _ = video_env.step(action)
            total_reward += reward
            done = terminated or truncated
    return total_reward / eval_epochs     

In [30]:
state, _ = env.reset()      # Initialize state s_t
timesteps = 0

while timesteps < max_timesteps:
    states, actions, rewards, log_probs, values, dones = [], [], [], [], [], []
    print(timesteps)
    for _ in range(T):              # Collect T timesteps for 1 rollout
        state_tensor = torch.FloatTensor(state)
        value = CriticNetwork(state_tensor)
        action_pred = ActorNetwork(state_tensor)                       # Select action a_t
        dist = distributions.Categorical(logits=action_pred)
        action = dist.sample()
        log_prob = dist.log_prob(action)        # Old policy given action a_t

        next_state, reward, terminated, truncated, _ = env.step(action.item())      # Advance simulation one time step
        done = terminated or truncated

        # Collect training data
        states.append(state)         # Collect states
        actions.append(action)        # Collect actions from Actor Network
        rewards.append(reward)        # Collect rewards
        log_probs.append(log_prob)      # Collect lob_probs
        values.append(value)         # Collect values from Critic Network
        dones.append(done)          # Collect done (0 or 1)

        state = next_state      # Move to next state
        timesteps += 1          # Increase timesteps

        if done:
            state, _ = env.reset()      # If an episode is done, reset the state
    
    # Compute advantages
    next_state_tensor = torch.FloatTensor(state)            # Get the last state reached after last action
    next_value = CriticNetwork(next_state_tensor)           # Calculate the next value
    advantages = compute_GAE(next_value, rewards=rewards, values=values, dones=dones)       # Calculate advantage using GAE
    advantages = torch.FloatTensor(advantages)
    returns = advantages + torch.FloatTensor(values)        # Calculate V-target_t = A_t + V-w(s_t)

    # Optimize policy
    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions)
    old_log_probs = torch.FloatTensor(log_probs)
    returns = returns.detach()

    for k in range(K):
        for i in torch.randperm(len(states)).split(minibatch_size):        # Sample over batches with size 'minibatch_size'      
            batch_states = states[i]
            batch_actions = actions[i]
            batch_old_log_probs = old_log_probs[i]
            batch_returns = returns[i]
            batch_advantages = advantages[i]

            optimized_value = CriticNetwork(batch_states)
            optimized_action_preds = ActorNetwork(batch_states)
            optimized_dist = distributions.Categorical(logits=optimized_action_preds)
            optimized_log_probs = optimized_dist.log_prob(batch_actions)
        
            ratio = (optimized_log_probs - batch_old_log_probs).exp()
            unclipped_objective = ratio * batch_advantages
            clipped_objective = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * batch_advantages

            L_clip = -torch.min(unclipped_objective, clipped_objective).mean()
            L_v = (optimized_value.squeeze() - batch_returns).pow(2).mean()
            H = optimized_dist.entropy().mean()

            # Loss = L_clip + v * L_v - h * H

            # Actor loss
            actor_loss = L_clip - h * H
            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()

            # Critic loss
            critic_optimizer.zero_grad()
            L_v.backward()
            critic_optimizer.step()

    if timesteps % 2048 == 0:
        try:
            avg_rewards = evaluation(100)
            print(f"In current timestep: {timesteps}, average reward: {avg_rewards}", flush=True)
        except Exception as e:
            print(f"Error in evaluation: {e}", flush=True)
        


0
In current timestep: 2048, average reward: -133.66161433331453
2048
In current timestep: 4096, average reward: -326.0962024724145
4096
In current timestep: 6144, average reward: -294.67870109642917
6144
In current timestep: 8192, average reward: -227.07719927348182
8192
In current timestep: 10240, average reward: -255.0342714625101
10240
In current timestep: 12288, average reward: -275.13861737629065
12288
In current timestep: 14336, average reward: -212.53248713509012
14336
In current timestep: 16384, average reward: -202.45216735494563
16384
In current timestep: 18432, average reward: -199.82398490698728
18432
In current timestep: 20480, average reward: -177.88119927457382
20480
In current timestep: 22528, average reward: -170.68942691179143
22528
In current timestep: 24576, average reward: -146.54267035096208
24576
In current timestep: 26624, average reward: -80.05114126714214
26624
In current timestep: 28672, average reward: -149.3803054471911
28672
In current timestep: 30720, av

KeyboardInterrupt: 