# Notes

### High Level Quick Read Through

So this isn't fundamentally a new setup or architecture, this is an improvement on the way that we calculate the returns for what we use when we compute the advantages. Let's go through a bit of background of what's going on and frame the problem and then show how GAE is better.

So the context is that in RL we want to find a policy that maximizes the expected cumulative reward. The objective function J(theta) is the expected return over all possible trajectories that could be generated by our policy. Our goal is to find theta such that it maximizes this value. Theta is the the params. To maximize J we perform gradient ascent on theta. The challenging part is calculating the gradient.

The policy gradient theorem gives us a way to compute this. The general / common form is:

$$
g = \nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^T \Psi_t \nabla_\theta \log \pi_\theta(a_t | s_t) \right] \quad
$$

Where delta log pi term is the score function. This is a vector (of the same shape of our weights theta) of the gradient of the log probability with respect to the policys parameters. We then have Psi which acts as the magnitude and sign for our update. This in A2C is the advantage that we calculate.

This paper is about finding a good choice for Psi. First they go through the different existing and flawed choices for psi.

1. The total reward of the trajectory
This means that we sum up the total reward so far for the trajectory. This has a few different flaws. The major one is the causality issue. Good or bad actions that happened in the past could cause this current action to be reinforced regardless if this action was good or bad. This would cause the updates to be very high variance, making the learning slow and unstable. 


2. The reward following action a_t
This is a bit better. Instead of for each one sum up the reward you look at the reward that you got from now until the future. So if you're looking at calculating psi for timestep 3 you would sum the reward for timestep 3 until the end of the episode. This helps to assign credit a bit better

3. Using a baseline
This is a trick that we use to reduce variance. We can subtract any function b(s_t) that depends only and only on the state s (and not the action you take) from our reward-to-go without changing the expected value of the gradient (that is to say introducing bias). This makes it more stable. A good baseline would be the avreage value of the state (which is the value function!)

4. Q- Function
This is another valid choice. The only issue is that it still suffers from high variance. For example in a state where all actions are good, all Q values will be large and positive and give a big gradient to reinforce any action in that state. It doesn't give a strong signal about which action is better than others.

5. Advantage Function
This is the best possible choice for psi from a variance perspective. It is the reward to go with the optimal baseline subtracted. The learning signal is centered around zero. Thisi s a much cleaner signal for learning. This paper is actually about finding the best way to estimate this quantity.

6. TD Residual
The TD residual is the one-step estimate of the advantage function. Because we don't have the perfect advantage function (what the critic is trying to learn), we can reduce the bias of it by increasing the variance by adding in the rewards of a certain number of time steps.

### The core idea of the paper -- generalized advantage estimation

The main issue is that we don't know the true value of the value function. We learn it with a network (critic network). The critic network is not going to be perfect obviously, which will introduce bias. The rest of the paper talks about managing the tradeoff between the tradeoff that is introduced by using empirical returns and the bias of using the value function to calculate the advantage.

The simplest possible advantage estimator uses V heavily. It's the TD residual or TD error. Note: this is the value that we use as psi or the advantage I suppose.

$$
\delta_t^V = r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t) \quad
$$

This value is defined as the reward we get at a timestep plus gamma multiplied by what our value function estimates at the next step minus the value function at this step. I can see why this uses V very heavily. It might work but will be quite biased. At early parts of the training it will be very biased (probably incorrectly) as the value network hasn't really learned anything yet. 

This is low variance though. It's low variance because it's composed of fewer random variables than the alternative monte carlo return which we'll cover in a second. Note: variance here refers to how much their value fluctuates for the same state-action pairs across different trajectories. Monte carlo is random because it's a function of a long chain of random events, a single different action at a step in the future can lead to a completely different sub trajectory. It's high bias because it relies heavily on the probably inacurate value function.


The next idea to look at is using k-step estimators. This is similar to n-step returns and is inbetween TD residual and monte carlo estimate. They define a family of estimators that look k steps into the future. This is different to n-step return in that n-step return is a target for the value function. The k-step advantage estimator is a weight for the policy gradient.

This is the formula for it:

