# Bellman Equations: The Foundation of All RL Algorithms

Welcome to the most important mathematical concept in RL! The Bellman equations are the foundation upon which ALL reinforcement learning algorithms are built.

## What You'll Learn

By the end of this notebook, you'll understand:
- The recursive nature of value (with a treasure map analogy!)
- Bellman expectation equations for V and Q
- Bellman optimality equations (finding the best!)
- Policy evaluation: computing values for a given policy
- Value iteration: finding optimal values directly
- Why these equations matter for RL

**Prerequisites:** Notebooks 1-4 (especially policies and value functions)

**Time:** ~35 minutes

---
## The Big Picture: The Treasure Map Analogy

Imagine you're a pirate with a treasure map:

```
    ┌────────────────────────────────────────────────────────────────┐
    │          THE BELLMAN EQUATION: A TREASURE MAP ANALOGY          │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  You're at Island A, trying to reach the treasure at Island X │
    │                                                                │
    │       (A) ──$10──> (B) ──$20──> (C) ──$100──> (X) TREASURE!   │
    │                                                                │
    │  Question: What's the value of being at Island A?             │
    │                                                                │
    │  NAIVE APPROACH:                                               │
    │    Calculate the entire path: $10 + $20 + $100 = $130         │
    │    But this requires knowing the ENTIRE path!                 │
    │                                                                │
    │  BELLMAN'S INSIGHT:                                           │
    │    V(A) = Reward(A→B) + V(B)                                  │
    │         = $10 + (value of being at B)                         │
    │                                                                │
    │    "My value = What I get now + Value of where I end up"      │
    │                                                                │
    │  This is RECURSIVE! We break a big problem into smaller ones. │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

**The Bellman equation says: The value of a state equals the immediate reward plus the (discounted) value of the next state.**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Circle, FancyArrowPatch, Rectangle
from matplotlib.colors import LinearSegmentedColormap

# Visualize the treasure map analogy
fig, ax = plt.subplots(figsize=(14, 6))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

ax.text(7, 7.5, 'The Bellman Equation: Recursive Value Calculation', 
        ha='center', fontsize=16, fontweight='bold')

# Islands (states)
islands = [
    {'name': 'A', 'x': 1.5, 'color': '#bbdefb', 'value': 117},
    {'name': 'B', 'x': 4.5, 'color': '#c8e6c9', 'value': 119},
    {'name': 'C', 'x': 7.5, 'color': '#fff3e0', 'value': 110},
    {'name': 'X', 'x': 10.5, 'color': '#ffeb3b', 'value': 0},  # Treasure!
]

# Draw islands
for island in islands:
    circle = Circle((island['x'], 4), 0.8, facecolor=island['color'], 
                     edgecolor='black', linewidth=2)
    ax.add_patch(circle)
    ax.text(island['x'], 4.2, island['name'], ha='center', va='center', 
            fontsize=16, fontweight='bold')
    if island['name'] == 'X':
        ax.text(island['x'], 3.5, 'TREASURE!', ha='center', fontsize=8, color='#795548')
    else:
        ax.text(island['x'], 3.5, f'V={island["value"]}', ha='center', fontsize=9)

# Draw arrows with rewards
rewards = [('A', 'B', 10), ('B', 'C', 20), ('C', 'X', 100)]
for i, (start, end, reward) in enumerate(rewards):
    x1, x2 = islands[i]['x'] + 0.8, islands[i+1]['x'] - 0.8
    ax.annotate('', xy=(x2, 4), xytext=(x1, 4),
                arrowprops=dict(arrowstyle='->', lw=2, color='#666'))
    ax.text((x1 + x2) / 2, 4.6, f'+${reward}', ha='center', fontsize=11, 
            color='#388e3c', fontweight='bold')

# Bellman equation explanation
ax.text(7, 1.8, 'BELLMAN EQUATION:', ha='center', fontsize=13, fontweight='bold')
ax.text(7, 1.2, 'V(A) = Reward(A→B) + γ × V(B)', ha='center', fontsize=12)
ax.text(7, 0.6, '117 = 10 + 0.9 × 119', ha='center', fontsize=11, color='#666')
ax.text(7, 0.1, '"My value = Immediate reward + Discounted future value"', 
        ha='center', fontsize=10, style='italic', color='#888')

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("THE KEY INSIGHT: RECURSION")
print("="*70)
print("""
Instead of calculating the ENTIRE path's value:
  V(A) = r₀ + γr₁ + γ²r₂ + γ³r₃ + ...

We use RECURSION:
  V(A) = r + γ × V(next_state)
       = "what I get now" + "discounted value of where I end up"

This breaks a complex problem into simpler sub-problems!
""")
print("="*70)

---
## The Bellman Expectation Equation

The Bellman equation we saw was simplified. The full version handles:
- **Stochastic policies** (probability of actions)
- **Stochastic transitions** (probability of next states)

```
    ┌────────────────────────────────────────────────────────────────┐
    │           BELLMAN EXPECTATION EQUATION FOR V^π                 │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  V^π(s) = Σ_a π(a|s) × Σ_s' P(s'|s,a) × [R(s,a,s') + γ×V^π(s')]│
    │           ─────────   ──────────────   ─────────────────────── │
    │           sum over    sum over all     immediate   discounted  │
    │           actions     next states      reward    + future value│
    │                                                                │
    │  In plain English:                                            │
    │    "The value of state s under policy π equals the            │
    │     EXPECTED immediate reward plus discounted future value,   │
    │     averaged over all actions and possible next states."      │
    │                                                                │
    │  Simplified (deterministic case):                             │
    │    V(s) = R(s,a) + γ × V(s')                                  │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

There's also a Bellman equation for Q:

```
    ┌────────────────────────────────────────────────────────────────┐
    │           BELLMAN EXPECTATION EQUATION FOR Q^π                 │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  Q^π(s,a) = Σ_s' P(s'|s,a) × [R + γ × Σ_a' π(a'|s') × Q^π(s',a')]│
    │                                                                │
    │  In plain English:                                            │
    │    "The value of action a in state s equals the expected      │
    │     immediate reward plus the discounted expected future Q."  │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Visualize the Bellman equation structure

fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

ax.text(7, 9.5, 'Bellman Equation: Breaking Down V(s)', 
        ha='center', fontsize=16, fontweight='bold')

# Current state
state_circle = Circle((2, 5), 0.8, facecolor='#bbdefb', edgecolor='#1976d2', linewidth=3)
ax.add_patch(state_circle)
ax.text(2, 5, 's', ha='center', va='center', fontsize=18, fontweight='bold')
ax.text(2, 3.8, 'V(s) = ?', ha='center', fontsize=11, color='#1976d2', fontweight='bold')

# Actions
actions = [
    {'name': 'a₁', 'y': 7, 'prob': 'π(a₁|s)'},
    {'name': 'a₂', 'y': 5, 'prob': 'π(a₂|s)'},
    {'name': 'a₃', 'y': 3, 'prob': 'π(a₃|s)'}
]

for action in actions:
    # Action box
    box = FancyBboxPatch((4.5, action['y']-0.4), 1.5, 0.8, boxstyle="round,pad=0.05",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=2)
    ax.add_patch(box)
    ax.text(5.25, action['y'], action['name'], ha='center', va='center', fontsize=12, fontweight='bold')
    
    # Arrow from state to action
    ax.annotate('', xy=(4.4, action['y']), xytext=(2.85, 5),
                arrowprops=dict(arrowstyle='->', lw=1.5, color='#666'))
    ax.text(3.5, (action['y'] + 5)/2 + 0.3, action['prob'], fontsize=9, color='#f57c00')

# Next states (for action a₂)
next_states = [
    {'name': "s'₁", 'y': 6.5, 'prob': "P(s'₁|s,a₂)", 'value': "V(s'₁)"},
    {'name': "s'₂", 'y': 5, 'prob': "P(s'₂|s,a₂)", 'value': "V(s'₂)"},
    {'name': "s'₃", 'y': 3.5, 'prob': "P(s'₃|s,a₂)", 'value': "V(s'₃)"}
]

for ns in next_states:
    # Next state circle
    circle = Circle((9, ns['y']), 0.6, facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
    ax.add_patch(circle)
    ax.text(9, ns['y'], ns['name'], ha='center', va='center', fontsize=11, fontweight='bold')
    ax.text(10, ns['y'], ns['value'], ha='left', fontsize=10, color='#388e3c')
    
    # Arrow from action to next state
    ax.annotate('', xy=(8.35, ns['y']), xytext=(6.1, 5),
                arrowprops=dict(arrowstyle='->', lw=1.5, color='#666'))
    ax.text(7.2, ns['y'] + 0.3, ns['prob'], fontsize=8, color='#388e3c')

# Rewards
ax.text(7, 7.5, 'R', fontsize=11, color='#d32f2f', fontweight='bold')
ax.text(7, 5.5, 'R', fontsize=11, color='#d32f2f', fontweight='bold')
ax.text(7, 4, 'R', fontsize=11, color='#d32f2f', fontweight='bold')

# Equation at bottom
ax.text(7, 1.5, 'V(s) = Σₐ π(a|s) × Σₛ\' P(s\'|s,a) × [R + γ × V(s\')]', 
        ha='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round', facecolor='#e3f2fd', edgecolor='#1976d2'))

ax.text(7, 0.5, '"Average over actions, then average over next states"',
        ha='center', fontsize=11, style='italic', color='#666')

plt.tight_layout()
plt.show()

---
## Building a Grid World MDP

Let's create a concrete example to work with:

In [None]:
class GridWorldMDP:
    """
    4x4 Grid World MDP for demonstrating Bellman equations.
    
    Layout:
        ┌───┬───┬───┬───┐
        │ S │   │   │   │   S = Start (0,0)
        ├───┼───┼───┼───┤
        │   │   │   │   │
        ├───┼───┼───┼───┤
        │   │   │   │   │
        ├───┼───┼───┼───┤
        │   │   │   │ G │   G = Goal (3,3)
        └───┴───┴───┴───┘
    
    Actions: 0=UP, 1=RIGHT, 2=DOWN, 3=LEFT
    Rewards: -1 per step, +10 at goal
    """
    
    def __init__(self):
        self.size = 4
        self.n_states = 16
        self.n_actions = 4
        self.goal = (3, 3)
        self.gamma = 0.9  # Discount factor
        
        # All possible states and actions
        self.states = [(r, c) for r in range(4) for c in range(4)]
        self.actions = [0, 1, 2, 3]  # up, right, down, left
        self.action_names = ['UP', 'RIGHT', 'DOWN', 'LEFT']
        self.action_symbols = ['↑', '→', '↓', '←']
        
        # Build transition model
        self.P = self._build_transitions()
    
    def _next_state(self, state, action):
        """Get next state given current state and action."""
        row, col = state
        
        if action == 0:    # UP
            row = max(0, row - 1)
        elif action == 1:  # RIGHT
            col = min(3, col + 1)
        elif action == 2:  # DOWN
            row = min(3, row + 1)
        elif action == 3:  # LEFT
            col = max(0, col - 1)
        
        return (row, col)
    
    def _build_transitions(self):
        """
        Build the transition probability dictionary.
        
        P[s][a] = list of (probability, next_state, reward) tuples
        """
        P = {}
        
        for s in self.states:
            P[s] = {}
            
            for a in self.actions:
                next_s = self._next_state(s, a)
                reward = 10 if next_s == self.goal else -1
                
                # Deterministic: probability 1.0 of reaching next_s
                P[s][a] = [(1.0, next_s, reward)]
        
        return P


# Create MDP
mdp = GridWorldMDP()

print("GRID WORLD MDP")
print("="*50)
print(f"Grid size: {mdp.size}x{mdp.size}")
print(f"Number of states: {mdp.n_states}")
print(f"Number of actions: {mdp.n_actions}")
print(f"Goal state: {mdp.goal}")
print(f"Discount factor (γ): {mdp.gamma}")
print(f"\nRewards:")
print(f"  -1 for each step")
print(f"  +10 for reaching the goal")
print("="*50)

---
## Policy Evaluation: Computing V^π

**Goal:** Given a policy π, compute V^π(s) for all states.

**Method:** Iteratively apply the Bellman equation until values converge!

```
    ┌────────────────────────────────────────────────────────────────┐
    │                   POLICY EVALUATION                            │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. Initialize V(s) = 0 for all states                        │
    │                                                                │
    │  2. Repeat until convergence:                                 │
    │                                                                │
    │     For each state s:                                         │
    │       V_new(s) = Σ_a π(a|s) × Σ_s' P(s'|s,a) × [R + γ×V(s')] │
    │                                                                │
    │  3. Return V                                                  │
    │                                                                │
    │  Convergence: When max|V_new(s) - V_old(s)| < threshold       │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

**Analogy:** Like a ripple spreading through water - values propagate from the goal backward!

In [None]:
def policy_evaluation(mdp, policy, threshold=1e-6, max_iterations=1000, verbose=False):
    """
    Compute V^π by iteratively applying the Bellman expectation equation.
    
    Args:
        mdp: The MDP
        policy: dict mapping state -> dict of action -> probability
                e.g., policy[s][a] = probability of action a in state s
        threshold: Stop when max value change < threshold
        max_iterations: Maximum iterations
        verbose: Print progress
    
    Returns:
        V: dict mapping state -> value
        history: list of V at each iteration (for visualization)
    """
    # ========================================
    # STEP 1: Initialize all values to zero
    # ========================================
    V = {s: 0.0 for s in mdp.states}
    history = []
    
    for iteration in range(max_iterations):
        delta = 0  # Track maximum change
        
        # ========================================
        # STEP 2: Update each state's value
        # ========================================
        for s in mdp.states:
            if s == mdp.goal:
                continue  # Terminal state has value 0
            
            v_old = V[s]
            
            # BELLMAN EXPECTATION EQUATION:
            # V(s) = Σ_a π(a|s) × Σ_s' P(s'|s,a) × [R + γ×V(s')]
            v_new = 0
            for a in mdp.actions:
                action_prob = policy[s][a]  # π(a|s)
                
                for (trans_prob, next_s, reward) in mdp.P[s][a]:
                    # P(s'|s,a) × [R + γ×V(s')]
                    v_new += action_prob * trans_prob * (reward + mdp.gamma * V[next_s])
            
            V[s] = v_new
            delta = max(delta, abs(v_old - v_new))
        
        # Save history for visualization
        history.append({s: V[s] for s in mdp.states})
        
        # ========================================
        # STEP 3: Check for convergence
        # ========================================
        if delta < threshold:
            if verbose:
                print(f"Converged after {iteration + 1} iterations (Δ < {threshold})")
            break
    
    return V, history


# Create a RANDOM policy (equal probability for all actions)
random_policy = {s: {a: 0.25 for a in mdp.actions} for s in mdp.states}

print("POLICY EVALUATION: Random Policy")
print("="*60)
print("\nPolicy: π(a|s) = 0.25 for all actions (random)")
print("\nRunning policy evaluation...")

V_random, history_random = policy_evaluation(mdp, random_policy, verbose=True)

print("\nValue Function V^π(s):")
print("-"*35)
for row in range(4):
    values = [V_random[(row, col)] for col in range(4)]
    print(" ".join([f"{v:8.2f}" for v in values]))
print("-"*35)
print("(Goal state at bottom-right has value 0)")

In [None]:
# Visualize how values propagate over iterations

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

iterations_to_show = [0, 2, 5, 10, 20, 50, 100, len(history_random)-1]
iterations_to_show = [i for i in iterations_to_show if i < len(history_random)]

# Fill remaining with last iteration
while len(iterations_to_show) < 8:
    iterations_to_show.append(len(history_random)-1)

for idx, ax in enumerate(axes.flat):
    it = iterations_to_show[idx]
    V_it = history_random[it]
    
    # Create grid
    V_grid = np.array([[V_it[(r, c)] for c in range(4)] for r in range(4)])
    
    im = ax.imshow(V_grid, cmap='RdYlGn', vmin=-10, vmax=5)
    
    for i in range(4):
        for j in range(4):
            color = 'white' if V_grid[i, j] < -2 else 'black'
            ax.text(j, i, f'{V_grid[i, j]:.1f}', ha='center', va='center', 
                    fontsize=10, fontweight='bold', color=color)
    
    ax.set_title(f'Iteration {it + 1}', fontsize=12, fontweight='bold')
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Policy Evaluation: Values Propagating Over Iterations\n(Random Policy)', 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Watch how the values 'ripple' backward from the goal!")
print("States closer to the goal get higher values first.")

In [None]:
# Compare random policy vs optimal policy

def create_optimal_policy(mdp):
    """
    Create an optimal policy that always moves toward the goal.
    """
    policy = {}
    
    for s in mdp.states:
        row, col = s
        policy[s] = {a: 0.0 for a in mdp.actions}
        
        # Determine best action to move toward goal (3, 3)
        if col < 3:  # Not at goal column
            best_action = 1  # RIGHT
        elif row < 3:  # At goal column, not at goal row
            best_action = 2  # DOWN
        else:
            best_action = 0  # At goal, any action
        
        policy[s][best_action] = 1.0
    
    return policy


# Create and evaluate optimal policy
optimal_policy = create_optimal_policy(mdp)

print("COMPARING POLICIES")
print("="*60)

print("\n1. RANDOM POLICY (equal probability for all actions):")
print("-"*60)
for row in range(4):
    values = [V_random[(row, col)] for col in range(4)]
    print(" ".join([f"{v:8.2f}" for v in values]))

print("\n2. OPTIMAL POLICY (always move toward goal):")
V_optimal, _ = policy_evaluation(mdp, optimal_policy, verbose=True)
print("-"*60)
for row in range(4):
    values = [V_optimal[(row, col)] for col in range(4)]
    print(" ".join([f"{v:8.2f}" for v in values]))

print("\n" + "="*60)
print("OBSERVATION:")
print(f"  Random policy: V(start) = {V_random[(0,0)]:.2f}")
print(f"  Optimal policy: V(start) = {V_optimal[(0,0)]:.2f}")
print(f"  Difference: {V_optimal[(0,0)] - V_random[(0,0)]:.2f}")
print("\nThe optimal policy gives MUCH higher values!")
print("="*60)

---
## The Bellman Optimality Equation

What if we want to find the **best possible** values, not just for a given policy?

```
    ┌────────────────────────────────────────────────────────────────┐
    │             BELLMAN OPTIMALITY EQUATION                        │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  EXPECTATION (for policy π):                                  │
    │    V^π(s) = Σ_a π(a|s) × [R + γ×V^π(s')]                      │
    │             ───────────                                        │
    │             AVERAGE over actions (weighted by policy)          │
    │                                                                │
    │  OPTIMALITY (for optimal policy π*):                          │
    │    V*(s) = MAX_a [R + γ×V*(s')]                               │
    │            ─────                                               │
    │            MAXIMUM over actions (pick the best!)              │
    │                                                                │
    │  The key difference:                                          │
    │    - Expectation: Average over what π DOES                    │
    │    - Optimality: Maximum over what's POSSIBLE                 │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

**For Q*:**
```
Q*(s,a) = R + γ × max_a' Q*(s', a')
```

This is the equation that Q-learning uses!

In [None]:
# Visualize the difference between expectation and optimality

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Bellman Expectation (averaging)
ax1 = axes[0]
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)
ax1.axis('off')
ax1.set_title('Bellman EXPECTATION\n(Average over policy)', fontsize=14, fontweight='bold')

# State
circle = Circle((2, 5), 0.6, facecolor='#bbdefb', edgecolor='#1976d2', linewidth=2)
ax1.add_patch(circle)
ax1.text(2, 5, 's', ha='center', va='center', fontsize=14, fontweight='bold')

# Actions with values
action_values = [
    {'y': 7, 'q': 5.0, 'prob': 0.25},
    {'y': 5.5, 'q': 8.0, 'prob': 0.25},
    {'y': 4, 'q': 3.0, 'prob': 0.25},
    {'y': 2.5, 'q': 6.0, 'prob': 0.25},
]

for av in action_values:
    box = FancyBboxPatch((5, av['y']-0.3), 2, 0.6, boxstyle="round,pad=0.05",
                          facecolor='#fff3e0', edgecolor='#f57c00', linewidth=1)
    ax1.add_patch(box)
    ax1.text(6, av['y'], f'Q={av["q"]}', ha='center', va='center', fontsize=10)
    ax1.annotate('', xy=(4.9, av['y']), xytext=(2.7, 5),
                 arrowprops=dict(arrowstyle='->', lw=1, color='#666'))
    ax1.text(3.8, av['y'] + 0.3, f'{av["prob"]}', fontsize=9, color='#666')

# Result
ax1.text(5, 1.5, 'V(s) = 0.25×5 + 0.25×8 + 0.25×3 + 0.25×6', ha='center', fontsize=10)
ax1.text(5, 0.8, '= 5.5 (average)', ha='center', fontsize=12, fontweight='bold', color='#1976d2')

# Right: Bellman Optimality (maximum)
ax2 = axes[1]
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.axis('off')
ax2.set_title('Bellman OPTIMALITY\n(Maximum over actions)', fontsize=14, fontweight='bold')

# State
circle = Circle((2, 5), 0.6, facecolor='#c8e6c9', edgecolor='#388e3c', linewidth=2)
ax2.add_patch(circle)
ax2.text(2, 5, 's', ha='center', va='center', fontsize=14, fontweight='bold')

# Actions with values (highlight max)
for i, av in enumerate(action_values):
    is_max = av['q'] == 8.0
    color = '#c8e6c9' if is_max else '#f5f5f5'
    edge = '#388e3c' if is_max else '#999'
    lw = 3 if is_max else 1
    
    box = FancyBboxPatch((5, av['y']-0.3), 2, 0.6, boxstyle="round,pad=0.05",
                          facecolor=color, edgecolor=edge, linewidth=lw)
    ax2.add_patch(box)
    ax2.text(6, av['y'], f'Q={av["q"]}' + (' ★' if is_max else ''), 
             ha='center', va='center', fontsize=10, fontweight='bold' if is_max else 'normal')
    ax2.annotate('', xy=(4.9, av['y']), xytext=(2.7, 5),
                 arrowprops=dict(arrowstyle='->', lw=2 if is_max else 1, 
                                color='#388e3c' if is_max else '#ccc'))

# Result
ax2.text(5, 1.5, 'V*(s) = max(5, 8, 3, 6)', ha='center', fontsize=10)
ax2.text(5, 0.8, '= 8.0 (maximum!)', ha='center', fontsize=12, fontweight='bold', color='#388e3c')

plt.tight_layout()
plt.show()

print("\nThe OPTIMALITY equation finds the BEST possible value!")
print("It assumes we'll always pick the best action.")

---
## Value Iteration: Finding V* Directly

**Idea:** Iteratively apply the Bellman optimality equation to find V*.

```
    ┌────────────────────────────────────────────────────────────────┐
    │                   VALUE ITERATION                              │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  1. Initialize V(s) = 0 for all states                        │
    │                                                                │
    │  2. Repeat until convergence:                                 │
    │                                                                │
    │     For each state s:                                         │
    │       V(s) = MAX_a Σ_s' P(s'|s,a) × [R + γ×V(s')]             │
    │              ─────                                             │
    │              Take the BEST action's value!                    │
    │                                                                │
    │  3. Extract optimal policy:                                   │
    │       π*(s) = argmax_a Σ_s' P(s'|s,a) × [R + γ×V*(s')]        │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
def value_iteration(mdp, threshold=1e-6, max_iterations=1000, verbose=False):
    """
    Find V* using value iteration.
    
    Returns:
        V: Optimal value function
        policy: Optimal policy
        history: Value function at each iteration
    """
    # ========================================
    # STEP 1: Initialize values to zero
    # ========================================
    V = {s: 0.0 for s in mdp.states}
    history = []
    
    for iteration in range(max_iterations):
        delta = 0
        
        # ========================================
        # STEP 2: Bellman OPTIMALITY update
        # ========================================
        for s in mdp.states:
            if s == mdp.goal:
                continue
            
            v_old = V[s]
            
            # Compute Q(s, a) for each action
            action_values = []
            for a in mdp.actions:
                q_value = 0
                for (prob, next_s, reward) in mdp.P[s][a]:
                    q_value += prob * (reward + mdp.gamma * V[next_s])
                action_values.append(q_value)
            
            # V(s) = MAX over actions (Bellman optimality!)
            V[s] = max(action_values)
            delta = max(delta, abs(v_old - V[s]))
        
        history.append({s: V[s] for s in mdp.states})
        
        if delta < threshold:
            if verbose:
                print(f"Converged after {iteration + 1} iterations")
            break
    
    # ========================================
    # STEP 3: Extract optimal policy
    # ========================================
    policy = {}
    for s in mdp.states:
        if s == mdp.goal:
            policy[s] = 0  # Doesn't matter at goal
            continue
        
        # Find action with highest Q-value
        action_values = []
        for a in mdp.actions:
            q_value = 0
            for (prob, next_s, reward) in mdp.P[s][a]:
                q_value += prob * (reward + mdp.gamma * V[next_s])
            action_values.append(q_value)
        
        policy[s] = np.argmax(action_values)
    
    return V, policy, history


# Run value iteration
print("VALUE ITERATION")
print("="*60)
print("\nFinding optimal value function V* and policy π*...")

V_star, pi_star, history_vi = value_iteration(mdp, verbose=True)

print("\nOptimal Value Function V*:")
print("-"*35)
for row in range(4):
    values = [V_star[(row, col)] for col in range(4)]
    print(" ".join([f"{v:8.2f}" for v in values]))

print("\nOptimal Policy π*:")
print("-"*35)
for row in range(4):
    actions = [mdp.action_symbols[pi_star[(row, col)]] for col in range(4)]
    print("    ".join(actions))
print("-"*35)
print("(Arrows show the best action in each state)")

In [None]:
# Visualize V* and π* together

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Value function heatmap
ax1 = axes[0]
V_grid = np.array([[V_star[(r, c)] for c in range(4)] for r in range(4)])

im = ax1.imshow(V_grid, cmap='RdYlGn')
for i in range(4):
    for j in range(4):
        color = 'white' if V_grid[i, j] < np.mean(V_grid) else 'black'
        ax1.text(j, i, f'{V_grid[i, j]:.1f}', ha='center', va='center', 
                 fontsize=12, fontweight='bold', color=color)

ax1.set_title('Optimal Value Function V*', fontsize=14, fontweight='bold')
ax1.set_xticks(range(4))
ax1.set_yticks(range(4))
plt.colorbar(im, ax=ax1, label='Value')

# Right: Policy visualization
ax2 = axes[1]

# Draw grid
for row in range(4):
    for col in range(4):
        color = '#c8e6c9' if (row, col) == mdp.goal else '#e3f2fd' if (row, col) == (0, 0) else 'white'
        rect = Rectangle((col, 3 - row), 1, 1, facecolor=color, edgecolor='black', linewidth=2)
        ax2.add_patch(rect)

# Draw arrows for policy
arrow_dx = [0, 0.35, 0, -0.35]  # up, right, down, left
arrow_dy = [0.35, 0, -0.35, 0]

for row in range(4):
    for col in range(4):
        if (row, col) == mdp.goal:
            ax2.text(col + 0.5, 3 - row + 0.5, 'GOAL', ha='center', va='center', 
                     fontsize=10, fontweight='bold', color='#388e3c')
        else:
            a = pi_star[(row, col)]
            cx, cy = col + 0.5, 3 - row + 0.5
            ax2.arrow(cx - arrow_dx[a]/2, cy - arrow_dy[a]/2, 
                      arrow_dx[a], arrow_dy[a],
                      head_width=0.15, head_length=0.1, 
                      fc='#f44336', ec='#f44336', linewidth=2)

ax2.set_xlim(0, 4)
ax2.set_ylim(0, 4)
ax2.set_aspect('equal')
ax2.axis('off')
ax2.set_title('Optimal Policy π*\n(Arrows = Best Actions)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("\nThe optimal policy always takes the shortest path to the goal!")

---
## Why Bellman Equations Matter

The Bellman equations are the foundation of virtually ALL RL algorithms:

```
    ┌────────────────────────────────────────────────────────────────┐
    │          BELLMAN EQUATIONS → RL ALGORITHMS                     │
    ├────────────────────────────────────────────────────────────────┤
    │                                                                │
    │  ALGORITHM                 USES WHICH BELLMAN EQUATION?       │
    │  ─────────────────────────────────────────────────────        │
    │                                                                │
    │  Policy Evaluation         Bellman Expectation for V^π        │
    │  Policy Iteration          Bellman Expectation + greedy       │
    │  Value Iteration           Bellman Optimality for V*          │
    │                                                                │
    │  TD Learning               Bellman Expectation (sampled)      │
    │  Q-Learning                Bellman Optimality for Q*          │
    │  SARSA                     Bellman Expectation for Q^π        │
    │                                                                │
    │  DQN                       Bellman Optimality for Q*          │
    │  Actor-Critic              Bellman Expectation (critic)       │
    │                                                                │
    │  The Bellman equations give us a way to break down the        │
    │  value of a state into immediate + future components.         │
    │  This RECURSIVE structure makes RL tractable!                 │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘
```

In [None]:
# Summary comparison: Random vs Optimal policies

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Random policy values
ax1 = axes[0]
V_rand_grid = np.array([[V_random[(r, c)] for c in range(4)] for r in range(4)])
im1 = ax1.imshow(V_rand_grid, cmap='RdYlGn', vmin=-5, vmax=10)
for i in range(4):
    for j in range(4):
        ax1.text(j, i, f'{V_rand_grid[i, j]:.1f}', ha='center', va='center', 
                 fontsize=11, fontweight='bold')
ax1.set_title('V^π (Random Policy)\nAverage value', fontsize=12, fontweight='bold')
ax1.set_xticks([])
ax1.set_yticks([])

# Optimal policy values
ax2 = axes[1]
V_opt_grid = np.array([[V_optimal[(r, c)] for c in range(4)] for r in range(4)])
im2 = ax2.imshow(V_opt_grid, cmap='RdYlGn', vmin=-5, vmax=10)
for i in range(4):
    for j in range(4):
        ax2.text(j, i, f'{V_opt_grid[i, j]:.1f}', ha='center', va='center', 
                 fontsize=11, fontweight='bold')
ax2.set_title('V^π* (Optimal Policy)\nBetter values!', fontsize=12, fontweight='bold')
ax2.set_xticks([])
ax2.set_yticks([])

# Difference
ax3 = axes[2]
diff_grid = V_opt_grid - V_rand_grid
im3 = ax3.imshow(diff_grid, cmap='Blues')
for i in range(4):
    for j in range(4):
        ax3.text(j, i, f'+{diff_grid[i, j]:.1f}', ha='center', va='center', 
                 fontsize=11, fontweight='bold')
ax3.set_title('Improvement\n(Optimal - Random)', fontsize=12, fontweight='bold')
ax3.set_xticks([])
ax3.set_yticks([])

plt.tight_layout()
plt.show()

print("\nThe optimal policy provides MUCH higher values everywhere!")
print(f"At start (0,0): Random = {V_random[(0,0)]:.2f}, Optimal = {V_optimal[(0,0)]:.2f}")
print(f"Improvement: +{V_optimal[(0,0)] - V_random[(0,0)]:.2f}")

---
## Summary: Key Takeaways

### The Core Insight

**Value is recursive:** V(s) = Immediate reward + Discounted future value

### Two Types of Bellman Equations

| Type | Equation | Use |
|------|----------|-----|
| **Expectation** | V^π(s) = Σ_a π(a\|s) × [R + γV(s')] | Evaluate a given policy |
| **Optimality** | V*(s) = max_a [R + γV*(s')] | Find the best policy |

### Two Fundamental Algorithms

| Algorithm | What it does | Uses |
|-----------|--------------|------|
| **Policy Evaluation** | Compute V^π for policy π | Bellman Expectation |
| **Value Iteration** | Find V* directly | Bellman Optimality |

### Why This Matters

- ALL RL algorithms are built on Bellman equations
- The recursive structure makes RL tractable
- Q-learning uses the Bellman optimality equation for Q*
- TD learning uses sampled Bellman equations

---
## Test Your Understanding

**1. What is the key insight of the Bellman equation?**
<details>
<summary>Click to reveal answer</summary>
The value of a state can be decomposed into immediate reward plus discounted future value: V(s) = R + γV(s'). This recursive structure breaks a complex problem (total future reward) into simpler sub-problems.
</details>

**2. What's the difference between Bellman expectation and optimality equations?**
<details>
<summary>Click to reveal answer</summary>
- Expectation: V^π(s) = Σ_a π(a|s) × [R + γV(s')] - AVERAGES over actions according to policy π
- Optimality: V*(s) = max_a [R + γV*(s')] - Takes the MAXIMUM over actions

Expectation evaluates a given policy; optimality finds the best possible value.
</details>

**3. How does policy evaluation work?**
<details>
<summary>Click to reveal answer</summary>
1. Initialize V(s) = 0 for all states
2. Repeatedly apply Bellman expectation: V(s) = Σ_a π(a|s) × [R + γV(s')]
3. Stop when values converge (change < threshold)

It finds V^π by iteratively "propagating" values backward from terminal states.
</details>

**4. How does value iteration find V*?**
<details>
<summary>Click to reveal answer</summary>
1. Initialize V(s) = 0 for all states
2. Repeatedly apply Bellman optimality: V(s) = max_a [R + γV(s')]
3. Stop when values converge
4. Extract policy: π*(s) = argmax_a [R + γV*(s')]

It finds V* directly without needing a policy, then derives the optimal policy from V*.
</details>

**5. Which RL algorithms use Bellman equations?**
<details>
<summary>Click to reveal answer</summary>
Virtually ALL of them!
- Q-learning: Bellman optimality for Q*
- SARSA: Bellman expectation for Q^π
- TD learning: Sampled Bellman expectation
- DQN: Bellman optimality for Q* (with neural networks)
- Actor-Critic: Critic uses Bellman expectation

The Bellman equations are the mathematical foundation of RL.
</details>

---
## Congratulations!

You've completed the **RL Fundamentals** section! You now understand:

- The RL paradigm and agent-environment interaction
- MDPs as the mathematical framework
- Rewards, returns, and discounting
- Policies and value functions (V and Q)
- **Bellman equations** - the foundation of all RL

**Next:** Move on to [Classic Algorithms](../classic-algorithms/) to learn Q-learning, SARSA, and Monte Carlo methods!

---

*You now have the mathematical foundation to understand ANY RL algorithm!*