File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -32,7 +32,14 @@ def update(self, x):
3232 self .p_estimate = ((self .N - 1 )* self .p_estimate + x ) / self .N
3333
3434
35- def experiment ():
35+ def choose_random_argmax (a ):
36+ idx = np .argwhere (np .amax (a ) == a ).flatten ()
37+ return np .random .choice (idx )
38+
39+
40+ def experiment (argmax = choose_random_argmax ):
41+ # argmax can also simply be np.argmax to choose the first argmax in case of ties
42+
3643 bandits = [BanditArm (p ) for p in BANDIT_PROBABILITIES ]
3744
3845 rewards = np .zeros (NUM_TRIALS )
@@ -50,7 +57,7 @@ def experiment():
5057 j = np .random .randint (len (bandits ))
5158 else :
5259 num_times_exploited += 1
53- j = np . argmax ([b .p_estimate for b in bandits ])
60+ j = argmax ([b .p_estimate for b in bandits ])
5461
5562 if j == optimal_j :
5663 num_optimal += 1
You can’t perform that action at this time.
0 commit comments