In [98]:
import numpy as np
import pandas as pd
import altair as alt
from bandits.environment.cascade.context_free import CascadeContextFreeBandit
from bandits.policy.context_free import BernoulliTS, Random
from bandits.plotting import plot_beta_dist
import vegafusion as vf
from typing import TypedDict
from dataclasses import dataclass

In [None]:
alt.data_transformers.enable('json') # saves the data locally in .json file so notebook doesn't get large
# or can allow altair to keep the data in the notebook using alt.data_transformers.disable_max_rows() but it will create large notebooks!

# helper functions

In [101]:
def plot_pdf_vs_actual(
    pdf_df: pd.DataFrame,
    actuals_df: pd.DataFrame,
    width: int =500,
    height: int =250
) -> alt.Chart:
        
    pdf_charts = alt.Chart(pdf_df).mark_line().encode(
        y=alt.Y('pdf'),
        x=alt.X('x', title='θ'),
        color=alt.Color('action:N', legend=None),
        tooltip = [
            'action',
            alt.Tooltip('pdf', format='0.4', title='θ'),
        ]
    )

    actual_charts = alt.Chart(actuals_df).mark_rule().encode(
        x=alt.X('w', title='θ'),
        color=alt.Color('action:N', legend=None),
        tooltip = [
            'action',
            alt.Tooltip('w', format='0.4', title='θ'),
        ]        
    )

    final_chart = (pdf_charts + actual_charts).properties(
        width=width, height=height
    )

    return final_chart

In [102]:
def plot_actual_vs_predicted(
    policy: BernoulliTS,
    env: CascadeContextFreeBandit
) -> alt.Chart:

    pred = (policy.reward_counts / policy.action_counts)
    act = env.weights

    chart_df = pd.DataFrame(dict(pred=pred, act=act)).assign(
        diff=lambda x: (x.act - x.pred).abs()
    ).assign(
        arm=lambda x: x.index,
        optimal_arm=lambda x: x.arm.isin(env.optimal_action),
        act_sort=lambda x: x.act,
    ).sort_values(
        ['diff'], ascending=False
    ).reset_index(drop=True)

    chart_ff_df = chart_df.melt(id_vars=['arm', 'optimal_arm', 'diff', 'act_sort'])

    return alt.Chart(chart_ff_df).mark_point().encode(
        y=alt.Y('arm:O', sort=alt.SortField("act_sort", "descending")),
        x=alt.X('value'),
        color=alt.Color('variable'),
    )

In [103]:
def plot_observed_optimal_action_prob(
    reporting_df: pd.DataFrame,
    height: int = 275,
    width: int = 675,    
) -> tuple[pd.DataFrame, alt.Chart]:
    policy_prob_df = reporting_df.assign(
        n_trials=1
    ).groupby(
        ['policy_batch_check'], as_index=False
    )[['n_trials','optimal_action_id']].sum().assign(
        prob_of_optimal_action=lambda x: x['optimal_action_id'] /  x['n_trials']
    )

    chart = alt.Chart(policy_prob_df).mark_line().encode(
        x='policy_batch_check', y='prob_of_optimal_action'
    ).properties(
        width=width, 
        height=height,
    )
    return policy_prob_df, chart

In [104]:
def plot_pdf_with_actuals(
    policy: BernoulliTS,
    env: CascadeContextFreeBandit,
    width: int = 500,
    height: int = 250,
) -> tuple[pd.DataFrame, pd.DataFrame, alt.Chart]:
    actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)

    all_pdf = []
    for idx in range(policy.n_actions):
        pdf = plot_beta_dist(
            alpha=policy.alpha[idx] + policy.reward_counts[idx],
            beta=policy.beta[idx] + (policy.action_counts[idx] - policy.reward_counts[idx])
        ).assign(
            action=idx,
        )
        all_pdf.append(pdf)

    all_pdf_df = pd.concat(all_pdf, axis=0).reset_index(drop=True)    

    pdf_charts = alt.Chart(all_pdf_df).mark_line().encode(
        y=alt.Y('pdf'),
        x=alt.X('x', title='θ'),
        color=alt.Color('action:N', legend=None)
    )

    actual_charts = alt.Chart(actuals_df).mark_rule().encode(
        x=alt.X('w', title='θ'),
        color=alt.Color('action:N', legend=None)
    )

    final_chart = (pdf_charts + actual_charts).properties(
        width=width, height=height
    )

    return all_pdf_df, actuals_df, final_chart

