Understanding: Recreate the RLHF loop in your own words: first supervised fine-tune (SFT) a model, then train a reward model on comparisons, then run PPO. Why is the KL divergence term (keeping $\pi_\theta$ close to a reference model $\pi_{\text{ref}}$) crucial? What happens if $\beta$ (the KL penalty coefficient) is set too low or zero?

1. RLHF loop => pre training -> SFT -> train a reward model which captures how human ranks various outputs(comparision) -> then run a PPO on policy(LM) to give outputs which maximize the reward as per reward model along with outputs which makes sense i.e, are closer to data used in pre training. 

2. KL - Divergence : KL divergance term is used to penalise the reward given by the reward model. Penalized by the KL divergance between difference in log probabilities of output tokens by Policy(LM) and frozen model after SFT. This ensures that there is no reward hacking i.e, our policy doesn't just output anything to maximize the reward but also ensures that the answer is close to what would have been outputted by SFT model. If the KL penalty coefficient is set too low or zero, the effect of this penality is too low so there is a possibility that policy learns to reward hack.



Now to calculate advantages using generalised advantage estimator, given a sequence of values and rewards, we need to calculate the advantage

#delta_t = r_t + gamma * V(s_t+1) - V(s_t)

#advantage_t = delta_t + (gamma * lambda * advantage_t+1)

In [1]:
def GAE(rewards, values, gamma, lamda):
    advantages = []
    # default value for t+1 = 0
    advantage = 0
    # [t-1,0] 
    for i in range(len(rewards) - 1, -1, -1):
        delta = rewards[i] + gamma * (values[i+1] if i+1 < len(values) else 0) - values[i]
        advantage = delta + gamma * lamda * advantage
        advantages.append(advantage)
    return advantages[::-1]



In [2]:
# Verification: Test GAE implementation with synthetic sequence
import numpy as np

# Test case: 4 timesteps
rewards = [1.0, 2.0, 3.0, 4.0]
values = [10.0, 11.0, 12.0, 13.0]  # V(s_t) for each state
gamma = 0.9
lamda = 0.95
terminal_value = 0.0  # V(s_4) = 0 for terminal state (assumed by function)

# Calculate advantages using our implementation
advantages = GAE(rewards, values, gamma, lamda)

print("Test Case:")
print(f"Rewards: {rewards}")
print(f"Values: {values}")
print(f"Gamma: {gamma}, Lambda: {lamda}")
print(f"Terminal value: {terminal_value}")
print(f"\nCalculated Advantages: {[f'{a:.6f}' for a in advantages]}")
print()

# Manual calculation for verification
# We'll calculate backwards from t=3 to t=0
print("Manual Verification (backwards computation):")
print("-" * 60)

# Initialize
manual_advantages = [0] * len(rewards)
advantage_next = 0  # A_4 = 0 (terminal state)

# Calculate backwards (from t=3 to t=0)
for t in range(len(rewards) - 1, -1, -1):
    # Get next state value
    if t + 1 < len(values):
        next_val = values[t + 1]
    else:
        next_val = terminal_value
    
    # Calculate TD error: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
    delta = rewards[t] + gamma * next_val - values[t]
    
    # Calculate advantage: A_t = delta_t + (gamma * lambda) * A_{t+1}
    # advantage_next holds A_{t+1} from previous iteration
    A_t_plus_1 = advantage_next  # Store before updating
    advantage = delta + gamma * lamda * advantage_next
    manual_advantages[t] = advantage
    advantage_next = advantage  # Update for next iteration
    
    print(f"t={t}:")
    print(f"  delta_{t} = r_{t} + γ * V(s_{t+1}) - V(s_{t})")
    print(f"  delta_{t} = {rewards[t]} + {gamma} * {next_val} - {values[t]}")
    print(f"  delta_{t} = {delta:.6f}")
    print(f"  A_{t} = delta_{t} + (γ * λ) * A_{t+1}")
    print(f"  A_{t} = {delta:.6f} + ({gamma} * {lamda}) * {A_t_plus_1:.6f}")
    print(f"  A_{t} = {advantage:.6f}")
    print()

print("Comparison:")
print(f"Implementation: {[f'{a:.6f}' for a in advantages]}")
print(f"Manual calc:    {[f'{a:.6f}' for a in manual_advantages]}")
print(f"Match: {np.allclose(advantages, manual_advantages)}")

Test Case:
Rewards: [1.0, 2.0, 3.0, 4.0]
Values: [10.0, 11.0, 12.0, 13.0]
Gamma: 0.9, Lambda: 0.95
Terminal value: 0.0

Calculated Advantages: ['-1.212470', '-2.470725', '-4.995000', '-9.000000']

