<a href="https://colab.research.google.com/github/howakuro/TreeBandit/blob/master/TreeBandit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##各種インポート

In [0]:
import numpy as np
import matplotlib.pyplot as plt
import random
import math

##ツリーバンディット環境
![ツリー構造](https://i.imgur.com/IRjxJMu.jpg)<br>
上記のような木構造のバンディットを探索する。合計期待値が最も高いルートは状態1で0.4の腕を選択し、状態2で0.9の腕を選択することである。


In [0]:
class TreeBandit():
    def __init__(self):
        self._arm_num = 4  # 腕の数
        self._state_num = 5  # 状態数
        self._playcount = 0  # 腕を回した回数
        self._arm_probability = [[0.4, 0.5, 0.6, 0.7],  # 状態1
                                 [0.9, 0.8, 0.1, 0.2],  # 状態2
                                 [0.7, 0.1, 0.3, 0.6],  # 状態3
                                 [0.6, 0.4, 0.1, 0.3],  # 状態4
                                 [0.5, 0.2, 0.3, 0.1]]  # 状態5
    #環境のリセット
    def reset(self):
        self._playcount = 0
        return 0  # 状態1に初期化

    # 設定された腕の確率を取得する
    def get_probability(self):
        return self._arm_probability

    # 選んだアームのスロットを回す
    def step(self, state, action):
        if random.random() <= self._arm_probability[state][action]:
            reward = 1.0
        else:
            reward = 0.0
        state += (action + 1)
        self._playcount += 1
        done = False
        if self._playcount == 2:
            done = True
            state = None
        return state, reward, done, None

##価値関数クラス

##期待値クラス

In [0]:
class Expected_Value():
    def __init__(self, state_num, action_num):
        self.state_num = state_num #状態の数
        self.action_num = action_num #とり得る行動の数
       
    def episode_reset(self):
        pass

    def get_value(self):
        return self.E
    
    def update(self, state, action, reward, next_state):
        self.play_count[state][action] += 1
        self.hit_count[state][action] += reward
        self.E[state][action] = self.hit_count[state][action] / self.play_count[state][action]
    
    def sim_reset(self):
        self.play_count = [[0 for j in range(self.action_num)] for i in range(self.state_num)]
        self.hit_count = [[0 for j in range(self.action_num)] for i in range(self.state_num)]
        self.E = [[0 for j in range(self.action_num)] for i in range(self.state_num)]

##Q関数クラス

In [0]:
class Q_Learning():
    def __init__(self, state_num, action_num, alpha =0.1,gamma = 0.99):
        self.state_num = state_num #状態の数
        self.action_num = action_num #とり得る行動の数
        self.alpha = alpha
        self.gamma = gamma
       
    def episode_reset(self):
        pass

    def get_value(self):
        return self.Q
    
    def update(self, state, action, reward, next_state):
        self.play_count[state][action] += 1
        self.hit_count[state][action] += reward
        if next_state == None:
            next_Q = 0
        else:
            next_Q = max(self.Q[next_state])
        td_error = reward + (self.gamma * next_Q - self.Q[state][action])
        self.Q[state][action] += self.alpha * td_error
    
    def sim_reset(self):
        self.play_count = [[0 for j in range(self.action_num)] for i in range(self.state_num)]
        self.hit_count = [[0 for j in range(self.action_num)] for i in range(self.state_num)]
        self.Q = [[0 for j in range(self.action_num)] for i in range(self.state_num)]

##方策関数

###ε減衰型ε-Greedyクラス

In [0]:
class Decay_E_Greedy():
    def __init__(self, epsilon, decay_step, decay_episode):
        self.start_epsilon = epsilon
        self.decay_step = decay_step
        self.decay_episode = decay_episode

    def sim_reset(self):
        self.epsilon = self.start_epsilon

    def episode_reset(self):
        self.epsilon -= self.decay_episode

    #アーム選択
    def select_action(self,value_func,state):
        self.epsilon -= self.decay_step
        if self.epsilon <= random.random():  
            idx = np.where(value_func[state] == np.max(value_func[state]))
            return np.random.choice(idx[0])
        else:
           return np.random.choice([0, 1, 2, 3]) 

##エージェントクラス

In [0]:
class Agent():
    def __init__(self, value, policy):
        self.value = value#価値推定クラス
        self.policy = policy#方策クラス

    def select_action(self, state):
        value = self.value.get_value()
        return self.policy.select_action(value, state)

    def update(self, state, action, reward, next_state):
        self.value.update(state, action, reward, next_state)

    def sim_reset(self):
        self.value.sim_reset()
        self.policy.sim_reset()

    def episode_reset(self):
        self.value.episode_reset()
        self.policy.episode_reset()

##シミュレーション関数

In [0]:
def simulation(env, agent, agent_name, sim_num, episode_num, step_num):
    print("\n【Agent_Name】\n",agent_name)
    for sim in range(sim_num):
        agent.sim_reset()
        for episode in range(episode_num):
            state = env.reset()
            agent.episode_reset()
            for step in range(step_num):
                action = agent.select_action(state)
                next_state, reward, done, info = env.step(state, action)
                agent.update(state, action, reward, next_state)
                if done:
                    break
                state = next_state
    print("【Q_Array】\n",agent.value.get_value())
    print("【Play_count】\n",agent.value.play_count)

##メイン

In [31]:
ENV = TreeBandit()
SIMULATION_NUM = 1
EPISODE_NUM = 10000
STEP_NUM = 2
ARM_NUM = 4#アームの数
STATE_NUM = 5#状態数
Agent_list = [
              Agent(Expected_Value(STATE_NUM, ARM_NUM), Decay_E_Greedy(0.1, 0.0, 0.0)),
              Agent(Q_Learning(STATE_NUM,ARM_NUM) , Decay_E_Greedy(1.0, 0.0 , (1.0 / 2000)))
             ]
Agent_Name = [
              "Expected,Decay_E_Greedy",
              "Q_Learning,Decay_E_Greedy"
             ]
for agent,name in zip(Agent_list,Agent_Name):
   simulation(ENV, agent, name, SIMULATION_NUM, EPISODE_NUM, STEP_NUM)


【Agent_Name】
 Expected,Decay_E_Greedy
【Q_Array】
 [[0.4065040650406504, 0.48917748917748916, 0.5970287836583101, 0.6946483542505328], [0.8741721854304636, 0.8360655737704918, 0.2, 0.0], [0.7282051282051282, 0.0, 0.3333333333333333, 0.5833333333333334], [0.6035751840168244, 0.38461538461538464, 0.08333333333333333, 0.2894736842105263], [0.5023678484576987, 0.25, 0.3119266055045872, 0.09004739336492891]]
【Play_count】
 [[246, 231, 1077, 8446], [151, 61, 30, 4], [195, 6, 6, 24], [951, 26, 24, 76], [7813, 204, 218, 211]]

【Agent_Name】
 Q_Learning,Decay_E_Greedy
【Q_Array】
 [[1.1621912852037761, 0.8848813113536277, 0.8810154362488262, 0.8783852348744439], [0.8458403503849545, 0.5423465579229474, 0.13598598991865526, 0.24412542702876458], [0.6318868386267796, 0.049508018578181545, 0.2944766993525858, 0.40519820283099767], [0.45089553254764725, 0.3817999537161761, 0.031381059609000006, 0.23005347496038836], [0.41170759204180923, 0.1272340660266844, 0.06448682975971405, 0.040405263754163374]]
【P