In [105]:
class ActionRewardLogging(TypedDict):
    action: list[int]
    reward: float
    prob_of_click: float

def harmonise_reporting(
    reporting: list[ActionRewardLogging],
    env: CascadeContextFreeBandit,
    policy_batch_check: int = 500,
) -> pd.DataFrame:
    
    reporting_df = pd.DataFrame(reporting).assign(
        optimal_prob_of_click=env.optimal_reward,
        time_idx=lambda x: x.index
    )
    reporting_df['avg_reward'] = reporting_df['reward'].cumsum() / reporting_df['time_idx']
    reporting_df['avg_reward'] = reporting_df['avg_reward'].fillna(0)
    reporting_df['action_as_str'] = reporting_df['action'].apply(lambda x: '|'.join([str(y) for y in x]))
    reporting_df['optimal_action_id'] = reporting_df['action'].apply(lambda x: all(x == env.optimal_action))
    reporting_df['policy_batch_check'] = (reporting_df['time_idx']  - reporting_df['time_idx'] % policy_batch_check)
    return reporting_df

# Env setup

In [None]:
N_ACTIONS = 50
LEN_LIST = 5

In [106]:
np.random.seed(1234)

env = CascadeContextFreeBandit(
    weights=np.random.beta(a=1, b=99, size=N_ACTIONS),
    max_steps=1_000_000,
    len_list=LEN_LIST,
)

In [107]:
outputs_of_policies = {}

In [None]:
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)
actuals_df.head()

In [None]:
width = N_ACTIONS * 13
actual_charts = alt.Chart(actuals_df).mark_rule().encode(
    x=alt.X(
        'action:N',
        sort=alt.SortField("w", order='descending'),
        axis=alt.Axis(orient='bottom', labelAngle=0)
    ),
    color=alt.Color('action:N', legend=None),
    y=alt.Y('w', title='θ'),
    text=alt.Text('w', format='0.3'),
    tooltip=[
        'action',
        alt.Tooltip('w', format='0.3', title='θ')
    ]
)

final_chart = (
    actual_charts +
    actual_charts.mark_point(filled=True, size=50) + 
    actual_charts.mark_text(align='left', angle=45*7, dx=5)
).properties(
    width=width, height=225
)

final_chart

# Random Policy

In [110]:
policy = Random(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
)

In [None]:
observation, info = env.reset(seed=1234)
action = policy.select_action()
action

In [112]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    

    action = policy.select_action()

In [113]:
reporting_df = harmonise_reporting(
    reporting=reporting,
    env=env,
    policy_batch_check=500,
)

In [114]:
policy_name = 'random'
outputs_of_policies[policy_name] = reporting_df.assign(policy=policy_name)


In [None]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

In [116]:
plot_every = 1_000
mask_df = (reporting_df['time_idx'] % plot_every) == 0
mask_ff_df = (reporting_ff_df['time_idx'] % plot_every) == 0

In [None]:
chart = alt.Chart(reporting_ff_df[mask_ff_df]).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__prob_of_click.png",
)

chart


In [None]:
chart = alt.Chart(reporting_df[mask_df]).mark_line().encode(
    x='time_idx',
    y='avg_reward',
).properties(
    width=700, 
)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__avg_reward.png",
)

chart


In [None]:
chart = plot_actual_vs_predicted(policy=policy, env=env)
vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__action_probs_actual_vs_predicted.png",
)

chart

In [None]:
policy_prob_df, chart = plot_observed_optimal_action_prob(
    reporting_df=reporting_df,
    height=275,
    width=675
)

chart

In [None]:
policy_prob_df['optimal_action_id'].value_counts()

# Thompson Sampling Policy

## Uninformed Prior

In [122]:
policy = BernoulliTS(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
)

In [None]:
policy

In [None]:
prior_df = plot_beta_dist(alpha=1, beta=1)
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)

actuals_chart = alt.Chart(actuals_df).mark_rule(opacity=0.5).encode(
    x=alt.X('w', title='θ'),
)
priors_chart = alt.Chart(prior_df).mark_area(opacity=0.5).encode(
    x=alt.X('x', title='θ'),
    y=alt.Y('pdf')
)

