In [None]:
import gymnasium as gym

import numpy as np
import polars as pl

from collections import defaultdict

import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [None]:
env = gym.make("Blackjack-v1", natural=False, sab=False)

In [None]:
class BlackjackAgent():
    def __init__(self, threshold, strategy):
        self.threshold = threshold
        self.strategy = strategy
        self.policy = defaultdict(lambda: 0)

    def get_action(self, state):

        if self.strategy == "random":
            return np.random.choice([0, 1])

        if self.strategy == "threshold":
            if state[0] < self.threshold:
                return 1
            else:
                return 0

        if self.strategy == "learning-ES":
            return 1

    def update(self, state, action, reward, next_state):
        pass


In [None]:
def first_visit_mc_prediction(episodes):
    gamma = 1
    values_by_state = defaultdict(list)

    for sequence in episodes:
        G = 0
        states_visited = []

        for step in sequence[::-1]:
            state, action, reward = step
            G = gamma * G + reward
            if state not in states_visited:
                values_by_state[state].append(G)


    value_function = {}
    for state, values in values_by_state.items():
        value_function[state] = np.mean(values)

    return value_function

In [None]:
def convert_value_function_to_df(value_function):

    value_function = pl.DataFrame(
        [
            {"player": k[0], "dealer": k[1], "value": v}
            for k, v in value_function.items()
        ]
    ).sort("player", "dealer")

    return value_function


In [None]:
def generate_episodes(agent, n_episodes=100_000):
    episodes = []

    for _ in range(n_episodes):
        state, _ = env.reset()
        terminated = False
        sequence = []

        while not terminated:
            current_state = state[:2]

            # determine next action and execute it and update value function
            action = agent.get_action(state)
            state, reward, terminated, truncated, info = env.step(action)
            agent.update(state, action, reward, state)

            sequence.append((current_state, action, reward))

        episodes.append(sequence)

    return episodes

In [None]:
def plot_surface(value_function, title=""):

    z = (
        value_function.pivot(index="player", on="dealer", values="value")
        .drop("player")
        .to_numpy()
    )

    x = np.arange(1, 11)
    y = np.arange(4, 22)

    fig = go.Figure(data=[go.Surface(x=x, y=y, z=z)])


    fig.update_layout(
        title=title,
        width=500,
        height=500,
        margin=dict(l=65, r=50, b=65, t=90),
        scene=dict(
            xaxis_title="Dealer Showing",
            yaxis_title="Player Sum",
            zaxis_title="Avg. Reward",
        ),
    )
    fig.show()

In [None]:
episodes = generate_episodes(BlackjackAgent(threshold=17, strategy="threshold"))
value_function_raw = first_visit_mc_prediction(episodes)
value_function = convert_value_function_to_df(value_function_raw)
plot_surface(value_function, title="First Visit MC Prediction, stay on 17")

In [None]:
episodes = generate_episodes(BlackjackAgent(threshold=21, strategy="threshold"))
value_function_raw = first_visit_mc_prediction(episodes)
value_function = convert_value_function_to_df(value_function_raw)
plot_surface(value_function, title="First Visit MC Prediction, stay on 20")


In [None]:
episodes = generate_episodes(
    BlackjackAgent(threshold=17, strategy="random"), n_episodes=100_000
)
value_function_raw = first_visit_mc_prediction(episodes)
value_function = convert_value_function_to_df(value_function_raw)
plot_surface(value_function, title="First Visit MC Prediction, stay on 17")


## MC Control

### Example 5.4: Off-policy Estimation of a Blackjack State Value

In [None]:
def enumerate_events(episodes):
    # calculate importance sampling ratio
    occurrences = []

    for episode_id, sequence in enumerate(episodes):
        for time, step in enumerate(sequence):
            state, action, reward = step

            occurrences.append((episode_id, time, state[0], state[1], action, reward))

    occurrences = pl.DataFrame(
        occurrences, orient="row", schema=["episode", "time", "player", "dealer", "action", "reward"]
    )
    return occurrences

