In [1]:
import numpy as np
import pandas as pd
import helpers.mab_bernoulli as brn
import matplotlib.pyplot as plt
import multiprocess as mp
import time

In [2]:
def single_sim_wrapper(info): 
    import numpy as np
    import helpers.mab_bernoulli as brn

    idnum, task = info
    T = task[0]
    K = task[1]
    delta = task[2]
    data_amt = task[3]
    num_samples = task[4]
    
    arms = [0.5 - delta/2 for _ in range(K-1)]
    opt_mean = 0.5 + delta / 2
    arms.append(opt_mean)

    PERC_FREQ = 0.1
    print_intvl = int(num_samples*PERC_FREQ/100)
    print_intvl = print_intvl if print_intvl > 0 else 1

    regret_vecs = [np.zeros((num_samples, T)) for _ in range(2)]
    for i in range(num_samples):
        arm_data = [brn.gen_data(arms[i], data_amt) for i in range(K)]
        fs_rewards = brn.FS_path(T, arms, brn.flat_priors(K), arm_data)
        ar_rewards = brn.AR_path(T, arms, brn.flat_priors(K), arm_data)
    
        fs_regrets = opt_mean - np.array(fs_rewards)
        ar_regrets = opt_mean - np.array(ar_rewards)
    
        regret_vecs[0][i] = np.cumsum(fs_regrets)
        regret_vecs[1][i] = np.cumsum(ar_regrets)
        
        if((i+1)%(print_intvl) == 0):
            whitespace = " " * 20
            print("Task #" +str(idnum) + ":" + str(task) + " is " + str((i+1)/num_samples*100) + "% completed." + whitespace , end='\r')
    
    avgs = np.mean(regret_vecs, axis=1)
    lower_CBs = avgs-2*np.std(regret_vecs,axis=1)/np.sqrt(num_samples)
    upper_CBs = avgs+2*np.std(regret_vecs,axis=1)/np.sqrt(num_samples)
        
    fs_row = [K, delta, data_amt, num_samples] + list(avgs[0]) + list(lower_CBs[0]) + list(upper_CBs[0])
    ar_row = [K, delta, data_amt, num_samples] + list(avgs[1]) + list(lower_CBs[1]) + list(upper_CBs[1])
    return fs_row, ar_row

def gen_dfs(T):
    df_cols = ["NumArms", "ArmGap", "ArmData", "NumSamples"]
    c = list(range(1, T + 1))
    df_cols = df_cols + c + ["low_"+str(n) for n in c] + ["high_"+str(n) for n in c]
    fs_df = pd.DataFrame(columns=df_cols)
    ar_df = pd.DataFrame(columns=df_cols)
    return fs_df, ar_df

In [3]:
#parameter listing
T = 5000 #simulation horizon
#K_list = np.arange(2, 7, 2) #number of arms
K_list = [2, 4]
delta_list = [0.005, 1/80, 3/80] #reward gap between best and worst arms
data_amt_list = np.arange(0, 35, 10) #number of data points per arm
num_samples = 20000
tasks = [[T, K, delta, data_amt, num_samples] for K in K_list for delta in delta_list for data_amt in data_amt_list]
info = [(i, tasks[i]) for i in range(len(tasks))]
len(info)

24

In [4]:
num_batches = len(info) // 12 + 1
batches = [info[i*12:min(len(info), 12*(i+1))] for i in range(num_batches)]

In [9]:
start = time.time()
p = mp.Pool(12)
result_rows = p.map(single_sim_wrapper, batches[1])
print()
print(time.time()-start)


4859.802528381348


In [10]:
len(result_rows)

12

In [11]:
#Update stored DataFrame
#fs_df, ar_df = gen_dfs(T)
fs_df = pd.read_csv("bern_fs_df.csv", index_col=0)
ar_df = pd.read_csv("bern_ar_df.csv",index_col=0)
for fs_res, ar_res in result_rows:
    fs_df.loc[len(fs_df)] = np.array(fs_res)
    ar_df.loc[len(ar_df)] = np.array(ar_res)
fs_df.to_csv("bern_fs_df.csv")
ar_df.to_csv("bern_ar_df.csv")

In [12]:
fs_df

Unnamed: 0,NumArms,ArmGap,ArmData,NumSamples,1,2,3,4,5,6,...,high_4991,high_4992,high_4993,high_4994,high_4995,high_4996,high_4997,high_4998,high_4999,high_5000
0,2.0,0.2000,0.0,20000.0,0.10305,0.19275,0.28040,0.36300,0.45100,0.52535,...,11.329212,11.336390,11.331847,11.331856,11.327146,11.319418,11.322226,11.324416,11.320070,11.319073
1,2.0,0.2000,5.0,20000.0,0.06570,0.13045,0.19565,0.26775,0.32720,0.38290,...,10.268495,10.274069,10.275897,10.281950,10.283829,10.277579,10.276563,10.275309,10.276945,10.276038
2,2.0,0.2000,10.0,20000.0,0.05330,0.10680,0.15295,0.20775,0.26230,0.30925,...,9.752596,9.749942,9.748894,9.753722,9.755476,9.756147,9.750386,9.749071,9.748872,9.748723
3,2.0,0.2000,15.0,20000.0,0.04435,0.08650,0.12630,0.16725,0.21590,0.25555,...,8.475747,8.475325,8.478413,8.474076,8.476498,8.478826,8.473253,8.471882,8.473843,8.472562
4,2.0,0.2000,20.0,20000.0,0.04190,0.07645,0.10550,0.13825,0.17145,0.20125,...,8.095993,8.097309,8.101420,8.102856,8.100521,8.108739,8.110408,8.104582,8.107620,8.107545
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
424,4.0,0.0125,30.0,20000.0,0.01080,0.01655,0.02715,0.03665,0.04345,0.05380,...,39.607687,39.612950,39.616725,39.623714,39.625492,39.633665,39.633132,39.638172,39.649279,39.653136
425,4.0,0.0375,0.0,20000.0,0.02975,0.05765,0.08195,0.10610,0.13190,0.15970,...,70.841708,70.846436,70.859226,70.867372,70.874819,70.884045,70.896036,70.904035,70.909809,70.920212
426,4.0,0.0375,10.0,20000.0,0.02930,0.05125,0.07140,0.09270,0.11375,0.13585,...,69.047933,69.055787,69.060649,69.061952,69.069222,69.072867,69.073892,69.080190,69.089025,69.094835
427,4.0,0.0375,20.0,20000.0,0.03245,0.05140,0.07190,0.09610,0.12280,0.14560,...,69.002343,69.009585,69.015692,69.022973,69.037280,69.044746,69.051259,69.055473,69.057930,69.065722
