In [1]:
import numpy as np
from scipy.optimize import root_scalar
import sympy as sp
import time
from tqdm import tqdm

# from scipy.optimize import fsolve

# # LUCB algorithm

# num_sim = 500 # for stopping time testing
num_sim = 100 # for runtime testing
num_arm = 4

total_run_time_array = np.zeros(num_sim)

def arm_selector(S_list, alpha):
  n = len(S_list)
  upp_bounds = np.ones(n)
  mus = np.ones(n) * 0.5

  for i in range(n):
    mus[i] = np.mean(S_list[i])
    t = len(S_list[i])
    rad = np.sqrt(np.log(405.5 * n * t ** (1.1) / alpha * np.log(405.5 * n * t ** (1.1) / alpha) ) / (2*t) )
    upp_bounds[i] = mus[i] + rad

  arm_1 = np.argmax(mus)
  upp_bounds[arm_1] = -1E10
  arm_2 = np.argmax(upp_bounds)

  return (arm_1, arm_2)


def e_val(S_list, mus, m):
  val = 0
  for i in range(4):
    val +=  0.25 * np.prod(1 + (np.array(S_list[i]) - np.asarray(m[i])) * (np.array(mus[i]) - np.asarray(m[i])) / 0.26)
  return val


solver= 'brenth'

# simulation for best arm identification using our approach
best_arm_list = []
stop_times = []

#mu = np.arange(4) + 1
#mu = (0.71 - 0.14*(mu - 1))/(0.29 + 0.14*(mu - 1))

mu = [0.29, 0.43, 0.57, 0.71]
horizon_len = 1000 # this is really 2000 samples (i.e. 1000 LUCB iterations)
# mu = [0.01, 0.011, 0.009, 0.013]