$$
\hat{A}_t^{(k)} = \sum_{l=0}^{k-1} \gamma^l \delta_{t+l}^V
$$

This is essentially summing the TD error (which we defined before) kind of "recursively" down different time steps. For k = 2 we would have

$$
\begin{align}
\hat{A}_t^{(2)} &= \delta_t^V + \gamma \delta_{t+1}^V \\
&= \underbrace{(r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t))}_{\delta_t^V} + \gamma \underbrace{(r_{t+1} + \gamma V_\phi(s_{t+2}) - V_\phi(s_{t+1}))}_{\delta_{t+1}^V} \\
&= r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t) + \gamma r_{t+1} + \gamma^2 V_\phi(s_{t+2}) - \gamma V_\phi(s_{t+1})
\end{align}
$$

and if we look closely the two terms in the middle cancel out the gramma V (s_t+1) and minus gamma V (st+1) at the end. These will continue to cancel leaving us with something like this.

$$
\hat{A}_t^{(k)} = \underbrace{\left( \sum_{l=0}^{k-1} \gamma^l r_{t+l} \right)}_{\text{k-step empirical return}} + \underbrace{\gamma^k V_\phi(s_{t+k})}_{\text{bootstrap value}} - \underbrace{V_\phi(s_t)}_{\text{baseline}}
$$

Which is the general form. Let's look at each of the parts one by one. So the first is part 3, the basline. That's what our critic network thinks is the expectation of the total discounted rewards for this state. We subtract this to center our calculation. The k-step return part 1 is the ground truth portion of our estimate. It's the actual discounted sum of rewards that we observed. The last part is our bootstrap value, that's the estimation part of our new estimate. That's what our network predicts the value of that state is after taking those k steps (we discount it by gamma ^k)

Now, what does the generalized advantage estimator do? Instead of picking one value of k in the above formula it combines all of them using an exponentially-weighted aveerage that is controlled by a new hyperparameter lambda E [0,1].

The definition is

$\hat{A}_t^{\text{GAE}} = (1-\lambda) \sum_{k=1}^\infty \lambda^{k-1} \hat{A}_t^{(k)}$

Which is 1 - lambda multiplied by the sum of weighting the above defined advantage functions for k steps by lambda ^ k-1.

This simplifies to the the following formula which is a lot more simple:

$$
\hat{A}_t^{\text{GAE}(\gamma, \lambda)} = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}^V \quad (\text{Equation 16})
$$

This equation is discounted sum of all future TD residuals. For every step we calculate (and sum) the surprise score (td residual defined earlier as reward we got + gamma value of next situation - value of current situation). So before we do anything else we go through all the steps into the future and calculate that for each state.

Then we need to look at the sum and what the l is. The l stands for lookahead, it's how many steps into the future we are looking at from our current state / position t=0. So what this is defined as then is lambda (our new hyper parameter) multiplied by gamma (our discount factor) to the power of the l (for exponential weighting) multiplied by the TD residual at that step.

Obivously this is just math, in reality we can't do an infinite sum. In practice two things are done. One is that our episodes are finite and end (depending on the setup). If that's the case we can sum til the end. In other cases we collect data in finite batches / finite amounts of steps. The sum stops at the end of our collected data.

We have a clever trick with this formula. We can efficiently calculate it with a backward pass. If we look at A_t it contains the definition of A_t+1 in it. Therefore this simplifies to 


$A_t = \delta_t + \gamma \lambda \cdot A_{t+1}$

Now we have to look at gamma and lambda. They seem to do similar things so it's important to differenciate them.

Gamma is still that same discount factor, it defines how much we care about future rewards. Higher gamma means we have a long horizon, lower means we care more about immediate rewards.

Lambda controls the bias-variance tradeoff of the advantage estimator. It doesn't change the definition of the value function. A higher lambda will mean that we take into account future estimations at a higher weight (making it more similar to monte carlo). This will reduce the bias but increase the variance. At lambda = 0 then we get high bias low variance when the value function is good.

Note: I wonder if there's anything about decaying lambda as you go on?

Comparing to N_Step Advantage

1. N step advantage
    advantage is the outcome over n real steps minus the baseline
    we go to the end of the episode and the boostrap value and then work backwards
    discounting and adding in the real rewards
