In [30]:
import numpy as np
import pandas as pd
import altair as alt
from bandits.environment.cascade.context_free import CascadeContextFreeBandit

alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [31]:
env = CascadeContextFreeBandit(
    weights=np.random.beta(a=1, b=99, size=50),
    max_steps=100_000,
    len_list=5,
)

In [32]:
observation, info = env.reset()
action = np.random.choice(
    range(env.n_actions), size=env.len_list, replace=False
)

action

array([ 4, 18, 28,  8, 11])

In [33]:
reporting = []

while True:
    _, reward, terminated, truncated, info = env.step(action=action)
    reporting.append(dict(
        action=action,
        reward=reward,
        prob_of_click=info["prob_of_click"],        
    ))

    if truncated:
        break    
    
    action = np.random.choice(
        range(env.n_actions), size=env.len_list, replace=False
    )

In [34]:
reporting_df = pd.DataFrame(reporting).assign(
    optimal_prob_of_click=env.optimal_reward,
    time_idx=lambda x: x.index
).assign(
    cumsum_of_prob_of_click=lambda x: x['prob_of_click'].cumsum(),
    cumsum_of_optimal_prob_of_click=lambda x: x['optimal_prob_of_click'].cumsum(),
    cumsum_of_reward=lambda x: x['reward'].cumsum(),    
)
reporting_df.head()

Unnamed: 0,action,reward,prob_of_click,optimal_prob_of_click,time_idx,cumsum_of_prob_of_click,cumsum_of_optimal_prob_of_click,cumsum_of_reward
0,"[4, 18, 28, 8, 11]",0,0.018795,0.171845,0,0.018795,0.171845,0
1,"[44, 19, 5, 42, 29]",0,0.059636,0.171845,1,0.07843,0.343689,0
2,"[25, 49, 22, 7, 36]",0,0.109715,0.171845,2,0.188145,0.515534,0
3,"[24, 11, 35, 45, 26]",0,0.036208,0.171845,3,0.224353,0.687378,0
4,"[30, 12, 43, 25, 39]",0,0.056908,0.171845,4,0.281261,0.859223,0


In [37]:
reporting_ff_df = reporting_df.melt(
    id_vars=['time_idx', 'action'],
    value_vars=['cumsum_of_prob_of_click', 'cumsum_of_optimal_prob_of_click']
)

In [40]:
alt.Chart(reporting_ff_df).mark_line().encode(
    y=alt.Y('value'),
    x=alt.X('time_idx'),
    color='variable'
).properties(width=800)