In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
class Bandit(object):
    def __init__(self, k, q_star_mean=5, policy='e-greedy-0.1', optimistic_init=15, softmax_temp=0.5):
        self.k = k
        self.rewards_mean = np.random.normal(loc=q_star_mean, size=k)
        self.policy = policy
        self.softmax_temp = softmax_temp
        
        self.default = optimistic_init if policy == 'optimistic_init' else 0

        self.sum_r = np.zeros(k)
        self.sum_r[:] = self.default
        self.q = None
        
        self.pulls = np.zeros(k)
        self.pulls_counter = 0
    
    def pull(self):
        self.pulls_counter += 1
        
        arm = self._pick_arm()
        r = self._reward(arm)
        self._update_q(r, arm)
        
        self.pulls[arm] += 1
        
        return r
    
    def get_max_avg_reward(self):
        return np.max(self._get_q())
    
    def _get_q(self, arm=None):
        if arm is not None:
            return self.sum_r[arm]/self.pulls[arm] if self.pulls[arm] > 0 else self.sum_r[arm]
        if self.q is None or self.last_q_update != self.pulls_counter:
            q = self.sum_r/self.pulls
            q[np.isnan(q)]=self.default
            self.q = q
            self.last_q_update = self.pulls_counter
        return self.q
    
    def _pick_arm(self):
        if self.policy == 'e-greedy-0.1':
            if random.random() < 0.1:
                arm = random.randrange(self.k)
            else:
                arm = np.argmax(self._get_q())
        
        elif self.policy == 'e-greedy-0.01':
            if random.random() < 0.01:
                arm = random.randrange(self.k)
            else:
                arm = np.argmax(self._get_q())
                
        elif self.policy == 'optimistic_init':
            arm = np.argmax(self._get_q())
        
        elif self.policy == 'softmax':
            e_q = np.exp(self._get_q()/self.softmax_temp)
            softmax = e_q / e_q.sum()
            arm = np.random.choice(range(self.k), replace=False, p=softmax)
            
        return arm
                
    def _reward(self, arm):
        return np.random.normal(self.rewards_mean[arm])
    
    def _update_q(self, r, arm):
        if self.pulls[arm] > 0:
            self.sum_r[arm] += r
        else:
            self.sum_r[arm] = r
            

In [12]:
arms = 100
num_pulls = 50000
num_bandits = 2000

## Test

In [None]:
bandits = []
for _ in range(num_bandits):
    bandits.append(Bandit(arms, policy='e-greedy-0.1'))
    
avg_rewards = []

for i in range(num_pulls):
    r_sum = 0
    for bandit in bandits:
        r = bandit.pull()
        r_sum += r #bandit.get_max_avg_reward()
    avg_r = r_sum/num_bandits
    avg_rewards.append(avg_r)
    
    if i % 100 == 0:
        print(i, avg_r)



0 5.004028266802621
100 6.301084198403233
200 6.543498298125811
300 6.665562381928693
400 6.808800121770313
500 6.847799694725674
600 6.962459845171583
700 6.9631934344010995
800 6.931498135852731
900 6.988607680077259
1000 7.023015774444564
1100 7.025080971586792
1200 7.109009471193682
1300 7.094447896963357
1400 7.058022485400568
1500 7.106381766502763
1600 7.145300027631724
1700 7.068481958545356
1800 7.12425044611362
1900 7.152412714994709
2000 7.171585283429531
2100 7.083388165705567
2200 7.144736247960285
2300 7.193489488390997
2400 7.179282230387227
2500 7.141778257350164
2600 7.135240925238093
2700 7.1153392520434435
2800 7.161413162338087
2900 7.20584404671059
3000 7.199934123425352
3100 7.201518478629402
3200 7.1729287996859545
3300 7.22003680333905
3400 7.139186504880618
3500 7.1900489264466785
3600 7.202902800296073
3700 7.235932988958584
3800 7.183567353266
3900 7.24353965803939
4000 7.24501646221321
4100 7.18334361574658
4200 7.218597324572581
4300 7.214312246819986
4400 

34700 7.223970824974586
34800 7.233594235520678
34900 7.24455055897669
35000 7.241313254925546
35100 7.265607853832955
35200 7.207392805805486
35300 7.255923888826426
35400 7.269001714214508
35500 7.24800725302077
35600 7.202414115369728
35700 7.226281904348824
35800 7.269945792357526
35900 7.273031997576251
36000 7.239892816130572
36100 7.235568846604561
36200 7.236990465356543
36300 7.2346728886353375
36400 7.2549287285368385
36500 7.256501344956379
36600 7.234272800996341
36700 7.259758257261624
36800 7.228246061668152
36900 7.273467280272444
37000 7.256625515065954
37100 7.250873545990006
37200 7.254409661761771
37300 7.224790007894162
37400 7.26101162885997
37500 7.247837499393961
37600 7.279553245157834
37700 7.232419639753414
37800 7.233431125770954
37900 7.293884620077842
38000 7.1875558626733165
38100 7.250314600124771
38200 7.2701006747544765
38300 7.2344059942944
38400 7.286553042052746
38500 7.264273312575273
38600 7.2829818144750815
38700 7.303145476771132
38800 7.26528398

In [None]:
df = pd.DataFrame({'avg_reward': avg_rewards})
np.max(avg_rewards)

In [None]:
df.plot()
plt.show()