2. GAE
    Advantage is exponentially weighted sum of all future td residuals
    do the backwards pass and for each state we first calculate the td residuals which is defined
    as the reward we got now + the discounted value of the next step - the value of this step
    then with those td residuals we sum them over time


So what is the actual difference I have to make in my A2C implementation to implement this? Instead of using N-Step advantage we replace it with the GAE 

Note on implementation: there isn't a way to manually truncate GAE sum to a fixed number of steps, it always goes to the end of the rollout.


### Quick Implementation Plan
1. Take my previous A2C algorithm and replace the n-step return with GAE in order to calculate the advantages

In [2]:
import sys
!{sys.executable} -m pip install "gymnasium[classic-control]"
import gymnasium as gym


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import threading
import torch.nn.functional as F

device = torch.device("mps")

class ActorCritic(nn.Module):
    def __init__(self, n_actions, hidden_size):
        # initialize the shared body
        # create the actor head and the critic head as their own extra linear layers on the side
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(4, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.actor_head = nn.Linear(hidden_size, 2)
        self.critic_head = nn.Linear(hidden_size, 1)
        self.to(device)
        
    def forward(self, x):
        x = torch.as_tensor(x, dtype=torch.float32, device=device)
        shared_features = self.shared(x)
        action_distribution = self.actor_head(shared_features)
        critic_value = self.critic_head(shared_features)
        return action_distribution, critic_value

class Worker(threading.Thread):
    def __init__(self, worker_id, task_queue, results_queue, actor_critic):
        super().__init__()
        self.worker_id = worker_id
        self.task_queue = task_queue

        # environment and agent attributes
        self.gym_environment = gym.make("CartPole-v1")
        self.environment_state = self.gym_environment.reset()[0]
        self.actor_critic = actor_critic
        self.results_queue = results_queue
    
    def run(self):
        while True:
            command, data = self.task_queue.get()
            if command == 'collect':
                experience = self.collect_experience(data['n_steps'])
                self.results_queue.put((self.worker_id, experience))
            elif command == 'stop':
                break
            
    def collect_experience(self, n_steps):
        states = []
        actions = []
        rewards = []
        dones = []
        terminates = False

        for i in range(n_steps):
            with torch.no_grad():
                action_distribution_logits, critic_estimate = self.actor_critic.forward(self.environment_state)
            log_probabilities = F.log_softmax(action_distribution_logits, dim=-1)
            probs = torch.exp(log_probabilities)
            action = torch.multinomial(probs, num_samples=1).item()
            next_state, reward, terminated, truncated, info = self.gym_environment.step(action)
            
            states.append(self.environment_state)
            actions.append(action)
            rewards.append(reward)
            dones.append(terminated)

            self.environment_state = next_state

            if(terminated):
                self.environment_state = self.gym_environment.reset()[0]
                terminates = True
                break
                
        last_state = self.environment_state
                    
        return {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "advantages": self.calculate_GAE(states, last_state, rewards, terminates),
            }     

    def calculate_GAE(self, states, last_state, rewards, terminates):
        lambda_val = 0.95
        gamma_val = 0.99
        # we need to go from backwards to front, we first get the bootstrap value
        # from the last state
        # we then first calculate all the td residuals
        a, prev_value = self.actor_critic.forward(last_state)
        if(terminates):
            prev_value = 0
        td_residuals = torch.zeros(len(states))
        for i in reversed(range(len(states))):
            a, curr_value = self.actor_critic.forward(states[i])
            td_residuals[i] = rewards[i] + gamma_val * prev_value - curr_value
            prev_value = curr_value

        previous_advantage = 0
        advantages = torch.zeros(len(states))
        for i in reversed(range(len(states))):
            advantages[i] = td_residuals[i] + lambda_val * gamma_val * previous_advantage
            previous_advantage = advantages[i]
        
        return advantages

In [None]:
# main training loop
import queue
print("starting")

# hyper params
NUM_WORKERS = 6
EPISODES = 600
UPDATE_STEPS = 96
LEARNING_RATE = 0.0003

# setup the global network and optimizer
global_network = ActorCritic(4, 128)
optimizer = optim.Adam(global_network.parameters(), lr=LEARNING_RATE)

workers = []
task_queues = []
results_queue = queue.Queue()
episode_rewards = [] 

for i in range(NUM_WORKERS):
    task_q = queue.Queue()
    worker = Worker(i, task_q, results_queue, global_network)
    workers.append(worker)
    task_queues.append(task_q)
    worker.start()

for episode in range(EPISODES):
    if (episode + 1) % 25 == 0:
        print(f"Episode {episode+1}")

    for q in task_queues:
        q.put(('collect', {"n_steps": UPDATE_STEPS}))

    all_states = []
    all_actions = []
    advantages = []

    for _ in range(NUM_WORKERS):
        worker_id, experience = results_queue.get()
        all_states.extend(experience['states'])
        all_actions.extend(experience['actions'])
        advantages.extend(experience['advantages'])
        episode_rewards.append(sum(experience['rewards']))

    all_states = torch.tensor(all_states, dtype=torch.float32).to(device)
    all_actions = torch.tensor(all_actions, dtype=torch.int64).to(device)
    advantages = torch.tensor(advantages, dtype=torch.float32).to(device)

    action_logits, critic_values = global_network(all_states)
    all_critic_estimates = critic_values.squeeze()
    critic_targets = advantages + all_critic_estimates.detach()

    critic_loss = F.mse_loss(all_critic_estimates, critic_targets)

    all_advantages = advantages.detach()

    log_probs = F.log_softmax(action_logits, dim=-1)
    probs = F.softmax(action_logits, dim=-1)
    entropy = -(probs * log_probs).sum(dim=-1).mean()

    action_log_probs = log_probs.gather(1, all_actions.unsqueeze(1)).squeeze()

    policy_loss = -(action_log_probs * all_advantages).mean()

    total_loss = policy_loss + 0.5 * critic_loss - 0.01 * entropy

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()


for q in task_queues:
    q.put(('stop', None))
for w in workers:
    w.join()


print("done")

starting
Episode 5
Episode 10
Episode 15
Episode 20
Episode 25
Episode 30
Episode 35
Episode 40
Episode 45
Episode 50
Episode 55
Episode 60
Episode 65
Episode 70
Episode 75
Episode 80
Episode 85
Episode 90
Episode 95
Episode 100
Episode 105
Episode 110
Episode 115
Episode 120
Episode 125
Episode 130
Episode 135
Episode 140
Episode 145
Episode 150
Episode 155
Episode 160
Episode 165
Episode 170
Episode 175
Episode 180
Episode 185
Episode 190
Episode 195
Episode 200
Episode 205
Episode 210
Episode 215
Episode 220
Episode 225
Episode 230
Episode 235
Episode 240
Episode 245
Episode 250
Episode 255
Episode 260
Episode 265
Episode 270
Episode 275
Episode 280
Episode 285
Episode 290
Episode 295
Episode 300
Episode 305
Episode 310
Episode 315
Episode 320
Episode 325
Episode 330
Episode 335
Episode 340
Episode 345
Episode 350
Episode 355
Episode 360
Episode 365
Episode 370
Episode 375
Episode 380
Episode 385
Episode 390
Episode 395
Episode 400
Episode 405
Episode 410
Episode 415
Episode 420
Epi

In [23]:
import torch
import time
import torch.nn.functional as F

eval_env = gym.make("CartPole-v1", render_mode="human")

state, info = eval_env.reset()
done = False
total_reward = 0

with torch.no_grad():
    while not done:
        eval_env.render()

        action_logits, _ = global_network.forward(state)
        
        action_probs = F.softmax(action_logits, dim=-1)
        
        action = torch.multinomial(action_probs, num_samples=1).item()
        
        next_state, reward, terminated, truncated, info = eval_env.step(action)
        
        done = terminated or truncated
        
        state = next_state
        total_reward += reward
        

eval_env.close()

print("\n" + "="*40)
print(f"Evaluation episode finished!")
print(f"Total Reward: {total_reward}")
print("="*40)


Evaluation episode finished!
Total Reward: 386.0
