* mab 는 state 제외된 RL 문제
* 푸는 방법은 많은데 보통 q 벡터 업데이트 하는 방식으로 많이 진행
* 업데이트 방식이 기초 RL 과 거의 동일함

In [110]:
import numpy as np

class MAB:
    def __init__(self, n_arms):
        self.n_arms = n_arms
        self.q_star = np.random.normal(0, 1, n_arms) # q*(a) 일종의 보상. 확률적으로 바꿀 수도 있음.
        self.action_count = np.zeros(n_arms) # N(a)
        self.q_estimates = np.zeros(n_arms) # Q(a)

    def pull(self, action):
        reward = np.random.normal(self.q_star[action], 1)
        self.action_count[action] += 1
        # q learning 업데이트 식과 유사. 정규화항은 alpha 값과 유사
        self.q_estimates[action] += (reward - self.q_estimates[action]) / self.action_count[action]
        return reward

    # 둘 중 하나.
    # epsilon_greedy vs ucb
    def epsilon_greedy(self, epsilon):
        if np.random.rand() < epsilon:
            action = np.random.randint(self.n_arms)
        else:
            action = np.argmax(self.q_estimates)
        return action

    def ucb(self, c):
        ucb_estimates = self.q_estimates + c * np.sqrt(np.log(np.sum(self.action_count)) / self.action_count)
        action = np.argmax(ucb_estimates)
        return action

# test the MAB class
mab = MAB(10)
num_steps = 10000
total_reward = 0

for i in range(num_steps):
    # action = mab.epsilon_greedy(0.1) # or mab.ucb(2)
    action = mab.ucb(2)
    reward = mab.pull(action)
    total_reward += reward

print('Total reward:', total_reward)


Total reward: 4789.373572626772


  ucb_estimates = self.q_estimates + c * np.sqrt(np.log(np.sum(self.action_count)) / self.action_count)
  ucb_estimates = self.q_estimates + c * np.sqrt(np.log(np.sum(self.action_count)) / self.action_count)
  ucb_estimates = self.q_estimates + c * np.sqrt(np.log(np.sum(self.action_count)) / self.action_count)
  ucb_estimates = self.q_estimates + c * np.sqrt(np.log(np.sum(self.action_count)) / self.action_count)


In [111]:
# real reward
mab.q_star

array([-1.60151656, -2.06636753, -0.54086025, -1.23370859,  0.51244845,
       -0.65708575, -0.26845323, -0.07360806, -0.2359048 , -0.78621137])

In [112]:
# 가치 추정
mab.q_estimates

array([-1.79355076, -2.08394222, -0.46405708, -1.47456643,  0.50801201,
       -0.64615912, -0.19272391, -0.02959981, -0.17589417, -0.76903622])

In [113]:
mab.action_count

array([7.000e+00, 7.000e+00, 3.600e+01, 9.000e+00, 9.658e+03, 2.500e+01,
       6.600e+01, 1.040e+02, 6.700e+01, 2.100e+01])