Manual Verification (backwards computation):
------------------------------------------------------------
t=3:
  delta_3 = r_3 + γ * V(s_4) - V(s_3)
  delta_3 = 4.0 + 0.9 * 0.0 - 13.0
  delta_3 = -9.000000
  A_3 = delta_3 + (γ * λ) * A_4
  A_3 = -9.000000 + (0.9 * 0.95) * 0.000000
  A_3 = -9.000000

t=2:
  delta_2 = r_2 + γ * V(s_3) - V(s_2)
  delta_2 = 3.0 + 0.9 * 13.0 - 12.0
  delta_2 = 2.700000
  A_2 = delta_2 + (γ * λ) * A_3
  A_2 = 2.700000 + (0.9 * 0.95) * -9.000000
  A_2 = -4.995000

t=1:
  delta_1 = r_1 + γ * V(s_2) - V(s_1)
  delta_1 = 2.0 + 0.9 * 12.0 - 11.0
  delta_1 = 1.800000
  A_1 = delta_1 + (γ * λ) * A_2
  A_1 = 1.800000 + (0.9 * 0.95) * -4.995000
  A_1 = -2.470725

t=0:
  delta_0 = r_0 + γ * V(s_1) - V(s_0)
  delta_0 = 1.0 + 0.9 * 11.0 - 10.0
  delta_0 = 0.900000
  A_0 = delt

In [3]:
# Additional test: Simple case with known answer
print("=" * 60)
print("Additional Test: Simple case")
print("=" * 60)

# Simple test case: lambda=0 (should give TD errors)
rewards_simple = [1.0, 2.0]
values_simple = [0.0, 1.0]
gamma_simple = 0.5
lamda_simple = 0.0  # When lambda=0, GAE reduces to TD error

advantages_simple = GAE(rewards_simple, values_simple, gamma_simple, lamda_simple)

print(f"Rewards: {rewards_simple}")
print(f"Values: {values_simple}")
print(f"Gamma: {gamma_simple}, Lambda: {lamda_simple}")
print(f"\nAdvantages (should equal TD errors when lambda=0): {advantages_simple}")

# Expected: A_0 = r_0 + gamma*V(s_1) - V(s_0) = 1 + 0.5*1 - 0 = 1.5
#           A_1 = r_1 + gamma*V(s_2) - V(s_1) = 2 + 0.5*0 - 1 = 1.0
expected_simple = [1.5, 1.0]
print(f"Expected (TD errors): {expected_simple}")
print(f"Match: {np.allclose(advantages_simple, expected_simple)}")
print()

# Test case: lambda=1 (monte carlo return)
print("=" * 60)
print("Test: Lambda=1 (Monte Carlo)")
print("=" * 60)

rewards_mc = [1.0, 2.0, 3.0]
values_mc = [0.0, 0.0, 0.0]
gamma_mc = 1.0
lamda_mc = 1.0  # When lambda=1, GAE gives Monte Carlo advantages

advantages_mc = GAE(rewards_mc, values_mc, gamma_mc, lamda_mc)
print(f"Rewards: {rewards_mc}")
print(f"Values: {values_mc}")
print(f"Gamma: {gamma_mc}, Lambda: {lamda_mc}")
print(f"\nAdvantages: {advantages_mc}")

# Expected with lambda=1 and gamma=1:
# A_2 = r_2 + V(s_3) - V(s_2) = 3 + 0 - 0 = 3
# A_1 = r_1 + V(s_2) - V(s_1) + gamma*lambda*A_2 = 2 + 0 - 0 + 1*1*3 = 5
# A_0 = r_0 + V(s_1) - V(s_0) + gamma*lambda*A_1 = 1 + 0 - 0 + 1*1*5 = 6
# Actually, let's recalculate more carefully:
# A_2 = delta_2 = 3 + 1*0 - 0 = 3
# A_1 = delta_1 + gamma*lambda*A_2 = (2 + 1*0 - 0) + 1*1*3 = 2 + 3 = 5  
# A_0 = delta_0 + gamma*lambda*A_1 = (1 + 1*0 - 0) + 1*1*5 = 1 + 5 = 6
expected_mc = [6.0, 5.0, 3.0]
print(f"Expected: {expected_mc}")
print(f"Match: {np.allclose(advantages_mc, expected_mc)}")

Additional Test: Simple case
Rewards: [1.0, 2.0]
Values: [0.0, 1.0]
Gamma: 0.5, Lambda: 0.0

Advantages (should equal TD errors when lambda=0): [1.5, 1.0]
Expected (TD errors): [1.5, 1.0]
Match: True

Test: Lambda=1 (Monte Carlo)
Rewards: [1.0, 2.0, 3.0]
Values: [0.0, 0.0, 0.0]
Gamma: 1.0, Lambda: 1.0

Advantages: [6.0, 5.0, 3.0]
Expected: [6.0, 5.0, 3.0]
Match: True
