In [1]:
import numpy as np
import pandas as pd
import altair as alt
from bandits.environment.cascade.context_free import CascadeContextFreeBandit
from bandits.policy.context_free import BernoulliTS, Random
from bandits.plotting import plot_beta_dist
import vegafusion as vf
from typing import TypedDict
from dataclasses import dataclass

In [2]:
alt.data_transformers.enable('json') # saves the data locally in .json file so notebook doesn't get large
# or can allow altair to keep the data in the notebook using alt.data_transformers.disable_max_rows() but it will create large notebooks!

DataTransformerRegistry.enable('json')

In [3]:
N_ACTIONS = 20
LEN_LIST = 5

# helper functions

In [4]:
def plot_pdf_vs_actual(
    pdf_df: pd.DataFrame,
    actuals_df: pd.DataFrame,
    width: int =500,
    height: int =250
) -> alt.Chart:
        
    pdf_charts = alt.Chart(pdf_df).mark_line().encode(
        y=alt.Y('pdf'),
        x=alt.X('x', title='θ'),
        color=alt.Color('action:N', legend=None),
        tooltip = [
            'action',
            alt.Tooltip('pdf', format='0.4', title='θ'),
        ]
    )

    actual_charts = alt.Chart(actuals_df).mark_rule().encode(
        x=alt.X('w', title='θ'),
        color=alt.Color('action:N', legend=None),
        tooltip = [
            'action',
            alt.Tooltip('w', format='0.4', title='θ'),
        ]        
    )

    final_chart = (pdf_charts + actual_charts).properties(
        width=width, height=height
    )

    return final_chart

In [5]:
def plot_actual_vs_predicted(
    policy: BernoulliTS,
    env: CascadeContextFreeBandit
) -> alt.Chart:

    pred = (policy.reward_counts / policy.action_counts)
    act = env.weights

    chart_df = pd.DataFrame(dict(pred=pred, act=act)).assign(
        diff=lambda x: (x.act - x.pred).abs()
    ).assign(
        arm=lambda x: x.index,
        optimal_arm=lambda x: x.arm.isin(env.optimal_action),
        act_sort=lambda x: x.act,
    ).sort_values(
        ['diff'], ascending=False
    ).reset_index(drop=True)

    chart_ff_df = chart_df.melt(id_vars=['arm', 'optimal_arm', 'diff', 'act_sort'])

    return alt.Chart(chart_ff_df).mark_point().encode(
        y=alt.Y('arm:O', sort=alt.SortField("act_sort", "descending")),
        x=alt.X('value'),
        color=alt.Color('variable'),
    )

In [6]:
class ActionRewardLogging(TypedDict):
    action: list[int]
    reward: float
    prob_of_click: float

def harmonise_reporting(
    reporting: list[ActionRewardLogging],
    env: CascadeContextFreeBandit,
    policy_batch_check: int = 500,
) -> pd.DataFrame:
    
    reporting_df = pd.DataFrame(reporting).assign(
        optimal_prob_of_click=env.optimal_reward,
        time_idx=lambda x: x.index
    )
    reporting_df['avg_reward'] = reporting_df['reward'].cumsum() / reporting_df['time_idx']
    reporting_df['avg_reward'] = reporting_df['avg_reward'].fillna(0)
    reporting_df['action_as_str'] = reporting_df['action'].apply(lambda x: '|'.join([str(y) for y in x]))
    reporting_df['optimal_action_id'] = reporting_df['action'].apply(lambda x: all(x == env.optimal_action))
    reporting_df['policy_batch_check'] = (reporting_df['time_idx']  - reporting_df['time_idx'] % policy_batch_check)
    return reporting_df

# Env setup

In [75]:
np.random.seed(1234)

env = CascadeContextFreeBandit(
    weights=np.random.beta(a=1, b=99, size=N_ACTIONS),
    max_steps=500_000,
    len_list=LEN_LIST,
)

In [76]:
outputs_of_policies = {}

# Random Policy

In [77]:
policy = Random(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
)

In [78]:
observation, info = env.reset(seed=1234)
action = policy.select_action()
action

array([ 3, 13,  2, 16, 14])

In [79]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    

    action = policy.select_action()

In [80]:
reporting_df = harmonise_reporting(
    reporting=reporting,
    env=env,
    policy_batch_check=500,
)

In [81]:
outputs_of_policies['random'] = reporting_df.assign(policy='random')

In [82]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

Unnamed: 0,time_idx,action,variable,value
0,0,"[3, 13, 2, 16, 14]",prob_of_click,0.028834
1,1,"[8, 19, 13, 1, 4]",prob_of_click,0.024114
2,2,"[2, 3, 17, 6, 16]",prob_of_click,0.034702
3,3,"[7, 0, 18, 11, 4]",prob_of_click,0.040325
4,4,"[8, 3, 5, 0, 6]",prob_of_click,0.034135