final_chart = (
    priors_chart + actuals_chart
).properties(
    width=500, height=200
)

final_chart

In [None]:
observation, info = env.reset(seed=34325)
action = policy.select_action()
action

In [126]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    
    
    action = policy.select_action()

In [127]:
reporting_df = harmonise_reporting(reporting=reporting, env=env, policy_batch_check=500)

In [128]:
policy_name = 'ts-priors-Beta(1, 1)'
outputs_of_policies[policy_name] = reporting_df.assign(policy=policy_name)

In [None]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

In [130]:
plot_every = 100
mask_df = (reporting_df['time_idx'] % plot_every) == 0
mask_ff_df = (reporting_ff_df['time_idx'] % plot_every) == 0

In [None]:
chart = alt.Chart(reporting_ff_df[mask_ff_df]).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__prob_of_click.png",
)

chart

In [None]:
chart = alt.Chart(reporting_df[mask_df]).mark_line().encode(
    x='time_idx',
    y='avg_reward',
).properties(
    width=700, 
)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__avg_reward.png",
)

chart

In [None]:
chart = plot_actual_vs_predicted(policy=policy, env=env)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__action_probs_actual_vs_predicted.png",
)

chart

In [None]:
policy_prob_df, chart = plot_observed_optimal_action_prob(
    reporting_df=reporting_df,
    height=275,
    width=675
)


vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__observed_optimal_action_distribution.png",
)


chart

In [None]:
all_pdf_df, actuals_df, chart = plot_pdf_with_actuals(env=env, policy=policy)


vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__action_beta_distributions.png",
)


chart

## Pessimistic priors

In [88]:
policy = BernoulliTS(
    n_actions=env.n_actions,
    len_list=env.len_list,
    random_state=1234,
    batch_size=1,
    alpha=np.ones(env.n_actions) * 1,
    beta=np.ones(env.n_actions) * 99,
)

In [None]:
prior_df = plot_beta_dist(alpha=1, beta=99)
actuals_df = pd.DataFrame(dict(w=env.weights)).assign(action=lambda x: x.index)

actuals_chart = alt.Chart(actuals_df).mark_rule(opacity=0.5).encode(
    x=alt.X('w', title='θ'),
)
priors_chart = alt.Chart(prior_df).mark_area(opacity=0.5).encode(
    x=alt.X('x', title='θ'),
    y=alt.Y('pdf')
)

final_chart = (
    priors_chart + actuals_chart
).properties(
    width=500, height=200
)

final_chart

In [None]:
observation, info = env.reset(seed=1234)
action = policy.select_action()
action

In [91]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    policy.cascade_params_update(
        action=action,
        reward_position=info["position_of_click"]
    )

    if truncated:
        break    
    
    action = policy.select_action()

In [None]:
policy_name = 'ts-priors-Beta(1, 99)'
outputs_of_policies[policy_name] = reporting_df.assign(policy=policy_name)

reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['prob_of_click', 'optimal_prob_of_click']
)

reporting_ff_df.head()

In [None]:
plot_every = 100
mask_df = (reporting_df['time_idx'] % plot_every) == 0
mask_ff_df = (reporting_ff_df['time_idx'] % plot_every) == 0


chart = alt.Chart(reporting_ff_df[mask_ff_df]).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=600)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__prob_of_click.png",
)
chart


In [None]:
chart = alt.Chart(reporting_df[mask_df]).mark_line().encode(
    x='time_idx',
    y='avg_reward',
).properties(
    width=700, 
)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__avg_reward.png",
)

chart

In [None]:
chart = plot_actual_vs_predicted(policy=policy, env=env)

vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__action_probs_actual_vs_predicted.png",
)

chart

In [None]:
policy_prob_df, chart = plot_observed_optimal_action_prob(
    reporting_df=reporting_df,
    height=275,
    width=675
)


vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__observed_optimal_action_distribution.png",
)


chart

In [None]:
all_pdf_df, actuals_df, chart = plot_pdf_with_actuals(env=env, policy=policy)


vf.save_png(
    chart, 
    f"context_free_outputs/{policy_name}__action_beta_distributions.png",
)


chart