In [None]:
def get_action_probabilities(episodes):
    action_counts = (
        episodes.group_by(["player", "dealer", "action"])
        .agg(pl.len().alias("count"))
        .with_columns(
            pl.sum("count").over(["player", "dealer"]).alias("total"),
        )
        .with_columns((pl.col("count") / pl.col("total")).alias("probability"))
        .select("player", "dealer", "action", "probability")
    )

    return action_counts

In [None]:
episodes_random = generate_episodes(BlackjackAgent(threshold=17, strategy="random"))
episodes_threshold = generate_episodes(
    BlackjackAgent(threshold=17, strategy="threshold")
)

episodes_random_df = enumerate_events(episodes_random)
episodes_threshold_df = enumerate_events(episodes_threshold)

probabilties_random = get_action_probabilities(episodes_random_df)
probabilties_threshold = get_action_probabilities(episodes_threshold_df)


In [None]:
isr_table = (
    probabilties_random.join(
        probabilties_threshold,
        on=["player", "dealer", "action"],
        suffix="_policy",
        how="full",
    )
    .drop(["player_policy", "dealer_policy", "action_policy"])
    .with_columns(
        (pl.col("probability_policy") / pl.col("probability"))
        .fill_null(0)
        .alias("importance_sampling_ration")
    )
)

assert (
    isr_table.null_count()["probability"][0] == 0
), "control policy must have coverage"

# importance_sampling_ration corresponds to pi(A/S) / b(A/S) in the book
isr_table


In [None]:
episodes_random_df_with_isr = episodes_random_df.join(
    isr_table.drop("probability", "probability_policy"),
    on=["player", "dealer", "action"],
    how="left",
)

results = []
for episode_id, episode in episodes_random_df_with_isr.group_by("episode"):
    episode_with_isr = episode.sort("time", descending=True)

    for t in range(0, len(episode_with_isr)):
        Gt = episode_with_isr[t]["reward"][0]
        rho_t_T = episode_with_isr[:t+1]["importance_sampling_ration"].sum()
        player = episode_with_isr[t]["player"][0]
        dealer = episode_with_isr[t]["dealer"][0]

        results.append((episode_id[0], player, dealer, Gt, rho_t_T))

results_df = pl.DataFrame(
    results, orient="row", schema=["episode_id", "player", "dealer", "reward", "rho_t_T"]
)

results_df

In [None]:
state_values = (
    results_df.with_columns(
        (pl.col("reward") * pl.col("rho_t_T")).alias("reward_contribution")
    )
    .group_by("player", "dealer")
    .agg(
        pl.count("reward").alias("visit_count_state"),
        pl.sum("rho_t_T").alias("weighted_denominator"),
        pl.sum("reward_contribution").alias("reward"),
    )
    .with_columns(
        (pl.col("reward") / pl.col("visit_count_state")).alias("value_is"),
        (pl.col("reward") / pl.col("weighted_denominator")).alias("value_wis"),
    )
    .sort("player", "dealer")
)

state_values

In [None]:
def plot_surface_from_df(df, x_col, y_col, z_col, title=""):
    z = (
        df.pivot(index=x_col, on=y_col, values=z_col)
        .drop("player")
        .to_numpy()
    )

    x = np.arange(1, 11)
    y = np.arange(4, 22)

    fig = go.Figure(data=[go.Surface(x=x, y=y, z=z)])

    fig.update_layout(
        title=title,
        width=500,
        height=500,
        margin=dict(l=65, r=50, b=65, t=90),
        scene=dict(
            xaxis_title=x_col,
            yaxis_title=y_col,
            zaxis_title=z_col,
        ),
    )
    fig.show()


In [None]:
x_col = "player"
y_col = "dealer"
z_col = "value_is"

plot_surface_from_df(
    state_values, x_col, y_col, "value_is", title="Importance Sampling"
)

plot_surface_from_df(
    state_values, x_col, y_col, "value_wis", title="Weighted Importance Sampling"
)

In [None]:
import plotly.express as px

In [None]:
px.line(
    state_values.with_columns((pl.col("value_is") - pl.col("value_wis")).alias("delta"))
    .group_by(pl.col("delta").round(2))
    .len()
    .sort("delta"),
    x="delta",
    y="len",
)