for j in tqdm(range(num_sim)):
  partitions = np.zeros(4)
    
  # set seed within simulation
  np.random.seed(j)
    
  S_list = [np.random.binomial(1, p, 1) for p in mu] # the last 1 represents the size
  # S_list = [[np.random.beta(1, p, 1)] for p in mu] # the last 1 represents the size
  mus = [[0.5]] * num_arm
  alpha = 0.05
  t = 4

  for i in range(horizon_len): # for runtime testing
  #while (np.sum(partitions)<3): # for stopping time testing
    which_arm = arm_selector(S_list, alpha)
    h = which_arm[0]
    l = which_arm[1]

    # update mus
    mus[h] = np.append(mus[h], [((1/2 + sum(S_list[h])) / (len(S_list[h])  + 1))])
    # print((1/2 + sum(S_list[h])) / (len(S_list[h])  + 1))
    mus[l] = np.append(mus[l], [((1/2 + sum(S_list[l])) / (len(S_list[l])  + 1))])
    # mus[l] = np.asarray(mus[l])

    # update samples
    # S_list[h] = np.append(S_list[h], np.random.beta(1, mu[h], 1))
    S_list[h] = np.append(S_list[h], np.random.binomial(1, mu[h], 1))
    # S_list[h] = np.asarray(S_list[h])
    # S_list[l] = np.append(S_list[l], np.random.beta(1, mu[l], 1))
    S_list[l] = np.append(S_list[l], np.random.binomial(1, mu[l], 1))
    # S_list[l] = np.asarray(S_list[l])

    ######### start timing
    start_time = time.time()


    # find the global minima in our set
    min_values = np.zeros(num_arm)
    coeffs = np.zeros(num_arm)
    for i in range(num_arm):
      S = np.asarray(S_list[i])
      mu_hat = np.asarray(mus[i])
      def equation(x):
        return sum((2*x-mu_hat-S)/(0.26 + (S-x)*(mu_hat-x)))
      root = root_scalar(equation, method=solver, bracket=[0,1], xtol=0.1**10)
      # print(root)
      min_values[i] = root.root

    # determine which arm is currently best
    ranks = np.argsort(min_values)[::-1] # Get the indices that would sort the array in descending order

    best_arm = ranks[0]

    #print(min_values)

    # minimum e-value globally corresponds to partition with best arm.
    test_values = np.zeros(num_arm)
    test_values[best_arm] = e_val(S_list, mus, min_values)

    # for the second best arm, we only need to project second best, best
    second_best = ranks[1]

    # get coefficient function
    def coeffs(x, S_list, mus, a):
      return np.prod(1/0.26* (0.26 + (np.asarray(S_list[a], dtype=np.float64) - x) * (np.asarray(mus[a], dtype=np.float64)-x ) ))

    def equation2(x): # we want to do element wise subtraction from x
      return coeffs(x, S_list, mus, best_arm) *  sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x)*(np.asarray(mus[best_arm])-x))) + \
      coeffs(x, S_list, mus, second_best) * sum((2*x-np.asarray(mus[second_best])-np.asarray(S_list[second_best]))/(0.26 + (np.asarray(S_list[second_best])-x)*(np.asarray(mus[second_best])-x)))
    val = root_scalar(equation2, method=solver, bracket=[0,1], xtol=0.1**10)
    temp = min_values.copy()
    temp[best_arm] = np.minimum(val.root, min_values[best_arm])
    temp[second_best] = val.root
    #print(temp)
    test_values[second_best] = e_val(S_list, mus, temp)


    # for the third best arm, we need to project third best, second best, best
    third_best = ranks[2]
    def equation3(x):
      return coeffs(x, S_list, mus, best_arm) * sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x)*(np.asarray(mus[best_arm])-x))) + \
       coeffs(x, S_list, mus, second_best) * (x < min_values[second_best]) *  sum((2*x-np.asarray(mus[second_best])-np.asarray(S_list[second_best]))/(0.26 + (np.asarray(S_list[second_best])-x)*(np.asarray(mus[second_best])-x))) + \
        coeffs(x, S_list, mus, third_best) * sum((2*x-np.asarray(mus[third_best])-np.asarray(S_list[third_best]))/(0.26 + (np.asarray(S_list[third_best])-x)*(np.asarray(mus[third_best])-x)))
    val = root_scalar(equation3, method=solver, bracket=[0,1], xtol=0.1**5)

    temp = min_values.copy()
    temp[best_arm] = np.minimum(val.root, min_values[best_arm])
    temp[second_best] = np.minimum(val.root, min_values[second_best])
    temp[third_best] = val.root
    #print(temp)
    test_values[third_best] = e_val(S_list, mus, temp)

    # for the fourth best arm, we project all values
    fourth_best = ranks[3]
    def equation4(x):
      return coeffs(x, S_list, mus, best_arm) * sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x) * (np.asarray(mus[best_arm])-x))) + \
        coeffs(x, S_list, mus, second_best) * (x < min_values[second_best]) *  sum((2*x-np.asarray(mus[second_best])-np.asarray(S_list[second_best]))/(0.26 + (np.asarray(S_list[second_best])-x)*(np.asarray(mus[second_best])-x))) + \
        coeffs(x, S_list, mus, third_best) * (x < min_values[third_best]) * sum((2*x-np.asarray(mus[third_best])-np.asarray(S_list[third_best]))/(0.26 + (np.asarray(S_list[third_best])-x)*(np.asarray(mus[third_best])-x))) + \
        coeffs(x, S_list, mus, fourth_best) * sum((2*x-np.asarray(mus[fourth_best])-np.asarray(S_list[fourth_best]))/(0.26 + (np.asarray(S_list[fourth_best])-x)*(np.asarray(mus[fourth_best])-x)))
    val = root_scalar(equation4, method=solver, bracket=[0,1], xtol=0.1**10)

    temp = min_values.copy()
    temp[best_arm] = np.minimum(val.root, min_values[best_arm])
    temp[second_best] = np.minimum(val.root, min_values[second_best])
    temp[third_best] = np.minimum(val.root, min_values[third_best])
    temp[fourth_best] = val.root
    #print(temp)
    test_values[fourth_best] = e_val(S_list, mus, temp)
    # print(test_values)
    #print(test_values)

    #check to see which arms can be eliminated
    reject = (test_values >= 1/alpha)
    # print(reject)
    partitions = np.maximum(reject, partitions)

    end_time = time.time()
    total_run_time_array[j] += end_time - start_time

    ########### end timing

    # update the number of pulls
    t += 2


  best_arm_list.append(np.where(partitions==0)[0][0])
  stop_times.append(t)
  print(t)


