In [41]:
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import helpers.mab_bernoulli as brn
import matplotlib.pyplot as plt

In [59]:
#parameter listing
T = 1000 #simulation horizon
K_list = np.arange(2, 16, 3) #number of arms
delta_list = [0.2] #reward gap between best and worst arms
data_amt_list = np.arange(0, 50, 10) #number of data points per arm
num_samples = 1000

In [60]:
df_cols = ["NumArms", "ArmGap", "ArmData", "NumSamples"]
c = list(range(1, T + 1))
df_cols = df_cols + c + ["low_"+str(n) for n in c] + ["high_"+str(n) for n in c]
fs_df = pd.DataFrame(columns=df_cols)
ar_df = pd.DataFrame(columns=df_cols)

for K in tqdm(K_list, desc = "Number of Arms"):
    for delta in tqdm(delta_list, leave = False, desc = "Arm Gap"):
        arms = [0.5 - delta/2 for _ in range(K-1)]
        opt_mean = 0.5 + delta / 2
        arms.append(opt_mean)
        for data_amt in tqdm(data_amt_list, leave = False, desc = "Data per arm"):

            regret_vecs = [np.zeros((num_samples, T)) for _ in range(2)]
            for i in tqdm(range(num_samples), leave = False):
                arm_data = [brn.gen_data(arms[i], data_amt) for i in range(K)]
                fs_rewards = brn.FS_path(T, arms, brn.flat_priors(K), arm_data)
                ar_rewards = brn.AR_path(T, arms, brn.flat_priors(K), arm_data)

                fs_regrets = opt_mean - np.array(fs_rewards)
                ar_regrets = opt_mean - np.array(ar_rewards)

                regret_vecs[0][i] = np.cumsum(fs_regrets)
                regret_vecs[1][i] = np.cumsum(ar_regrets)

            avgs = np.mean(regret_vecs, axis=1)
            lower_CBs = avgs-2*np.std(regret_vecs,axis=1)/np.sqrt(num_samples)
            upper_CBs = avgs+2*np.std(regret_vecs,axis=1)/np.sqrt(num_samples)

            fs_df.loc[len(fs_df)] = [K, delta, data_amt, num_samples] + list(avgs[0]) + list(lower_CBs[0]) + list(upper_CBs[0])
            ar_df.loc[len(ar_df)] = [K, delta, data_amt, num_samples] + list(avgs[1]) + list(lower_CBs[1]) + list(upper_CBs[1])

Number of Arms:   0%|          | 0/3 [00:00<?, ?it/s]

Arm Gap:   0%|          | 0/1 [00:00<?, ?it/s]

Data per arm:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Arm Gap:   0%|          | 0/1 [00:00<?, ?it/s]

Data per arm:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Arm Gap:   0%|          | 0/1 [00:00<?, ?it/s]

Data per arm:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [61]:
fs_df

Unnamed: 0,NumArms,ArmGap,ArmData,NumSamples,1,2,3,4,5,6,...,high_991,high_992,high_993,high_994,high_995,high_996,high_997,high_998,high_999,high_1000
0,2.0,0.2,0.0,1000.0,0.124,0.215,0.291,0.387,0.474,0.539,...,9.383029,9.353586,9.355911,9.363193,9.347976,9.334835,9.316342,9.344965,9.356432,9.370311
1,2.0,0.2,20.0,1000.0,0.035,0.076,0.086,0.106,0.152,0.205,...,5.714818,5.711069,5.70527,5.727089,5.729891,5.739044,5.740509,5.752502,5.757989,5.748878
2,2.0,0.2,40.0,1000.0,0.027,0.073,0.118,0.129,0.159,0.185,...,3.685315,3.702985,3.715149,3.709391,3.710877,3.712272,3.701473,3.693508,3.707075,3.718037
3,8.0,0.2,0.0,1000.0,0.175,0.353,0.507,0.67,0.868,1.049,...,54.614349,54.638959,54.650966,54.651836,54.676058,54.694476,54.726676,54.742591,54.744858,54.749152
4,8.0,0.2,20.0,1000.0,0.066,0.155,0.259,0.364,0.472,0.596,...,33.442853,33.451405,33.466481,33.473299,33.480007,33.468919,33.491863,33.507872,33.510522,33.522242
5,8.0,0.2,40.0,1000.0,0.057,0.146,0.216,0.275,0.325,0.4,...,21.424195,21.437711,21.428095,21.464551,21.457804,21.483657,21.499838,21.49577,21.465301,21.496614
6,14.0,0.2,0.0,1000.0,0.177,0.383,0.586,0.82,1.018,1.205,...,93.697242,93.73136,93.810269,93.836835,93.876321,93.91986,93.932141,93.986775,94.04085,94.085011
7,14.0,0.2,20.0,1000.0,0.123,0.233,0.354,0.465,0.597,0.712,...,57.309964,57.301015,57.331748,57.350967,57.356745,57.381175,57.42859,57.426764,57.443803,57.50118
8,14.0,0.2,40.0,1000.0,0.095,0.21,0.309,0.385,0.465,0.539,...,38.003216,38.013784,38.042674,38.058471,38.07332,38.105585,38.117769,38.12787,38.112338,38.118261


In [62]:
fs_df.to_csv("bern_fs_df.csv")