# Code Examples related to Chapter 6
Sutton and Barto's Reinforcement Learning: An Introduction

In [None]:
import gymnasium as gym

import numpy as np
import polars as pl
import numpy as np

from collections import defaultdict

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

## Example 6.2 Random Walk

In [None]:

def generate_episode(end_node=3):
    location = 0

    episode = [(location, 0)]

    for i in range(100):
        move = np.random.choice([-1, 1])
        location += move


        if np.abs(location) == end_node:
            if location > 0:
                reward = 1
            else:
                reward = 0

            episode.append((location, reward))
            break
        else:
            episode.append((location, 0))

    return np.array(episode)

In [None]:
episodes = [generate_episode() for _ in range(1000)]

In [None]:
def convert_to_polars_td(V_history):
    iteration = []
    state = []
    value = []

    for r, V in V_history.items():
        for s, v in V.items():
            iteration.append(r)
            state.append(s)
            value.append(v)

    v_hist_df = pl.DataFrame(
        {"iteration": iteration, "state": np.array(state), "value": np.array(value)}
    )

    return v_hist_df


def temporal_difference(episodes):

    alpha = 0.01
    gamma = 1

    V_history = {}
    V = {k: 0 for k in np.arange(-3, 3 + 1, 1)}

    for r, sequence in enumerate(episodes):
        state = 0

        for next_state, reward in sequence[1:]:
            V[state] = V[state] + alpha * (reward + gamma * V[next_state] - V[state])
            state = next_state

        V_history[r] = dict(V.copy())

    return V_history

In [None]:
def first_visit_mc_prediction(episodes):
    gamma = 1

    V_history = {}

    values_by_state = defaultdict(list)

    for r, sequence in enumerate(episodes):
        G = 0
        states_visited = []

        for step in sequence[::-1]:
            state, reward = step
            G = gamma * G + reward
            if state not in states_visited:
                values_by_state[state].append(G)

        V = {}
        for state, values in values_by_state.items():
            V[state] = np.mean(values)

        V_history[r] = V

    return V_history

In [None]:
value_history_mc = first_visit_mc_prediction(episodes)
value_history_td = temporal_difference(episodes)

In [None]:
results_combined = (
    pl.concat(
        [
            convert_to_polars_td(value_history_mc).with_columns(
                pl.lit("td").alias("method")
            ),
            convert_to_polars_td(value_history_td).with_columns(
                pl.lit("mc").alias("method")
            ),
        ]
    )
    .with_columns((((pl.col("state") + 3) / 6) % 1).alias("value_true"))
    .with_columns((pl.col("value") - pl.col("value_true")).alias("error"))
)

results_combined

In [None]:
fig = px.line(
    (
        results_combined.filter((pl.col("iteration") % 10) == 0).sort(
            "iteration", "state"
        )
    ),
    x="state",
    y="value",
    color="method",
    animation_frame="iteration",
)
fig.update_yaxes(range=[0, 1])
fig.show()
