In [4]:
import numpy as np
import matplotlib.pyplot as plt

class Environment:
    def __init__(self, bandits, agents):
        self.n = 0
        self.bandits = bandits
        self.agents = agents
        self.optimal_selections = [0 for x in range(len(agents))]
        self.results = {x:[] for x in range(len(agents))}
        
    def optimal_selection(self):
        best_val, best_idx = None, None
        for idx, bandit in enumerate(self.bandits):
            if best_val == None or bandit.mean > best_val:
                best_val = bandit.mean
                best_idx = idx
        return best_idx
    
    def step(self):
        #print("Optimal selection: {}".format(self.optimal_selection()))
        self.n += 1
        for idx, agent in enumerate(self.agents):
            arm = agent.select_bandit()
            reward = self.bandits[arm].reward()
            #print("Agent {} with Q table: {} selected {} for reward of {}".format(idx, self.agents[idx].Q_table, arm, reward))
            agent.update_q(reward, arm)
            if arm == self.optimal_selection():
                self.optimal_selections[idx] += 1
            self.results[idx].append(self.optimal_selections[idx]/self.n)
        
        for bandit in self.bandits:
            bandit.update()                        
        
        

class Bandit:
    def __init__(self):
        self.mean = 1
        self.stddev = 0.1
    
    def update(self):
        disturb = np.random.normal(0, 0.01)
        self.mean += disturb
    
    def reward(self):
        return np.random.normal(self.mean, self.stddev)

class Agent:
    def __init__(self, av_meth, arm_count, eps):
        self.av_meth = av_meth
        self.Q_table = [0 for x in range(0,arm_count)]
        self.n_arm = [0 for x in range(0,arm_count)]
        self.eps = eps
        self.optimal_rate = None
        
    def update_q(self, r, arm):
        self.n_arm[arm] += 1
        self.Q_table[arm] = self.av_meth(self.Q_table[arm], r, self.n_arm[arm])
    
    def select_bandit(self):
        if np.random.uniform() > self.eps:
            return np.argmax(self.Q_table)
        else:
            return np.random.randint(0,len(self.Q_table))
        
def sample_av(Q_old, reward, n):
    return Q_old + (reward - Q_old) / n


def step_av(Q_old, reward, n):
    return Q_old + 0.1 * (reward - Q_old)


(1, array([ 1.,  1.]))
(101, array([ 0.48514851,  0.02970297]))
(201, array([ 0.30348259,  0.0199005 ]))
(301, array([ 0.25249169,  0.01328904]))
(401, array([ 0.26932668,  0.00997506]))
(501, array([ 0.23752495,  0.00798403]))
(601, array([ 0.19966722,  0.00665557]))
(701, array([ 0.17403709,  0.00713267]))
(801, array([ 0.18851436,  0.00749064]))
(901, array([ 0.26304107,  0.00776915]))
(1001, array([ 0.32167832,  0.00799201]))
(1101, array([ 0.30245232,  0.00817439]))
(1201, array([ 0.32223147,  0.00832639]))
(1301, array([ 0.35818601,  0.00999231]))
(1401, array([ 0.39685939,  0.00927909]))
(1501, array([ 0.41572285,  0.02531646]))
(1601, array([ 0.43660212,  0.03372892]))
(1701, array([ 0.42386831,  0.03233392]))
(1801, array([ 0.40421988,  0.03109384]))
(1901, array([ 0.41451867,  0.03156234]))
(2001, array([ 0.43828086,  0.02998501]))
(2101, array([ 0.45978106,  0.02855783]))
(2201, array([ 0.47932758,  0.02771468]))
(2301, array([ 0.48891786,  0.02651021]))
(2401, array([ 0.468

In [None]:
# Each bandit has uniform Q initially, varies as time goes (add unique np.random.normal(0,0.01) to each Q on each iteration)
# Create sample average method and fixed weight average, compare relative % optimal Q for each time step over 10k steps

total_steps = 50000
bandit_count = 10
eps = 0.1

bandits = []
results = []
agents = []

for i in range(bandit_count):
    bandits.append(Bandit())
    
agents.append(Agent(sample_av, bandit_count, eps))
agents.append(Agent(step_av, bandit_count, eps))
env = Environment(bandits, agents)

for step in range(total_steps):
    env.step()
#    if step % 100 == 0:
#        print(env.results[env.n - 1])

steps = np.arange(0,total_steps)

for i in range(len(env.agents)):
  plt.plot(steps, env.results[i], label="n={}".format(i))
leg = plt.legend()
plt.show()