In [1]:
import numpy as np
import pandas as pd
import altair as alt
from bandits.environment.cascade.contextual import CascadeContextualBandit
from bandits.policy.context_free import BernoulliTS, Random
import vegafusion as vf
from typing import TypedDict
from dataclasses import dataclass

In [None]:
alt.data_transformers.enable('json') # saves the data locally in .json file so notebook doesn't get large
# or can allow altair to keep the data in the notebook using alt.data_transformers.disable_max_rows() but it will create large notebooks!

# helper functions

In [3]:
class ActionRewardLogging(TypedDict):
    action: list[int]
    reward: float
    prob_of_click: float
    context: int


# Env setup

In [128]:
N_CONTEXTS = 10
N_ACTIONS = 50
LEN_LIST = 5

In [138]:
np.random.seed(1234)

env = CascadeContextualBandit(
    weights=np.random.beta(a=1, b=99, size=N_CONTEXTS * N_ACTIONS).reshape((N_CONTEXTS, N_ACTIONS)),
    max_steps=500_000,
    len_list=LEN_LIST,
)

# Thompson Sampling Policy

## Uninformed Prior

In [139]:
policy: dict[int, BernoulliTS] = {}

for idx in range(env.dim):
    policy[idx] = BernoulliTS(
        n_actions=N_ACTIONS,
        alpha=1,
        beta=1,
        len_list=env.len_list,
        random_state=1234,
        batch_size=1,
    )

In [None]:
policy

In [None]:
observation, info = env.reset(seed=1234)
observation

In [None]:
context = np.where(observation==1)[0][0]
context

In [None]:
env.optimal_action

In [144]:
policy_for_context = policy[context]
action = policy_for_context.select_action()

In [145]:
reporting: list[ActionRewardLogging] = []

while True:
    new_observation, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],
        context=context,
    ))

    policy_for_context.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"],
    )

    if truncated:
        break

    observation = new_observation
    context = np.where(observation==1)[0][0]
    policy_for_context = policy[context]
    action = policy_for_context.select_action()

In [None]:
reporting_df = pd.DataFrame(reporting)
reporting_df.head()

In [None]:
reporting_df['optimal_prob_of_click'] = reporting_df['context'].apply(lambda x: env.optimal_reward[x])
reporting_df.head()

In [None]:
reporting_ff_df = reporting_df.assign(
    time_idx=lambda x: x.index
).melt(
    id_vars=['time_idx', 'context'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)
# start from 0
reporting_ff_df['context_time_idx'] = reporting_ff_df.sort_values(['context', 'time_idx']).assign(n=1).groupby(['context'])['n'].cumsum() - 1
reporting_ff_df.head()

In [None]:
alt.Chart(reporting_ff_df).mark_line().encode(
    x='context_time_idx',
    color='variable',
    y='value',
    facet=alt.Facet('context:O', columns=2)
).properties(
    width=300, height=100
).resolve_scale(y='independent')