In [1]:
import numpy as np
import random 
from collections import defaultdict
import matplotlib.pyplot as plt

In [2]:
class Bandit:
    
    def __init__(self, k):
        """
        k: number of bandits 
        """
        self.k = k
        self.mean_sd_list = [] # Storing mean and sf of each bandit
        
        max_mean = 0
        self.max_i = 0
        
        for i in range(k):
            mean = np.random.randint(5, 100)
            sigma = random.uniform(0, 1)
            self.mean_sd_list.append((mean, sigma))
            
            if mean > max_mean:
                max_mean = mean
                self.max_i = i
        
    def generate_reward(self, i):
        mu, sigma = self.mean_sd_list[i]
        return np.random.normal(mu, sigma)
    
    def generate_optimum_reward(self):
        return self.generate_reward(self.max_i)

In [3]:
class Environment(object):
    
    def __init__(self, bandit):
        """
        bandit(object of class Bandit) to solve
        """
        
        self.bandit = bandit
        
        self.counts = [0] * self.bandit.k
        self.mean_estimate = [0] * self.bandit.k
        self.actions = []
        
    def reset(self):
        self.counts = [0] * self.bandit.k
        self.mean_estimate = [0] * self.bandit.k
        self.actions = []
        
        init_bandit = np.random.randint(0, self.bandit.k)
#         print("Init_bandit = ", init_bandit)
        return init_bandit
    
    def step(self, a):
        '''
        updates current state and returns next state
        a : chosen bandit index
        '''
#         print("a = ", a)
        reward = self.bandit.generate_reward(a)
        # cross check
        self.mean_estimate[a] = (self.mean_estimate[a] * self.counts[a] + reward)/ (self.counts[a] + 1)
        self.counts[a] += 1
        
        next_bandit = np.argmax(self.mean_estimate)
        
        return next_bandit, reward
        

In [149]:
def main(k, num_episodes = 10, compute_convergence = False, lr = 0.8, y = 0.95):
    """
    k = number of bandits
    """
    
    # defining object of class Bandit 
    bandit = Bandit(k) 
    print("Guassian distribution of bandits = \n", bandit.mean_sd_list)
    
    env = Environment(bandit)
    
    Q = np.zeros([k, k])

    reward_list = []
#     s = env.reset()
    state_transition_count = np.zeros(k)
    num_trials = 10000
    
    for i in range(num_episodes):
        
        if compute_convergence:
#             Q = np.zeros([k, k])
            state_transition_count = np.zeros(k)
            best_bandit_so_far = None
            best_bandit_streak = 0
            
        s = env.reset()
            
        total_reward = 0
        j = 0
        
        print("--- Episode {} ---".format(i))
        while j < num_trials:
            j += 1
            
#             a = np.argmax(Q[s,:] + np.random.randn(1, k) * (1./(i+1))) # not including noise for now
            
            if np.random.random() < 0.4:
                a = np.random.randint(0, k)
            else:
                a = np.argmax(Q[s, :])
                
            # Get new state and reward from environment
            s1, reward = env.step(a)
            
#             print("\ns = {}, a = {}, s1 = {}, reward = {}".format(s, a, s1, reward))
#             print("Mean estimate = ", env.mean_estimate)
            
            # Update Q-table with new knowledge
            Q[s,a] = Q[s,a] + lr * (reward + y * np.max(Q[s1,:]) - Q[s,a])
            state_transition_count[s] += 1
            
            total_reward += reward
            s = s1
             
#             if compute_convergence:
#                 best_bandits = np.argmax(Q, axis = 1)
                
#                 single_best_bandit = np.all(best_bandits == best_bandits[0])
# #                 print(state_transition_count)
# #                 if single_best_bandit and (0 not in state_transition_count):

#                 if single_best_bandit:
#                     if best_bandits[0] == best_bandit_so_far:
#                         best_bandit_streak += 1
#                     else:
#                         best_bandit_so_far = best_bandits[0]
#                         best_bandit_streak = 1
                        
