# Lab 11: A2C on CartPole with n-Step Bootstrapping

In this experiment, we implement a minimal yet fully functional version of the **Advantage Actor‚ÄìCritic (A2C)** algorithm and apply it to the classic **CartPole-v1** control task.  
A2C is one of the core policy-gradient methods where a learned value function is used as a **baseline** to reduce variance, while the actor learns a stochastic policy that maximizes expected cumulative rewards.

Unlike simple Monte-Carlo policy gradient, A2C supports **bootstrapped n-step returns**, which combine short-horizon rewards with value estimates of future states. This makes training more stable and significantly improves learning efficiency. In this lab, you will:

- Implement an Actor‚ÄìCritic network with shared representation.
- Collect trajectories from the CartPole environment.
- Compute **n-step bootstrapped targets** $R_t^{(n)}$.
- Compute the **advantage** $A_t = R_t^{(n)} - V(s_t)$.
- Update both policy and value networks using stochastic gradient descent.

This version of A2C is intentionally kept small and readable so that you can fully understand how policy gradients and value baselines work together. Once you are familiar with the single-environment version, you can extend it to multi-environment parallel A2C, GAE(Œª), or even PPO.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym 

In [10]:
GAMMA = 0.99
LR = 1e-3
MAX_EPISODES = 1000
PRINT_INTERVAL = 10
N_STEPS = 20  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## üß† Step 1: The Actor‚ÄìCritic Neural Network

In A2C, the policy (actor) and the value function (critic) are implemented using a **single neural network** that shares its initial layers.  
This design is simple, efficient, and helps both components learn better features.

---

## 2.1  Shared Feature Representation

The first part of the network is a shared feed-forward module that takes in the state \( s_t \) and produces a latent feature vector.  
This shared representation serves two purposes:

- It reduces the total number of parameters  
- It allows both actor and critic to learn from the same extracted features  
- It stabilizes training, because both heads learn consistent representations of the environment

---

## 2.2  Policy Head (Actor)

The policy head maps the shared features to a vector of **action logits**, which represent the unnormalized preferences for each action.

From these logits, a categorical policy \( \pi(a|s) \) is formed.

The actor‚Äôs job:

- Output a probability distribution over actions  
- Encourage exploration through stochastic sampling  
- Adjust action probabilities based on advantage values during training  

The outputs of the actor determine which action the agent will take at each step.

---

## 2.3  Value Head (Critic)

The critic head maps the shared features to a single scalar value \( V(s_t) \).

The critic‚Äôs role:

- Estimate the expected discounted return from the current state  
- Provide the **baseline** used to compute advantages  
- Stabilize the actor updates by reducing variance  

A good critic makes the policy gradient updates far more sample-efficient.

---

## 2.4  Why Use a Shared Network?

Sharing the lower layers of the network is beneficial because:

- The actor and critic often rely on similar state features  
- It reduces computational cost and parameter count  
- It speeds up training  
- It encourages the two components to learn a consistent internal representation  

In practice, this shared architecture is widely used in A2C, PPO, IMPALA, and other modern policy-gradient algorithms.

In [11]:
class ActorCritic(nn.Module):
    def __init__(self, obs_dim, n_actions):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
        )
        self.policy_head = nn.Linear(128, n_actions)
        self.value_head = nn.Linear(128, 1)

    def forward(self, x):
        # x: (batch, obs_dim)
        x = self.shared(x)
        logits = self.policy_head(x)           # (batch, n_actions)
        value = self.value_head(x).squeeze(-1) # (batch,)
        probs = torch.softmax(logits, dim=-1)
        dist = Categorical(probs=probs)
        return dist, value

# üîç Step 4: Computing n-step Targets, Advantages, and Loss Functions

In this section, you will complete the core learning logic of the A2C algorithm.  
After collecting **N steps** of rollout data, your task is to compute the necessary learning signals and perform a single update of both the actor and the critic.

---

### 4.1  Compute n-step Bootstrapped Returns

You need to compute the n-step return for each timestep in the rollout.  
This return should:

- Start from the immediate reward,
- Discount future rewards by Œ≥,
- And optionally bootstrap from the value function of the next state (unless the episode has ended).

Mathematically:

$$
G_t = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \dots + \gamma^n V(s_{t+n})
$$

If the environment ends within the rollout, the bootstrap value should be zero.

You should compute these returns **backwards from the last transition**.

---

### 4.2  Compute Advantages

Once you have the n-step targets, compute the advantage:

$$
A_t = G_t - V(s_t)
$$

The value $ V(s_t) $ should come from your critic network, but you should **detach** it to avoid backpropagating through the advantage term.

Advantages act as the learning signal for the actor.

---

### 4.3  Actor Loss (Policy Gradient with Baseline)

