<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/rl/%5BRL%5D_Actor_Critic_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

> Tutorial: https://github.com/pytorch/examples/blob/main/reinforcement_learning/actor_critic.py

Terminology
- Actor == Policy
- Critic == Value
- Reward: the immediate feedback
- Return: the total discounted reward,  starting from the current timestep.

In [149]:
# @title Imports
import argparse
import gymnasium
import numpy as np
from itertools import count
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [150]:
# Cart Pole Gym

seed = 543
gamma = 0.99
log_interval = 10

env = gymnasium.make('CartPole-v1')
env.reset(seed=seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x78abdff42530>

In [151]:
# @title Libs - Policy model

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

class Policy(nn.Module):
  """Implement both actor and critic in one model."""

  def __init__(self):
    super().__init__()

    hidden_dim = 128

    # common layer
    self.affine1 = nn.Linear(4, hidden_dim)   # 4 is the state space.

    # actor's head
    self.actor_head = nn.Linear(hidden_dim, 2)

    # critic head
    self.value_head = nn.Linear(hidden_dim, 1)

    # action & reward buffer
    self.saved_actions = []
    self.rewards = []

  def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Forward for both actor and critic.

    Args:
      x: the current state. It should contain 4 floats to represent the state.
    """
    assert x.shape == (4,)
    x = self.affine1(x)     # [H]
    x = F.relu(x)

    action_logits = self.actor_head(x)
    action_prob = F.softmax(action_logits, dim=-1)   # [2]

    state_value = self.value_head(x)                  # [1]

    return action_prob, state_value

In [152]:
# @title Libs - selection action

bs = 2
policy = Policy()
state = torch.randn((bs, 4))
state = np.random.randint((bs, 4))
# output_action_probs, output_state_values = policy(input)

# print(f"{output_action_probs=}, {output_state_values=}")

def select_action(policy: Policy, state: int):
  state = torch.from_numpy(state).float()
  action_prob, state_value = policy(state)

  # Sample from the probs and emit the enum int of the next action
  sampler = Categorical(action_prob)
  next_action = sampler.sample()               # [1]

  # save the action to buffer
  #
  # log_prob is the log of the probability of the selected action.
  log_prob = sampler.log_prob(next_action)     # [1]
  assert torch.allclose(log_prob, action_prob[next_action].log(), atol=1e-5)
  policy.saved_actions.append(SavedAction(log_prob, state_value))

  return next_action.item()

state, _ = env.reset()
next_action = select_action(policy, state)
print(f"{next_action=}, prob={policy.saved_actions[-1].log_prob.exp():.3f}, pred value={policy.saved_actions[-1].value.item():.3f}")
print(f"# saved actions = {len(policy.saved_actions)}")

next_action=1, prob=0.499, pred value=0.085
# saved actions = 1


## Training

Question - In the training loop below, it does **many FWD passes** to collect a full trajectory, and then do a **single backprop** for all these fwd passes together. How is this possible this is not 1-to-1 mapping? Specifically, each **FWD pass' activations** need to be remembered to be used in backprop to calculate **grad**, but will a later FWD overwrite its previous FWD's activation?

> Yes this works! The **activations from FWD passes** are stored in what's called a **computation graph**. When you do FWD or loss calculation, these operations (including activation) are stored into the CG; when you do **multiple FWDs** and accumulate their results (like summing up the losses) into a single tensor, PyTorch extends the **same computation graph**. During backprop, the stored activations are used to calculate the grad. And importantly, when loss.backward() is called, PyTorch **destroys this entire computation graph** by default to free up memory.


In [153]:
def finish_episode(optimizer, policy: Policy, gamma: float, eps=np.finfo(np.float32).eps.item()):
  """Training code. Calculates actor and critic loss and performs backprop.

  Args:
    policy: the policy model.
    gamma: The discount factor (gamma) determines how much an agent prioritizes
      future rewards over immediate ones.
  """

  # 1. Init vars
  saved_actions = policy.saved_actions
  actor_losses = []       # Save the actor (policy) losses
  critic_losses = []      # Save the critic (value) losses

  # Reward: for a single step
  # Return: for the aggregated
  R = 0                   # Return for a single trajectory
  returns = []            # List to save the true values for a single trajectory
  n_steps = len(policy.rewards)

  # 2. Calculate the true value using rewards returned from the env
  for r in policy.rewards[::-1]:  # iterate backwards
    # calculate the discounted value
    R = r + gamma * R
    returns.insert(0, R)

  returns = torch.tensor(returns)
  returns = (returns - returns.mean()) / (returns.std() + eps)
  # print(f"{returns=}")

  assert len(returns) == len(policy.saved_actions)

  # 3. Collect losses from the full trajectory
  for (log_prob, value), R in zip(policy.saved_actions, returns):
    advantage = R - value.item()

    # actor loss
    actor_losses.append(-1 * log_prob * advantage)

    # critic loss
    critic_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))

  # 4. Backprop
  optimizer.zero_grad()
  loss = torch.stack(actor_losses).sum() + torch.stack(critic_losses).sum()
  loss_val = loss.detach().item()
  loss.backward()
  optimizer.step()

  # 5. Clear the rewards and action buffer
  del policy.rewards[:]
  del policy.saved_actions[:]

  return loss_val, n_steps

# optimizer = optim.Adam(policy.parameters(), lr=3e-2)
# eps = np.finfo(np.float32).eps.item()

# finish_episode(optimizer=optimizer, policy=policy, gamma=0.9, eps=eps)


## Main

In [154]:
running_reward = 0

policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()
reward_threshold = env.spec.reward_threshold

print(f"Start running... Target reward threshold: {reward_threshold}")

# run infinitely many episodes, until the reward threshold is met.
for i_episode in count(1):
  # reset the env
  state, _ = env.reset()
  ep_reward = 0

  # for each episode, only run 9999 steps to avoid infinite loop
  for t in range(1, 10000):
    action = select_action(policy=policy, state=state)

    state, reward, done, _, _ = env.step(action)

    policy.rewards.append(reward)
    ep_reward += reward

    if done:
      break

  # Backprop
  loss, n_steps = finish_episode(optimizer=optimizer, policy=policy, gamma=gamma, eps=eps)

  if i_episode % log_interval == 0:
    print(f"Episod: {i_episode}, Total Loss: {loss:.3f}, Avg Loss: {loss/n_steps:.3f}, Running Return: {running_reward:.3f}")

  running_reward = 0.05 * ep_reward + 0.95 * running_reward
  if running_reward > reward_threshold:
    print(f"Solved after {i_episode} episods! Running reward: {running_reward} > reward threshold: {reward_threshold}")
    break


Start running... Target reward threshold: 475.0
Episod: 10, Total Loss: 1.802, Avg Loss: 0.225, Running Return: 4.478
Episod: 20, Total Loss: 0.215, Avg Loss: 0.027, Running Return: 6.406
Episod: 30, Total Loss: 0.247, Avg Loss: 0.031, Running Return: 7.524
Episod: 40, Total Loss: 0.083, Avg Loss: 0.009, Running Return: 8.144
Episod: 50, Total Loss: 0.236, Avg Loss: 0.030, Running Return: 8.609
Episod: 60, Total Loss: 0.328, Avg Loss: 0.033, Running Return: 8.757
Episod: 70, Total Loss: 0.062, Avg Loss: 0.007, Running Return: 9.109
Episod: 80, Total Loss: 0.525, Avg Loss: 0.066, Running Return: 9.063
Episod: 90, Total Loss: 0.355, Avg Loss: 0.044, Running Return: 9.253
Episod: 100, Total Loss: 0.023, Avg Loss: 0.003, Running Return: 9.167
Episod: 110, Total Loss: 0.176, Avg Loss: 0.018, Running Return: 9.288
Episod: 120, Total Loss: 0.017, Avg Loss: 0.002, Running Return: 9.328
Episod: 130, Total Loss: 0.853, Avg Loss: 0.071, Running Return: 9.377
Episod: 140, Total Loss: 0.060, Avg Lo