# Sarsa(λ) – Forward View (λ‑return)

In [None]:

import numpy as np, matplotlib.pyplot as plt, pickle, pathlib
from env import BlackjackEnv
from tqdm import trange
from collections import defaultdict

env = BlackjackEnv(seed=1)

# Load MC ground truth
with open('mc_Q_star.pkl', 'rb') as f:
    Q_star = pickle.load(f)

def mse(Q_hat):
    err = 0.0
    for s_a,q_star in Q_star.items():
        err += (Q_hat.get(s_a, np.zeros(2))[s_a[1]] - q_star)**2 if isinstance(s_a, tuple) else 0
    return err / len(Q_star)


### Helper: λ‑return

In [None]:

def lambda_return(rewards, t, lam):
    G = 0.0
    lam_pow = 1.0
    for n in range(1, len(rewards)-t):
        G_n = rewards[t+1:t+n+1].sum()
        G += lam_pow * (1-lam) * G_n
        lam_pow *= lam
    G += lam_pow * rewards[-1]
    return G


### Run experiments for λ ∈ {0,0.1,…,1}

In [None]:

lam_grid = np.linspace(0,1,11)
mse_per_lam = []
episodes = 1000
for lam in lam_grid:
    Q_hat = defaultdict(lambda: np.zeros(2))
    alpha = 0.01
    mses = []
    for ep in range(episodes):
        s = env.reset()
        episode_states, episode_actions, episode_rewards = [], [], [0]
        done=False
        # generate episode
        while not done:
            # ε-greedy with constant ε=0.1
            if np.random.rand() < 0.1:
                a = np.random.randint(2)
            else:
                q_vals = Q_hat[(s[0], s[1])]
                a = np.random.choice(np.flatnonzero(q_vals == q_vals.max()))
            episode_states.append(s)
            episode_actions.append(a)
            s_next, r, done = env.step(a)
            episode_rewards.append(r)
            s = s_next if not done else (None, None)
        # forward view updates
        T = len(episode_states)
        for t in range(T):
            G_lam = lambda_return(np.array(episode_rewards), t, lam)
            key = (episode_states[t][0], episode_states[t][1]), episode_actions[t]
            Q_hat[key[0]][key[1]] += alpha * (G_lam - Q_hat[key[0]][key[1]])
        # log mse
        mses.append(mse(Q_hat))
    mse_per_lam.append(mses[-1])
    print(f"λ={lam:.1f}  final MSE={mses[-1]:.4f}")


### Plot MSE vs λ

In [None]:

plt.figure(figsize=(6,4))
plt.plot(lam_grid, mse_per_lam, marker='o')
plt.xlabel('λ')
plt.ylabel('MSE to Q★')
plt.title('Mean‑squared error after 1000 episodes')
pathlib.Path('plots').mkdir(exist_ok=True)
plt.savefig('plots/mse_vs_lambda.png')
plt.show()