print(np.mean(total_run_time_array))
print(np.std(total_run_time_array))

#   print(stop_times)
#   print(best_arm_list)

print(np.mean(stop_times))
print(np.std(stop_times))




  return coeffs(x, S_list, mus, best_arm) *  sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x)*(np.asarray(mus[best_arm])-x))) + \
  return coeffs(x, S_list, mus, best_arm) * sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x)*(np.asarray(mus[best_arm])-x))) + \
  return coeffs(x, S_list, mus, best_arm) * sum((2*x-np.asarray(mus[best_arm])-np.asarray(S_list[best_arm]))/(0.26 + (np.asarray(S_list[best_arm])-x) * (np.asarray(mus[best_arm])-x))) + \
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  1%|▍                                          | 1/100 [00:05<08:40,  5.25s/it]

2004


  2%|▊                                          | 2/100 [00:10<08:34,  5.25s/it]

2004


  3%|█▎                                         | 3/100 [00:15<08:39,  5.36s/it]

2004


  4%|█▋                                         | 4/100 [00:21<08:27,  5.29s/it]

2004


  5%|██▏                                        | 5/100 [00:26<08:17,  5.24s/it]

2004


  6%|██▌                                        | 6/100 [00:31<08:20,  5.32s/it]

2004


  7%|███                                        | 7/100 [00:37<08:15,  5.33s/it]

2004


  8%|███▍                                       | 8/100 [00:42<08:09,  5.32s/it]

2004


  9%|███▊                                       | 9/100 [00:47<08:07,  5.36s/it]

2004


 10%|████▏                                     | 10/100 [00:53<08:00,  5.34s/it]

2004


 11%|████▌                                     | 11/100 [00:58<07:51,  5.30s/it]

2004


 12%|█████                                     | 12/100 [01:03<07:32,  5.14s/it]

2004


 13%|█████▍                                    | 13/100 [01:08<07:24,  5.11s/it]

2004


 14%|█████▉                                    | 14/100 [01:13<07:18,  5.09s/it]

2004


 15%|██████▎                                   | 15/100 [01:18<07:18,  5.16s/it]

2004


 16%|██████▋                                   | 16/100 [01:23<07:09,  5.12s/it]

2004


 17%|███████▏                                  | 17/100 [01:28<07:08,  5.16s/it]

2004


 18%|███████▌                                  | 18/100 [01:34<07:05,  5.19s/it]

2004


 19%|███████▉                                  | 19/100 [01:39<07:07,  5.27s/it]

2004


 20%|████████▍                                 | 20/100 [01:45<07:06,  5.33s/it]

2004


 21%|████████▊                                 | 21/100 [01:50<06:58,  5.30s/it]

2004


 22%|█████████▏                                | 22/100 [01:55<06:51,  5.28s/it]

2004


 23%|█████████▋                                | 23/100 [02:00<06:47,  5.29s/it]

2004


 24%|██████████                                | 24/100 [02:06<06:43,  5.31s/it]

2004


 25%|██████████▌                               | 25/100 [02:11<06:28,  5.18s/it]

2004


 26%|██████████▉                               | 26/100 [02:16<06:31,  5.30s/it]

2004


 27%|███████████▎                              | 27/100 [02:21<06:22,  5.24s/it]

2004


 28%|███████████▊                              | 28/100 [02:27<06:20,  5.28s/it]