In [83]:
chart = alt.Chart(reporting_ff_df).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

vf.save_png(
    chart, 
    "random_policy__prob_of_click.png",
)


In [84]:
chart

In [85]:
chart = alt.Chart(reporting_df).mark_line().encode(
    x='time_idx',
    y='avg_reward',
).properties(
    width=700, 
)

vf.save_png(
    chart, 
    "random_policy__avg_reward.png",
)


In [86]:
chart

In [87]:
chart = plot_actual_vs_predicted(policy=policy, env=env)
vf.save_png(
    chart, 
    "random_policy__action_probs_actual_vs_predicted.png",
)

In [88]:
chart

In [89]:
policy_prob_df = reporting_df.assign(n_trials=1).groupby(['policy_batch_check'], as_index=False)[['n_trials','optimal_action_id']].sum().assign(
    prob_of_optimal_action=lambda x: x['optimal_action_id'] /  x['n_trials']
)

chart = alt.Chart(policy_prob_df).mark_line().encode(
    x='policy_batch_check', y='prob_of_optimal_action'
).properties(
    width=700, 
)


chart

# Thompson Sampling Policy

## Uninformed Prior

In [57]:
policy = BernoulliTS(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
)

In [58]:
policy

BernoulliTS(n_actions=20, len_list=5, batch_size=1, random_state=1234, alpha=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1.]), beta=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1.]))

In [67]:
observation, info = env.reset(seed=34325)
action = policy.select_action()
action

array([ 7,  9, 10,  1, 16])

In [68]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    
    
    action = policy.select_action()

In [69]:
reporting_df = harmonise_reporting(reporting=reporting, env=env, policy_batch_check=500)

In [70]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

Unnamed: 0,time_idx,action,variable,value
0,0,"[7, 9, 10, 1, 16]",prob_of_click,0.080067
1,1,"[7, 9, 10, 1, 3]",prob_of_click,0.080305
2,2,"[7, 9, 10, 1, 3]",prob_of_click,0.080305
3,3,"[7, 9, 10, 1, 16]",prob_of_click,0.080067
4,4,"[7, 9, 10, 1, 16]",prob_of_click,0.080067


In [71]:
chart = alt.Chart(reporting_ff_df).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

vf.save_png(
    chart, 
    "ts_policy_uninformed_prior__prob_of_click.png",
)

In [72]:
chart

In [73]:
chart = plot_actual_vs_predicted(
    policy=policy,
    env=env
)

chart

In [74]:
policy_prob_df = reporting_df.assign(n_trials=1).groupby(['policy_batch_check'], as_index=False)[['n_trials','optimal_action_id']].sum().assign(
    prob_of_optimal_action=lambda x: x['optimal_action_id'] /  x['n_trials']
)

chart = alt.Chart(policy_prob_df).mark_line().encode(
    x='policy_batch_check', y='prob_of_optimal_action'
).properties(
    width=700, 
)


chart

In [None]:
reporting_df['action_as_str'] = reporting_df['action'].apply(lambda x: '|'.join([str(y) for y in x]))
reporting_df['optimal_action_id'] = reporting_df['action'].apply(lambda x: all(x == env.optimal_action))
reporting_df.tail()

In [None]:
policy_batch_check = 500
reporting_df['policy_batch_check'] = (reporting_df['time_idx']  - reporting_df['time_idx'] % policy_batch_check)
reporting_df.groupby(['policy_batch_check']).size()

In [113]:
policy_prob_df = reporting_df.assign(n_trials=1).groupby(['policy_batch_check'], as_index=False)[['n_trials','optimal_action_id']].sum().assign(
    prob_of_optimal_action=lambda x: x['optimal_action_id'] /  x['n_trials']
)

In [None]:
alt.Chart(policy_prob_df).mark_line().encode(
    x='policy_batch_check', y='prob_of_optimal_action'
).properties(
    width=700, 
)

In [None]:
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)
actuals_df.head()


In [None]:
all_pdf = []
for idx in range(policy.n_actions):
    pdf = plot_beta_dist(
        alpha=policy.alpha[idx] + policy.reward_counts[idx],
        beta=policy.beta[idx] + (policy.action_counts[idx] - policy.reward_counts[idx])
    ).assign(
        action=idx,
    )
    all_pdf.append(pdf)

all_pdf_df = pd.concat(all_pdf, axis=0).reset_index(drop=True)    
all_pdf_df.head()


