Skip to content

Commit ab26f33

Browse files
author
User
committed
choose random argmax
1 parent 72f124a commit ab26f33

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

ab_testing/epsilon_greedy.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ def update(self, x):
3232
self.p_estimate = ((self.N - 1)*self.p_estimate + x) / self.N
3333

3434

35-
def experiment():
35+
def choose_random_argmax(a):
36+
idx = np.argwhere(np.amax(a) == a).flatten()
37+
return np.random.choice(idx)
38+
39+
40+
def experiment(argmax=choose_random_argmax):
41+
# argmax can also simply be np.argmax to choose the first argmax in case of ties
42+
3643
bandits = [BanditArm(p) for p in BANDIT_PROBABILITIES]
3744

3845
rewards = np.zeros(NUM_TRIALS)
@@ -50,7 +57,7 @@ def experiment():
5057
j = np.random.randint(len(bandits))
5158
else:
5259
num_times_exploited += 1
53-
j = np.argmax([b.p_estimate for b in bandits])
60+
j = argmax([b.p_estimate for b in bandits])
5461

5562
if j == optimal_j:
5663
num_optimal += 1

0 commit comments

Comments
 (0)