2004


 29%|████████████▏                             | 29/100 [02:32<06:18,  5.33s/it]

2004


 30%|████████████▌                             | 30/100 [02:38<06:17,  5.39s/it]

2004


 31%|█████████████                             | 31/100 [02:43<06:05,  5.29s/it]

2004


 32%|█████████████▍                            | 32/100 [02:48<05:57,  5.26s/it]

2004


 33%|█████████████▊                            | 33/100 [02:53<05:45,  5.16s/it]

2004


 34%|██████████████▎                           | 34/100 [02:58<05:34,  5.07s/it]

2004


 35%|██████████████▋                           | 35/100 [03:03<05:37,  5.19s/it]

2004


 36%|███████████████                           | 36/100 [03:09<05:36,  5.26s/it]

2004


 37%|███████████████▌                          | 37/100 [03:14<05:36,  5.34s/it]

2004


 38%|███████████████▉                          | 38/100 [03:19<05:32,  5.36s/it]

2004


 39%|████████████████▍                         | 39/100 [03:25<05:23,  5.30s/it]

2004


 40%|████████████████▊                         | 40/100 [03:30<05:17,  5.29s/it]

2004


 41%|█████████████████▏                        | 41/100 [03:35<05:07,  5.21s/it]

2004


 42%|█████████████████▋                        | 42/100 [03:41<05:11,  5.36s/it]

2004


 43%|██████████████████                        | 43/100 [03:46<05:10,  5.44s/it]

2004


 44%|██████████████████▍                       | 44/100 [03:52<05:07,  5.49s/it]

2004


 45%|██████████████████▉                       | 45/100 [03:57<04:59,  5.44s/it]

2004


 46%|███████████████████▎                      | 46/100 [04:03<04:55,  5.47s/it]

2004


 47%|███████████████████▋                      | 47/100 [04:08<04:43,  5.34s/it]

2004


 48%|████████████████████▏                     | 48/100 [04:13<04:36,  5.32s/it]

2004


 49%|████████████████████▌                     | 49/100 [04:18<04:32,  5.34s/it]

2004


 50%|█████████████████████                     | 50/100 [04:23<04:22,  5.24s/it]

2004


 51%|█████████████████████▍                    | 51/100 [04:29<04:18,  5.27s/it]

2004


 52%|█████████████████████▊                    | 52/100 [04:34<04:18,  5.38s/it]

2004


 53%|██████████████████████▎                   | 53/100 [04:40<04:09,  5.30s/it]

2004


 54%|██████████████████████▋                   | 54/100 [04:45<04:06,  5.37s/it]

2004


 55%|███████████████████████                   | 55/100 [04:51<04:04,  5.43s/it]

2004


 56%|███████████████████████▌                  | 56/100 [04:56<03:59,  5.45s/it]

2004


 57%|███████████████████████▉                  | 57/100 [05:02<03:53,  5.44s/it]

2004


 58%|████████████████████████▎                 | 58/100 [05:07<03:48,  5.45s/it]

2004


 59%|████████████████████████▊                 | 59/100 [05:13<03:44,  5.47s/it]

2004


 60%|█████████████████████████▏                | 60/100 [05:18<03:34,  5.36s/it]

2004


 61%|█████████████████████████▌                | 61/100 [05:23<03:32,  5.44s/it]

2004


 62%|██████████████████████████                | 62/100 [05:28<03:24,  5.38s/it]

2004


 63%|██████████████████████████▍               | 63/100 [05:34<03:18,  5.35s/it]

2004


 64%|██████████████████████████▉               | 64/100 [05:39<03:06,  5.17s/it]

2004


 65%|███████████████████████████▎              | 65/100 [05:44<03:03,  5.25s/it]

2004


 66%|███████████████████████████▋              | 66/100 [05:50<03:01,  5.34s/it]

2004


 67%|████████████████████████████▏             | 67/100 [05:55<02:57,  5.39s/it]