The actor should maximize:

$$
J_{\text{actor}} = \mathbb{E}[ A_t \log \pi(a_t|s_t) ]
$$

In practice, you will compute the negative of this expectation because optimizers perform gradient descent.

Your loss should ensure:
- Actions with **positive advantage** become more likely,
- Actions with **negative advantage** become less likely.

---

### 4.4  Critic Loss (Value Function Regression)

The critic should predict the n-step return.  
Optimize it by minimizing the squared error:

$$
J_{\text{critic}} = (V(s_t) - G_t)^2
$$

This gives the value network a stable regression target.

---

### 4.5  Combine Losses and Update the Network

Your final loss is a combination of actor and critic losses:

$$
J = J_{\text{actor}} + 0.5 \, J_{\text{critic}}
$$

Then:

1. Zero out previous gradients  
2. Backpropagate through the combined loss  
3. Update the model parameters once  
4. Clear the rollout buffer  

---

### ‚úÖ Summary of What You Must Implement

- Compute **backward n-step targets** with discounting and bootstrap  
- Compute **advantages**  
- Construct the **actor loss** using policy gradient + baseline  
- Construct the **critic loss** using MSE  
- Perform a **single gradient update**  
- Reset t


In [12]:
env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n

ac = ActorCritic(obs_dim, n_actions).to(device)
optimizer = optim.Adam(ac.parameters(), lr=LR)

episode_rewards = []

for episode in range(1, MAX_EPISODES + 1):

    episode_length = 1
    
    obs, info = env.reset()  # ËÄÅ gym: obs = env.reset()
    obs = torch.tensor(obs, dtype=torch.float32, device=device)

    total_reward = 0.0
    done = False

    obs_buf = []
    logprob_buf = []
    value_buf = []
    reward_buf = []
    done_buf = []

    while not done:
        episode_length+=1
        dist, value = ac(obs.unsqueeze(0))  # (1, obs_dim)
        action = dist.sample()
        log_prob = dist.log_prob(action)


        next_obs, reward, terminated, truncated, info = env.step(
            int(action.item())
        ) 
        done = terminated or truncated

        obs_buf.append(obs)
        logprob_buf.append(log_prob)
        value_buf.append(value.squeeze(0))
        reward_buf.append(reward)
        done_buf.append(float(done))  

        total_reward += reward
        obs = torch.tensor(next_obs, dtype=torch.float32, device=device)

        if len(reward_buf) == N_STEPS or done:
            with torch.no_grad():
                if done:
                    next_value = torch.tensor(0.0, device=device)
                else:
                    _, nv = ac(obs.unsqueeze(0))
                    next_value = nv.squeeze(0)

            values = torch.stack(value_buf)  # (T,)
            log_probs = torch.stack(logprob_buf)  # (T,)
            rewards = torch.tensor(
                reward_buf, dtype=torch.float32, device=device
            )  
            dones = torch.tensor(
                done_buf, dtype=torch.float32, device=device
            ) 


            '''
            Your time to work on it
            '''
            



            obs_buf.clear()
            logprob_buf.clear()
            value_buf.clear()
            reward_buf.clear()
            done_buf.clear()

    episode_rewards.append(total_reward)

    if episode % PRINT_INTERVAL == 0:
        avg_r = np.mean(episode_rewards[-PRINT_INTERVAL:])
        print(
            f"Episode {episode:4d} | "
            f"avg_reward (last {PRINT_INTERVAL}) = {avg_r:.1f} | "
            f"episode length = {episode_length}"
        )

env.close()

Episode   10 | avg_reward (last 10) = 22.8 | episode length = 13
Episode   20 | avg_reward (last 10) = 19.2 | episode length = 16
Episode   30 | avg_reward (last 10) = 20.5 | episode length = 17
Episode   40 | avg_reward (last 10) = 24.2 | episode length = 33
Episode   50 | avg_reward (last 10) = 17.4 | episode length = 28
Episode   60 | avg_reward (last 10) = 17.7 | episode length = 18
Episode   70 | avg_reward (last 10) = 19.7 | episode length = 15
Episode   80 | avg_reward (last 10) = 15.6 | episode length = 10
Episode   90 | avg_reward (last 10) = 16.4 | episode length = 10
Episode  100 | avg_reward (last 10) = 18.7 | episode length = 25
Episode  110 | avg_reward (last 10) = 17.1 | episode length = 15
Episode  120 | avg_reward (last 10) = 17.1 | episode length = 14
Episode  130 | avg_reward (last 10) = 17.6 | episode length = 17
Episode  140 | avg_reward (last 10) = 19.9 | episode length = 13
Episode  150 | avg_reward (last 10) = 24.0 | episode length = 29
Episode  160 | avg_reward

KeyboardInterrupt: 