# üéì Chapter 3: Estimating Value Functions ‚Äî Hands-On Lab

**Complete Reinforcement Learning Journey: From Basics to RLHF**

In this notebook, you will:
1. **Implement** Policy Evaluation (iterative Bellman updates)
2. **Implement** Policy Iteration (evaluate ‚Üí improve loop)
3. **Implement** Value Iteration (single-step improvement)
4. **Visualize** convergence: watch V(s) evolve sweep by sweep
5. **Compare** PI vs VI across 7 environments
6. **Experiment** with Œ≥, stochasticity, and grid layouts

---

üìò **Companion to Chapter 3 of the book**  
üîó Interactive web app: [Policy Iteration Visualizer](https://mlnjsh.github.io/rl-book-labs/ch3/)  
üîó GitHub: [github.com/mlnjsh/rl-book-labs](https://github.com/mlnjsh/rl-book-labs)

## üì¶ Install & Import Libraries

In [None]:
# Install required libraries
!pip install gymnasium numpy matplotlib seaborn pandas --quiet

# ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
# ‚îÇ Libraries used in this notebook:                        ‚îÇ
# ‚îÇ   gymnasium    - RL environments (FrozenLake)           ‚îÇ
# ‚îÇ   numpy        - numerical computation                  ‚îÇ
# ‚îÇ   matplotlib   - plotting and visualization             ‚îÇ
# ‚îÇ   seaborn      - heatmaps                               ‚îÇ
# ‚îÇ   pandas       - data tables                            ‚îÇ
# ‚îÇ   time         - measuring convergence speed            ‚îÇ
# ‚îÇ   IPython      - display utilities (built-in)           ‚îÇ
# ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML, clear_output
import gymnasium as gym
import time
import warnings
warnings.filterwarnings('ignore')

plt.rcParams['figure.facecolor'] = '#0f172a'
plt.rcParams['axes.facecolor'] = '#1e293b'
plt.rcParams['text.color'] = '#e2e8f0'
plt.rcParams['axes.labelcolor'] = '#e2e8f0'
plt.rcParams['xtick.color'] = '#94a3b8'
plt.rcParams['ytick.color'] = '#94a3b8'
plt.rcParams['font.family'] = 'monospace'

print("‚úÖ All libraries loaded!")

## üèóÔ∏è Environment Setup

We reuse the MDP framework from Chapter 2 and add all 7 environments.

In [None]:
# ==================== MDP BASE ====================
class MDP:
    def __init__(self, states, actions, transitions, rewards, gamma=0.9):
        self.states = states
        self.actions = actions
        self.transitions = transitions
        self.rewards = rewards
        self.gamma = gamma

    def get_transitions(self, s, a):
        results = []
        for prob, ns in self.transitions.get((s, a), []):
            reward = self.rewards.get((s, a, ns), 0)
            results.append((prob, ns, reward))
        return results

# ==================== GRIDWORLD 5x5 ====================
class GridWorld(MDP):
    DELTAS = {0:(0,-1), 1:(1,0), 2:(0,1), 3:(-1,0)}
    ARROWS = {0:'‚Üê', 1:'‚Üì', 2:'‚Üí', 3:'‚Üë'}

    def __init__(self, gamma=0.9, slip=0.0):
        self.size=5; self.walls={(1,1),(1,3),(3,1)}; self.goal=(4,4); self.pit=(3,4); self.start=(0,0); self.slip=slip
        states=[(r,c) for r in range(5) for c in range(5) if (r,c) not in self.walls]
        am,tr,rw={},{},{}
        for s in states:
            if s==self.goal or s==self.pit: am[s]=[]; continue
            am[s]=[0,1,2,3]
            for a in range(4):
                tl=[]
                for aa in range(4):
                    if slip==0:
                        if aa!=a: continue
                        p=1.0
                    else: p=(1-slip+slip/4) if aa==a else slip/4
                    if p<1e-9: continue
                    dr,dc=self.DELTAS[aa]; nr,nc=s[0]+dr,s[1]+dc
                    if nr<0 or nr>=5 or nc<0 or nc>=5 or (nr,nc) in self.walls: nr,nc=s
                    ns=(nr,nc)
                    found=False
                    for i,(pp,ens) in enumerate(tl):
                        if ens==ns: tl[i]=(pp+p,ns); found=True; break
                    if not found: tl.append((p,ns))
                    rw[(s,a,ns)] = 10.0 if ns==self.goal else (-10.0 if ns==self.pit else -0.1)
                tr[(s,a)]=tl
        super().__init__(states,am,tr,rw,gamma)

# ==================== FROZENLAKE 4x4 ====================
class FrozenLakeMDP(MDP):
    MAP=[['S','F','F','F'],['F','H','F','H'],['F','F','F','H'],['H','F','F','G']]
    DELTAS={0:(0,-1),1:(1,0),2:(0,1),3:(-1,0)}; ARROWS={0:'‚Üê',1:'‚Üì',2:'‚Üí',3:'‚Üë'}
    def __init__(self, gamma=0.95, is_slippery=True):
        self.size=4; self.is_slippery=is_slippery
        self.holes=set(); self.goal=None; self.start=None
        for r in range(4):
            for c in range(4):
                if self.MAP[r][c]=='H': self.holes.add((r,c))
                elif self.MAP[r][c]=='G': self.goal=(r,c)
                elif self.MAP[r][c]=='S': self.start=(r,c)
        states=[(r,c) for r in range(4) for c in range(4)]
        terminals=self.holes|{self.goal}; am,tr,rw={},{},{}
        for s in states:
            if s in terminals: am[s]=[]; continue
            am[s]=[0,1,2,3]
            for a in range(4):
                tl=[]; possible=[(a-1)%4,a,(a+1)%4] if is_slippery else [a]
                for aa in possible:
                    p=1/3 if is_slippery else 1.0
                    dr,dc=self.DELTAS[aa]; nr,nc=s[0]+dr,s[1]+dc
                    if nr<0 or nr>=4 or nc<0 or nc>=4: nr,nc=s
                    ns=(nr,nc)
                    found=False
                    for i,(pp,ens) in enumerate(tl):
                        if ens==ns: tl[i]=(pp+p,ns); found=True; break
                    if not found: tl.append((p,ns))
                    rw[(s,a,ns)]=1.0 if ns==self.goal else (-1.0 if ns in self.holes else -0.01)
                tr[(s,a)]=tl
        super().__init__(states,am,tr,rw,gamma)

# ==================== TRAFFIC, THERMOSTAT, BANDIT, INVENTORY, ROBOT ====================
class TrafficLightMDP(MDP):
    def __init__(self, gamma=0.9):
        S=[(t,p) for t in ['low','medium','high'] for p in ['green_NS','green_EW']]
        am={s:['keep','switch'] for s in S}; tr={}; rw={}
        for t in ['low','medium','high']:
            for ph in ['green_NS','green_EW']:
                s=(t,ph)
                for a in ['keep','switch']:
                    np2=ph if a=='keep' else('green_EW' if ph=='green_NS' else 'green_NS')
                    if t=='low': tr[(s,a)]=[(0.7,('low',np2)),(0.3,('medium',np2))]; rw[(s,a,('low',np2))]=1.0; rw[(s,a,('medium',np2))]=0.0
                    elif t=='medium':
                        if a=='switch': tr[(s,a)]=[(0.4,('low',np2)),(0.5,('medium',np2)),(0.1,('high',np2))]
                        else: tr[(s,a)]=[(0.2,('low',np2)),(0.4,('medium',np2)),(0.4,('high',np2))]
                        rw[(s,a,('low',np2))]=1.0; rw[(s,a,('medium',np2))]=-0.5; rw[(s,a,('high',np2))]=-2.0
                    else:
                        if a=='switch': tr[(s,a)]=[(0.3,('medium',np2)),(0.5,('high',np2)),(0.2,('low',np2))]
                        else: tr[(s,a)]=[(0.1,('medium',np2)),(0.9,('high',np2))]
                        rw[(s,a,('medium',np2))]=-0.5; rw[(s,a,('high',np2))]=-3.0; rw[(s,a,('low',np2))]=1.0
        super().__init__(S,am,tr,rw,gamma)

class ThermostatMDP(MDP):
    def __init__(self, gamma=0.9):
        S=['cold','comfortable','hot']; am={s:['heat','cool','off'] for s in S}; tr={}; rw={}
        tr[('cold','heat')]=[(0.8,'comfortable'),(0.2,'cold')]
        tr[('cold','cool')]=[(0.95,'cold'),(0.05,'comfortable')]
        tr[('cold','off')]=[(0.7,'cold'),(0.3,'comfortable')]
        tr[('comfortable','heat')]=[(0.6,'comfortable'),(0.4,'hot')]
        tr[('comfortable','cool')]=[(0.6,'comfortable'),(0.4,'cold')]
        tr[('comfortable','off')]=[(0.8,'comfortable'),(0.1,'cold'),(0.1,'hot')]
        tr[('hot','heat')]=[(0.2,'comfortable'),(0.8,'hot')]
        tr[('hot','cool')]=[(0.8,'comfortable'),(0.2,'hot')]
        tr[('hot','off')]=[(0.3,'comfortable'),(0.7,'hot')]
        for s in S:
            for a in ['heat','cool','off']:
                for p,ns in tr[(s,a)]:
                    r=2.0 if ns=='comfortable' else -1.0
                    if a in ['heat','cool']: r-=0.5
                    rw[(s,a,ns)]=r
        super().__init__(S,am,tr,rw,gamma)

class BanditMDP(MDP):
    def __init__(self, gamma=0.9):
        S=['morning','afternoon','evening']; am={s:['A','B','C'] for s in S}; tr={}; rw={}
        wp={('morning','A'):0.7,('morning','B'):0.3,('morning','C'):0.5,('afternoon','A'):0.4,('afternoon','B'):0.6,('afternoon','C'):0.5,('evening','A'):0.2,('evening','B'):0.5,('evening','C'):0.8}
        nx={'morning':'afternoon','afternoon':'evening','evening':'morning'}
        for s in S:
            for a in ['A','B','C']:
                ns=nx[s]; tr[(s,a)]=[(1.0,ns)]; rw[(s,a,ns)]=wp[(s,a)]*10-(1-wp[(s,a)])*2
        super().__init__(S,am,tr,rw,gamma)

class InventoryMDP(MDP):
    def __init__(self, gamma=0.9):
        S=list(range(5)); am={s:['o0','o1','o2'] for s in S}; tr={}; rw={}
        dp={0:0.3,1:0.5,2:0.2}
        for stk in S:
            for a in ['o0','o1','o2']:
                oq=int(a[1]); ao=min(stk+oq,4); tl=[]
                for d,dpr in dp.items():
                    sold=min(d,ao); unmet=d-sold; ns2=ao-sold
                    r=sold*3-oq*1-ns2*0.5+unmet*(-4)
                    found=False
                    for i,(pp,ens) in enumerate(tl):
                        if ens==ns2: tl[i]=(pp+dpr,ns2); found=True; break
                    if not found: tl.append((dpr,ns2))
                    rw[(stk,a,ns2)]=r
                tr[(stk,a)]=tl
        super().__init__(S,am,tr,rw,gamma)

class RobotRoomsMDP(MDP):
    def __init__(self, gamma=0.9, lock_prob=0.2):
        S=['A','B','C','D']; nb={'A':['B','C'],'B':['A','D'],'C':['A','D'],'D':['B','C']}
        am={}; tr={}; rw={}
        for s in S:
            acts=[f'go_{n}' for n in nb[s]]+['stay']; am[s]=acts
            for a in acts:
                if a=='stay':
                    tr[(s,a)]=[(1.0,s)]; rw[(s,a,s)]=1.0 if s=='C' else -0.1
                else:
                    tgt=a.split('_')[1]
                    tr[(s,a)]=[(1-lock_prob,tgt),(lock_prob,s)]
                    rw[(s,a,tgt)]=10.0 if tgt=='D' else -0.5; rw[(s,a,s)]=-0.5
        super().__init__(S,am,tr,rw,gamma)

print("‚úÖ All 7 environments loaded!")

---
## üìê Algorithm 1: Policy Evaluation (Iterative)

Given a policy œÄ, compute V^œÄ(s) by repeatedly applying the Bellman expectation equation:

$$V_{k+1}(s) = \sum_{s'} P(s'|s, \pi(s)) \left[ R(s,\pi(s),s') + \gamma V_k(s') \right]$$

In [None]:
def policy_evaluation(mdp, policy, theta=1e-8, max_iter=1000, track_history=False):
    """Iterative policy evaluation.
    Returns: V dict, and optionally history of V at each sweep."""
    V = {s: 0.0 for s in mdp.states}
    history = [V.copy()] if track_history else None
    deltas = []

    for k in range(max_iter):
        delta = 0
        for s in mdp.states:
            if s not in policy or not mdp.actions.get(s):
                continue
            a = policy[s]
            new_v = sum(p * (r + mdp.gamma * V[ns]) for p, ns, r in mdp.get_transitions(s, a))
            delta = max(delta, abs(V[s] - new_v))
            V[s] = new_v

        deltas.append(delta)
        if track_history:
            history.append(V.copy())
        if delta < theta:
            break

    return V, deltas, history

# Demo: evaluate a random policy on GridWorld
gw = GridWorld(gamma=0.9, slip=0.0)
random_policy = {s: np.random.choice(gw.actions[s]) for s in gw.states if gw.actions.get(s)}

V, deltas, hist = policy_evaluation(gw, random_policy, track_history=True)

print(f"Converged in {len(deltas)} sweeps")
print(f"Final max delta: {deltas[-1]:.2e}")

# Plot convergence
fig, ax = plt.subplots(figsize=(10, 4))
ax.semilogy(deltas, color='#22d3ee', linewidth=2)
ax.axhline(y=1e-8, color='#ef4444', linestyle='--', alpha=0.5, label='Œ∏ threshold')
ax.set_xlabel('Sweep'); ax.set_ylabel('Max |ŒîV|')
ax.set_title('Policy Evaluation Convergence', fontweight='bold', color='#22d3ee')
ax.legend(); ax.grid(alpha=0.1)
plt.tight_layout(); plt.show()

### üé¨ Animated Sweep-by-Sweep Convergence

Watch the value function evolve as policy evaluation runs!

In [None]:
def show_gridworld_values(gw, V, policy=None, title=""):
    """Quick grid visualization."""
    fig, ax = plt.subplots(figsize=(6,6))
    ax.set_xlim(-0.5,4.5); ax.set_ylim(-0.5,4.5); ax.set_aspect('equal'); ax.invert_yaxis()
    all_v = [V.get(s,0) for s in gw.states if s not in gw.walls and s!=gw.goal and s!=gw.pit]
    mn,mx = (min(all_v),max(all_v)) if all_v else (0,1)
    arrow_map={0:(0,-0.3),1:(0.3,0),2:(0,0.3),3:(-0.3,0)}
    for r in range(5):
        for c in range(5):
            if (r,c) in gw.walls:
                ax.add_patch(plt.Rectangle((c-0.5,r-0.5),1,1,color='#334155')); ax.text(c,r,'üß±',ha='center',va='center',fontsize=16)
            elif (r,c)==gw.goal:
                ax.add_patch(plt.Rectangle((c-0.5,r-0.5),1,1,color='#064e3b',alpha=0.5)); ax.text(c,r,'üèÜ',ha='center',va='center',fontsize=16)
            elif (r,c)==gw.pit:
                ax.add_patch(plt.Rectangle((c-0.5,r-0.5),1,1,color='#450a0a',alpha=0.5)); ax.text(c,r,'üï≥Ô∏è',ha='center',va='center',fontsize=16)
            else:
                v=V.get((r,c),0); t=(v-mn)/(mx-mn+1e-8)
                ax.add_patch(plt.Rectangle((c-0.5,r-0.5),1,1,color=plt.cm.RdYlGn(t),alpha=0.4))
                ax.text(c,r+0.3,f"{v:.2f}",ha='center',va='center',fontsize=8,color='white',fontweight='bold')
                if policy and (r,c) in policy:
                    dy,dx=arrow_map[policy[(r,c)]]
                    ax.annotate('',xy=(c+dx,r+dy),xytext=(c,r),arrowprops=dict(arrowstyle='->',color='#22d3ee',lw=2))
            ax.add_patch(plt.Rectangle((c-0.5,r-0.5),1,1,fill=False,edgecolor='#334155',lw=0.3))
    ax.set_title(title,fontsize=12,fontweight='bold',color='#22d3ee',pad=8)
    ax.set_xticks([]); ax.set_yticks([])
    for s in ax.spines.values(): s.set_visible(False)
    plt.tight_layout(); plt.show()

# Show V at sweeps 0, 1, 5, 20, final
sweep_indices = [0, 1, 5, 20, len(hist)-1]
for idx in sweep_indices:
    if idx < len(hist):
        show_gridworld_values(gw, hist[idx], title=f"V after sweep {idx}")

---
## üìê Algorithm 2: Policy Iteration

1. **Evaluate** current policy ‚Üí V^œÄ
2. **Improve** policy ‚Üí greedy w.r.t. V^œÄ
3. Repeat until stable

In [None]:
def policy_iteration(mdp, verbose=True):
    """Full Policy Iteration. Returns V*, œÄ*, and tracking info."""
    # Initialize random policy
    policy = {s: mdp.actions[s][0] for s in mdp.states if mdp.actions.get(s)}
    V = {s: 0.0 for s in mdp.states}
    track = []

    for iteration in range(100):
        # Evaluate
        V, deltas, _ = policy_evaluation(mdp, policy)
        eval_sweeps = len(deltas)

        # Improve
        changed = 0
        new_policy = {}
        for s in mdp.states:
            if not mdp.actions.get(s): continue
            best_a, best_v = None, -float('inf')
            for a in mdp.actions[s]:
                q = sum(p*(r+mdp.gamma*V[ns]) for p,ns,r in mdp.get_transitions(s,a))
                if q > best_v: best_v=q; best_a=a
            if best_a != policy.get(s): changed += 1
            new_policy[s] = best_a

        track.append({'iter': iteration+1, 'eval_sweeps': eval_sweeps, 'changes': changed})

        if verbose:
            print(f"  Iter {iteration+1}: {eval_sweeps} eval sweeps, {changed} policy changes")

        if changed == 0:
            if verbose: print(f"  ‚úÖ Converged in {iteration+1} iterations!")
            break
        policy = new_policy

    return V, policy, track

# Run on GridWorld
print("üåç Policy Iteration on GridWorld 5√ó5")
print("="*50)
gw = GridWorld(gamma=0.9, slip=0.0)
V_pi, pi_pi, track_pi = policy_iteration(gw)
show_gridworld_values(gw, V_pi, pi_pi, "GridWorld ‚Äî Optimal Policy (PI)")

---
## üìê Algorithm 3: Value Iteration

Combines evaluation and improvement in a single step:

$$V_{k+1}(s) = \max_a \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V_k(s')]$$

In [None]:
def value_iteration(mdp, theta=1e-8, max_iter=1000, verbose=True):
    """Value Iteration. Returns V*, œÄ*."""
    V = {s: 0.0 for s in mdp.states}
    deltas = []

    for k in range(max_iter):
        delta = 0
        for s in mdp.states:
            if not mdp.actions.get(s): continue
            old_v = V[s]
            V[s] = max(
                sum(p*(r+mdp.gamma*V[ns]) for p,ns,r in mdp.get_transitions(s,a))
                for a in mdp.actions[s]
            )
            delta = max(delta, abs(old_v - V[s]))
        deltas.append(delta)
        if delta < theta:
            break

    # Extract policy
    policy = {}
    for s in mdp.states:
        if not mdp.actions.get(s): continue
        best_a, best_v = None, -float('inf')
        for a in mdp.actions[s]:
            q = sum(p*(r+mdp.gamma*V[ns]) for p,ns,r in mdp.get_transitions(s,a))
            if q > best_v: best_v=q; best_a=a
        policy[s] = best_a

    if verbose:
        print(f"  ‚úÖ Value Iteration converged in {len(deltas)} sweeps")
    return V, policy, deltas

# Run on GridWorld
print("üåç Value Iteration on GridWorld 5√ó5")
print("="*50)
gw = GridWorld(gamma=0.9, slip=0.0)
V_vi, pi_vi, deltas_vi = value_iteration(gw)
show_gridworld_values(gw, V_vi, pi_vi, "GridWorld ‚Äî Optimal Policy (VI)")

---
## üìä Policy Iteration vs Value Iteration ‚Äî All Environments

In [None]:
envs = {
    'GridWorld (det)': GridWorld(0.9, 0.0),
    'GridWorld (sto)': GridWorld(0.9, 0.2),
    'FrozenLake (slip)': FrozenLakeMDP(0.95, True),
    'FrozenLake (det)': FrozenLakeMDP(0.95, False),
    'Traffic Light': TrafficLightMDP(0.9),
    'Thermostat': ThermostatMDP(0.9),
    'Bandit': BanditMDP(0.9),
    'Inventory': InventoryMDP(0.9),
    'Robot Rooms': RobotRoomsMDP(0.9, 0.2),
}

results = []
for name, mdp in envs.items():
    # PI
    t0 = time.time()
    V_pi, pi_pi, track = policy_iteration(mdp, verbose=False)
    pi_time = time.time() - t0
    pi_iters = len(track)

    # VI
    t0 = time.time()
    V_vi, pi_vi, deltas = value_iteration(mdp, verbose=False)
    vi_time = time.time() - t0
    vi_sweeps = len(deltas)

    # Check same policy
    same = all(pi_pi.get(s) == pi_vi.get(s) for s in pi_pi)

    results.append({
        'Environment': name,
        '|S|': len(mdp.states),
        'PI Iters': pi_iters,
        'PI Time (ms)': f"{pi_time*1000:.1f}",
        'VI Sweeps': vi_sweeps,
        'VI Time (ms)': f"{vi_time*1000:.1f}",
        'Same œÄ*?': '‚úÖ' if same else '‚ùå'
    })

df = pd.DataFrame(results)
print("\nüìä Policy Iteration vs Value Iteration ‚Äî Comparison")
print("="*80)
display(df)
print("\nüí° Both always find the same optimal policy!")
print("   PI uses fewer iterations but each iteration has a full evaluation phase.")
print("   VI uses more sweeps but each sweep is simpler (just one Bellman max).")

---
## üî¨ Experiment: How Œ≥ Affects Convergence

In [None]:
gammas = [0.1, 0.5, 0.9, 0.95, 0.99]
fig, ax = plt.subplots(figsize=(10, 5))

for g in gammas:
    gw = GridWorld(gamma=g, slip=0.0)
    _, _, deltas = value_iteration(gw, verbose=False)
    ax.semilogy(deltas, label=f'Œ≥={g}', linewidth=2)

ax.set_xlabel('Sweep'); ax.set_ylabel('Max |ŒîV|')
ax.set_title('Value Iteration Convergence vs Œ≥', fontweight='bold', color='#22d3ee')
ax.legend(); ax.grid(alpha=0.1)
plt.tight_layout(); plt.show()

print("üí° Higher Œ≥ ‚Üí slower convergence! The agent must propagate values across more steps.")
print("   Œ≥=0.99 takes ~4x more sweeps than Œ≥=0.1")

---
## üî¨ Experiment: Deterministic vs Stochastic Optimal Policies

In [None]:
slips = [0.0, 0.1, 0.2, 0.3, 0.4]

for s in slips:
    gw = GridWorld(gamma=0.9, slip=s)
    V, pi, _ = value_iteration(gw, verbose=False)
    show_gridworld_values(gw, V, pi, f"Optimal Policy ‚Äî slip={s:.1f}")

print("üí° As slip increases:")
print("   - Policy becomes more cautious near pits")
print("   - Values decrease (less certain about reaching goal)")
print("   - Some cells change arrow direction to avoid risky paths")

---
## üìä Q-Value Analysis for All Environments

In [None]:
def compute_q_values(mdp, V):
    Q = {}
    for s in mdp.states:
        for a in mdp.actions.get(s, []):
            Q[(s,a)] = sum(p*(r+mdp.gamma*V[ns]) for p,ns,r in mdp.get_transitions(s,a))
    return Q

# Q-values for Thermostat
print("üå°Ô∏è Thermostat Q-Values")
thermo = ThermostatMDP(0.9)
V_th, pi_th, _ = value_iteration(thermo, verbose=False)
Q_th = compute_q_values(thermo, V_th)

rows = []
for s in ['cold','comfortable','hot']:
    row = {'State': s}
    for a in ['heat','cool','off']:
        q = Q_th.get((s,a), 0)
        best = q == max(Q_th.get((s,a2),0) for a2 in ['heat','cool','off'])
        row[a] = f"{q:.2f}" + (" ‚òÖ" if best else "")
    rows.append(row)
display(pd.DataFrame(rows).set_index('State'))

print(f"\nOptimal: cold‚Üí{pi_th['cold']}, comfortable‚Üí{pi_th['comfortable']}, hot‚Üí{pi_th['hot']}")

# Q-values for GridWorld start state
print("\nüåç GridWorld Q-Values at START (0,0)")
gw = GridWorld(0.9, 0.0)
V_gw, pi_gw, _ = value_iteration(gw, verbose=False)
Q_gw = compute_q_values(gw, V_gw)

for a in range(4):
    q = Q_gw.get(((0,0),a), 0)
    best = q == max(Q_gw.get(((0,0),i),0) for i in range(4))
    print(f"  {GridWorld.ARROWS[a]} Q((0,0),{a}) = {q:.4f}{'  ‚òÖ BEST' if best else ''}")

---
## üìù Summary

### Algorithms Implemented
| Algorithm | What It Does | Key Equation |
|-----------|-------------|-------------|
| Policy Evaluation | Computes V^œÄ for a given œÄ | V(s) ‚Üê Œ£ P[R + Œ≥V(s')] |
| Policy Iteration | Finds œÄ* via evaluate‚Üíimprove loop | Guaranteed to converge |
| Value Iteration | Finds V* via Bellman optimality | V(s) ‚Üê max_a Œ£ P[R + Œ≥V(s')] |

### Key Findings
1. **Both PI and VI find the same optimal policy** for all 7 environments
2. **Higher Œ≥ ‚Üí slower convergence** but better long-term planning
3. **Stochastic environments ‚Üí cautious policies** that avoid risky states
4. **Q-values** reveal exactly why each action is good or bad
5. **These are DP methods** ‚Äî they need the full model P(s'|s,a)

### What's Next?
- **Chapter 4**: Monte Carlo methods ‚Äî learn V without knowing P!
- **Chapter 5**: TD Learning ‚Äî learn from incomplete episodes

---
üìò **Book**: Complete Reinforcement Learning Journey  
üîó [Interactive Labs](https://mlnjsh.github.io/rl-book-labs/)  
üîó [GitHub](https://github.com/mlnjsh/rl-book-labs)