2004


 68%|████████████████████████████▌             | 68/100 [06:00<02:51,  5.37s/it]

2004


 69%|████████████████████████████▉             | 69/100 [06:05<02:44,  5.30s/it]

2004


 70%|█████████████████████████████▍            | 70/100 [06:11<02:41,  5.39s/it]

2004


 71%|█████████████████████████████▊            | 71/100 [06:17<02:37,  5.41s/it]

2004


 72%|██████████████████████████████▏           | 72/100 [06:22<02:29,  5.34s/it]

2004


 73%|██████████████████████████████▋           | 73/100 [06:27<02:21,  5.22s/it]

2004


 74%|███████████████████████████████           | 74/100 [06:32<02:19,  5.35s/it]

2004


 75%|███████████████████████████████▌          | 75/100 [06:37<02:11,  5.26s/it]

2004


 76%|███████████████████████████████▉          | 76/100 [06:43<02:07,  5.32s/it]

2004


 77%|████████████████████████████████▎         | 77/100 [06:48<02:04,  5.42s/it]

2004


 78%|████████████████████████████████▊         | 78/100 [06:54<01:57,  5.35s/it]

2004


 79%|█████████████████████████████████▏        | 79/100 [06:59<01:52,  5.36s/it]

2004


 80%|█████████████████████████████████▌        | 80/100 [07:05<01:47,  5.39s/it]

2004


 81%|██████████████████████████████████        | 81/100 [07:10<01:40,  5.31s/it]

2004


 82%|██████████████████████████████████▍       | 82/100 [07:15<01:35,  5.33s/it]

2004


 83%|██████████████████████████████████▊       | 83/100 [07:20<01:30,  5.34s/it]

2004


 84%|███████████████████████████████████▎      | 84/100 [07:26<01:25,  5.33s/it]

2004


 85%|███████████████████████████████████▋      | 85/100 [07:31<01:20,  5.37s/it]

2004


 86%|████████████████████████████████████      | 86/100 [07:36<01:14,  5.30s/it]

2004


 87%|████████████████████████████████████▌     | 87/100 [07:41<01:08,  5.28s/it]

2004


 88%|████████████████████████████████████▉     | 88/100 [07:47<01:02,  5.20s/it]

2004


 89%|█████████████████████████████████████▍    | 89/100 [07:51<00:56,  5.12s/it]

2004


 90%|█████████████████████████████████████▊    | 90/100 [07:57<00:52,  5.23s/it]

2004


 91%|██████████████████████████████████████▏   | 91/100 [08:02<00:47,  5.29s/it]

2004


 92%|██████████████████████████████████████▋   | 92/100 [08:08<00:42,  5.33s/it]

2004


 93%|███████████████████████████████████████   | 93/100 [08:13<00:36,  5.27s/it]

2004


 94%|███████████████████████████████████████▍  | 94/100 [08:18<00:31,  5.32s/it]

2004


 95%|███████████████████████████████████████▉  | 95/100 [08:24<00:26,  5.30s/it]

2004


 96%|████████████████████████████████████████▎ | 96/100 [08:28<00:20,  5.17s/it]

2004


 97%|████████████████████████████████████████▋ | 97/100 [08:33<00:15,  5.11s/it]

2004


 98%|█████████████████████████████████████████▏| 98/100 [08:38<00:10,  5.05s/it]

2004


 99%|█████████████████████████████████████████▌| 99/100 [08:43<00:04,  4.96s/it]

2004


100%|█████████████████████████████████████████| 100/100 [08:48<00:00,  5.28s/it]

2004
5.212039794921875
0.22697138652251098
2004.0
0.0





In [17]:
# for stoppping times
print(np.mean(stop_times))
print(np.std(stop_times))
print(np.mean(np.equal(best_arm_list, 3)))


692.04
238.42650523798733
1.0


In [2]:
# for runtime testing\
print(np.mean(total_run_time_array))
print(np.std(total_run_time_array))

5.212039794921875
0.22697138652251098