In [None]:
pdf_charts = alt.Chart(all_pdf_df).mark_line().encode(
    y=alt.Y('pdf'),
    x=alt.X('x'),
    color=alt.Color('action:N', legend=None)
)

actual_charts = alt.Chart(actuals_df).mark_rule().encode(
    x=alt.X('w'),
    color=alt.Color('action:N', legend=None)
)

final_chart = (pdf_charts + actual_charts).properties(
    width=500, height=250
)

final_chart



## Pessimistic priors

In [130]:
policy = BernoulliTS(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
    alpha=np.ones(env.n_actions) * 1,
    beta=np.ones(env.n_actions) * 99,
)

In [None]:
prior_df = plot_beta_dist(alpha=1, beta=99)
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)
actuals_df.head()

In [None]:
prior_df.head()

In [None]:
actuals_chart = alt.Chart(actuals_df).mark_rule(opacity=0.5).encode(
    x='w',
)
priors_chart = alt.Chart(prior_df).mark_area(opacity=0.5).encode(
    x='x',
    y=alt.Y('pdf')
)

final_chart = (
    priors_chart + actuals_chart
).properties(
    width=500, height=200
)

final_chart

In [None]:
observation, info = env.reset(seed=1234)
action = policy.select_action()
action

In [135]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    
    
    action = policy.select_action()

In [None]:
reporting_df = pd.DataFrame(reporting).assign(
    optimal_prob_of_click=env.optimal_reward,
    time_idx=lambda x: x.index
)
reporting_df['avg_reward'] = reporting_df['reward'].cumsum() / reporting_df['time_idx']
reporting_df['avg_reward'] = reporting_df['avg_reward'].fillna(0)
reporting_df['action_as_str'] = reporting_df['action'].apply(lambda x: '|'.join([str(y) for y in x]))
reporting_df['optimal_action_id'] = reporting_df['action'].apply(lambda x: all(x == env.optimal_action))


reporting_df.head()

In [None]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

In [None]:
plot_every = 100
mask = (reporting_ff_df['time_idx'] % plot_every) == 0
mask.sum()

In [None]:
alt.Chart(reporting_ff_df[mask]).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

In [None]:
pred = (policy.reward_counts / policy.action_counts)
act = env.weights


chart_df = pd.DataFrame(dict(pred=pred, act=act)).assign(
    diff=lambda x: (x.act - x.pred).abs()
).assign(
    arm=lambda x: x.index,
    optimal_arm=lambda x: x.arm.isin(env.optimal_action),
    act_sort=lambda x: x.act,
).sort_values(['diff'], ascending=False).reset_index(drop=True)

chart_ff_df = chart_df.melt(id_vars=['arm', 'optimal_arm', 'diff', 'act_sort'])
alt.Chart(chart_ff_df).mark_point().encode(
    y=alt.Y('arm:O', sort=alt.SortField("act_sort", "descending")),
    x=alt.X('value'),
    color=alt.Color('variable'),
)

In [None]:
plot_every = 10
mask = (reporting_ff_df['time_idx'] % plot_every) == 0
alt.Chart(reporting_df[mask]).mark_line().encode(
    x='time_idx', y='avg_reward'
).properties(
    width=700, 
)

In [142]:
policy_batch_check = 500
reporting_df['policy_batch_check'] = (reporting_df['time_idx']  - reporting_df['time_idx'] % policy_batch_check)
policy_prob_df = reporting_df.assign(n_trials=1).groupby(['policy_batch_check'], as_index=False)[['n_trials','optimal_action_id']].sum().assign(
    prob_of_optimal_action=lambda x: x['optimal_action_id'] /  x['n_trials']
)

In [None]:
alt.Chart(policy_prob_df).mark_line().encode(
    x='policy_batch_check', y='prob_of_optimal_action'
).properties(
    width=700, 
)

In [None]:
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)
actuals_df.head()

In [145]:
all_pdf = []
for idx in range(policy.n_actions):
    pdf = plot_beta_dist(
        alpha=policy.alpha[idx] + policy.reward_counts[idx],
        beta=policy.beta[idx] + (policy.action_counts[idx] - policy.reward_counts[idx])
    ).assign(
        action=idx,
    )
    all_pdf.append(pdf)

all_pdf_df = pd.concat(all_pdf, axis=0).reset_index(drop=True)    

In [None]:
pdf_charts = alt.Chart(all_pdf_df).mark_line().encode(
    y=alt.Y('pdf'),
    x=alt.X('x'),
    color=alt.Color('action:N', legend=None)
)

actual_charts = alt.Chart(actuals_df).mark_rule().encode(
    x=alt.X('w'),
    color=alt.Color('action:N', legend=None)
)

final_chart = (pdf_charts + actual_charts).properties(
    width=500, height=250
)

final_chart