#                     if best_bandit_streak > 20:
#                         print("\n\nQ = {}, \nbest_bandit = {}, \nstate_tranisition_count {}".format(Q, best_bandits, state_transition_count))
#                         print("Convergence at t = ", j)
#                         break
        if compute_convergence:
            best_bandits = np.argmax(Q, axis = 1)

            single_best_bandit = np.all(best_bandits == best_bandits[0])
#                 print(state_transition_count)
#                 if single_best_bandit and (0 not in state_transition_count):

            if single_best_bandit:
                print("\n\nQ = {}, \nbest_bandit = {}, \nstate_tranisition_count {}".format(Q, best_bandits, state_transition_count))

                
        
        reward_list.append(total_reward) 

    print("\nQ value Table : \n", Q)
    print("\nScore over time: ", sum(reward_list)/num_episodes)
    

In [151]:
main(k = 3, num_episodes = 100, compute_convergence = True)

Guassian distribution of bandits = 
 [(57, 0.23216005097060943), (78, 0.9458715020956356), (55, 0.3913833509954069)]
--- Episode 0 ---
--- Episode 1 ---
--- Episode 2 ---
--- Episode 3 ---
--- Episode 4 ---
--- Episode 5 ---
--- Episode 6 ---
--- Episode 7 ---
--- Episode 8 ---
--- Episode 9 ---
--- Episode 10 ---
--- Episode 11 ---
--- Episode 12 ---
--- Episode 13 ---
--- Episode 14 ---
--- Episode 15 ---
--- Episode 16 ---
--- Episode 17 ---
--- Episode 18 ---
--- Episode 19 ---
--- Episode 20 ---
--- Episode 21 ---
--- Episode 22 ---
--- Episode 23 ---
--- Episode 24 ---
--- Episode 25 ---
--- Episode 26 ---
--- Episode 27 ---
--- Episode 28 ---
--- Episode 29 ---
--- Episode 30 ---
--- Episode 31 ---
--- Episode 32 ---
--- Episode 33 ---
--- Episode 34 ---
--- Episode 35 ---
--- Episode 36 ---
--- Episode 37 ---
--- Episode 38 ---
--- Episode 39 ---
--- Episode 40 ---
--- Episode 41 ---
--- Episode 42 ---
--- Episode 43 ---
--- Episode 44 ---
--- Episode 45 ---
--- Episode 46 ---




Q = [[1540.61446692 1560.55002591 1534.16733172]
 [1541.49756047 1564.02363434 1538.66221446]
 [1537.53729932 1561.18623232 1537.41034054]], 
best_bandit = [1 1 1], 
state_tranisition_count [1.000e+00 9.999e+03 0.000e+00]
--- Episode 79 ---


Q = [[1540.61446692 1563.12917033 1534.16733172]
 [1540.74807216 1562.06259723 1536.88269182]
 [1537.53729932 1561.18623232 1537.41034054]], 
best_bandit = [1 1 1], 
state_tranisition_count [1.000e+00 9.999e+03 0.000e+00]
--- Episode 80 ---


Q = [[1540.61446692 1562.24248627 1534.16733172]
 [1539.50734975 1559.3871694  1537.42312835]
 [1540.82522328 1561.18623232 1537.41034054]], 
best_bandit = [1 1 1], 
state_tranisition_count [1.000e+00 9.998e+03 1.000e+00]
--- Episode 81 ---


Q = [[1540.61446692 1562.24248627 1534.16733172]
 [1533.62507985 1555.90047631 1535.21327074]
 [1540.82522328 1561.18623232 1537.41034054]], 
best_bandit = [1 1 1], 
state_tranisition_count [    0. 10000.     0.]
--- Episode 82 ---


Q = [[1540.61446692 1557.4588432  1

In [93]:
Q = [[1, 2, 3], [4, 5, 6]]
np.argmax(Q, axis = 1)

array([2, 2])

In [80]:
state_transition_count = [1, 1, 1]

In [81]:
if 0 in state_transition_count:
    print("yes")