Skip to content

Commit

Permalink
Merge pull request #14 from ishikota/update-policy-api
Browse files Browse the repository at this point in the history
Update policy api
  • Loading branch information
ishikota committed Oct 23, 2016
2 parents c537864 + 33b29e4 commit e504ecf
Show file tree
Hide file tree
Showing 33 changed files with 124 additions and 132 deletions.
4 changes: 2 additions & 2 deletions kyoka/algorithm/base_rl_algorithm.py
Expand Up @@ -21,11 +21,11 @@ def GPI(self, domain, policy, value_function, finish_rules, debug=False):
[callback.after_gpi_finish(domain, value_function) for callback in self.callbacks]
return finish_msg

def generate_episode(self, domain, policy):
def generate_episode(self, domain, value_function, policy):
state = domain.generate_initial_state()
episode = []
while not domain.is_terminal_state(state):
action = policy.choose_action(state)
action = policy.choose_action(domain, value_function, state)
next_state = domain.transit_state(state, action)
reward = domain.calculate_reward(next_state)
episode.append((state, action, next_state, reward))
Expand Down
10 changes: 5 additions & 5 deletions kyoka/algorithm/montecarlo/montecarlo.py
Expand Up @@ -12,14 +12,14 @@ def update_value_function(self, domain, policy, value_function):
self.__validate_value_function(value_function)
self.__initialize_update_counter_if_needed(value_function)
update_counter = value_function.get_additinal_data(self.__KEY_ADDITIONAL_DATA)
episode = self.generate_episode(domain, policy)
episode = self.generate_episode(domain, value_function, policy)
for idx, turn_info in enumerate(episode):
if isinstance(value_function, BaseActionValueFunction):
self.__update_action_value_function(\
domain, policy, value_function, update_counter, episode, idx, turn_info)
domain, value_function, update_counter, episode, idx, turn_info)
elif isinstance(value_function, BaseStateValueFunction):
self.__update_state_value_function(\
domain, policy, value_function, update_counter, episode, idx, turn_info)
domain, value_function, update_counter, episode, idx, turn_info)
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, update_counter)

def __validate_value_function(self, value_function):
Expand All @@ -38,7 +38,7 @@ def __initialize_update_counter_if_needed(self, value_function):
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, update_counter)

def __update_action_value_function(\
self, domain, policy, value_function, update_counter, episode, idx, turn_info):
self, domain, value_function, update_counter, episode, idx, turn_info):
state, action, _next_state, _reward = turn_info
Q_value = value_function.calculate_value(state, action)
update_count = value_function.fetch_value_from_table(update_counter, state, action)
Expand All @@ -48,7 +48,7 @@ def __update_action_value_function(\
value_function.update_function(state, action, new_Q_value)

def __update_state_value_function(\
self, domain, policy, value_function, update_counter, episode, idx, turn_info):
self, domain, value_function, update_counter, episode, idx, turn_info):
state, action, next_state, reward = turn_info
Q_value = value_function.calculate_value(next_state)
update_count = value_function.fetch_value_from_table(update_counter, next_state)
Expand Down
12 changes: 6 additions & 6 deletions kyoka/algorithm/td_learning/q_lambda.py
Expand Up @@ -12,18 +12,18 @@ def __init__(self, alpha=0.1, gamma=0.9, eligibility_trace=None):
BaseTDMethod.__init__(self)
self.alpha = alpha
self.gamma = gamma
self.greedy_policy = GreedyPolicy()
self.trace = eligibility_trace if eligibility_trace else self.__generate_default_trace()

def update_action_value_function(self, domain, policy, value_function):
self.__setup_trace(value_function)
greedy_policy = GreedyPolicy(domain, value_function)
current_state = domain.generate_initial_state()
current_action = policy.choose_action(current_state)
current_action = policy.choose_action(domain, value_function, current_state)
while not domain.is_terminal_state(current_state):
next_state = domain.transit_state(current_state, current_action)
reward = domain.calculate_reward(next_state)
next_action = self.__choose_action(domain, policy, next_state)
greedy_action = self.__choose_action(domain, greedy_policy, next_state)
next_action = self.__choose_action(domain, policy, value_function, next_state)
greedy_action = self.__choose_action(domain, self.greedy_policy, value_function, next_state)
delta = self.__calculate_delta(value_function,\
current_state, current_action, next_state, greedy_action, reward)
self.trace.update(current_state, current_action)
Expand All @@ -48,11 +48,11 @@ def __calculate_new_Q_value(self,\
Q_value = value_function.calculate_value(state, action)
return Q_value + self.alpha * delta * eligibility

def __choose_action(self, domain, policy, state):
def __choose_action(self, domain, policy, value_function, state):
if domain.is_terminal_state(state):
return self.ACTION_ON_TERMINAL_FLG
else:
return policy.choose_action(state)
return policy.choose_action(domain, value_function, state)

def __calculate_value(self, value_function, next_state, next_action):
if self.ACTION_ON_TERMINAL_FLG == next_action:
Expand Down
12 changes: 6 additions & 6 deletions kyoka/algorithm/td_learning/q_learning.py
Expand Up @@ -9,16 +9,16 @@ def __init__(self, alpha=0.1, gamma=0.9):
BaseTDMethod.__init__(self)
self.alpha = alpha
self.gamma = gamma
self.greedy_policy = GreedyPolicy()

def update_action_value_function(self, domain, policy, value_function):
greedy_policy = GreedyPolicy(domain, value_function)
state = domain.generate_initial_state()
action = policy.choose_action(state)
action = policy.choose_action(domain, value_function, state)
while not domain.is_terminal_state(state):
next_state = domain.transit_state(state, action)
reward = domain.calculate_reward(next_state)
next_action = self.__choose_action(domain, policy, next_state)
greedy_action = self.__choose_action(domain, greedy_policy, next_state)
next_action = self.__choose_action(domain, policy, value_function, next_state)
greedy_action = self.__choose_action(domain, self.greedy_policy, value_function, next_state)
new_Q_value = self.__calculate_new_Q_value(\
value_function, state, action, next_state, greedy_action, reward)
value_function.update_function(state, action, new_Q_value)
Expand All @@ -31,11 +31,11 @@ def __calculate_new_Q_value(self,\
greedy_Q_value = self.__calculate_value(value_function, next_state, greedy_action)
return Q_value + self.alpha * (reward + self.gamma * greedy_Q_value - Q_value)

def __choose_action(self, domain, policy, state):
def __choose_action(self, domain, policy, value_function, state):
if domain.is_terminal_state(state):
return self.ACTION_ON_TERMINAL_FLG
else:
return policy.choose_action(state)
return policy.choose_action(domain, value_function, state)

def __calculate_value(self, value_function, next_state, next_action):
if self.ACTION_ON_TERMINAL_FLG == next_action:
Expand Down
8 changes: 4 additions & 4 deletions kyoka/algorithm/td_learning/sarsa.py
Expand Up @@ -11,11 +11,11 @@ def __init__(self, alpha=0.1, gamma=0.9):

def update_action_value_function(self, domain, policy, value_function):
state = domain.generate_initial_state()
action = policy.choose_action(state)
action = policy.choose_action(domain, value_function, state)
while not domain.is_terminal_state(state):
next_state = domain.transit_state(state, action)
reward = domain.calculate_reward(next_state)
next_action = self.__choose_action(domain, policy, next_state)
next_action = self.__choose_action(domain, policy, value_function, next_state)
new_Q_value = self.__calculate_new_Q_value(\
value_function, state, action, next_state, next_action, reward)
value_function.update_function(state, action, new_Q_value)
Expand All @@ -28,11 +28,11 @@ def __calculate_new_Q_value(self,\
next_Q_value = self.__calculate_value(value_function, next_state, next_action)
return Q_value + self.alpha * (reward + self.gamma * next_Q_value - Q_value)

def __choose_action(self, domain, policy, state):
def __choose_action(self, domain, policy, value_function, state):
if domain.is_terminal_state(state):
return self.ACTION_ON_TERMINAL_FLG
else:
return policy.choose_action(state)
return policy.choose_action(domain, value_function, state)

def __calculate_value(self, value_function, next_state, next_action):
if self.ACTION_ON_TERMINAL_FLG == next_action:
Expand Down
8 changes: 4 additions & 4 deletions kyoka/algorithm/td_learning/sarsa_lambda.py
Expand Up @@ -16,11 +16,11 @@ def __init__(self, alpha=0.1, gamma=0.9, eligibility_trace=None):
def update_action_value_function(self, domain, policy, value_function):
self.__setup_trace(value_function)
current_state = domain.generate_initial_state()
current_action = policy.choose_action(current_state)
current_action = policy.choose_action(domain, value_function, current_state)
while not domain.is_terminal_state(current_state):
next_state = domain.transit_state(current_state, current_action)
reward = domain.calculate_reward(next_state)
next_action = self.__choose_action(domain, policy, next_state)
next_action = self.__choose_action(domain, policy, value_function, next_state)
delta = self.__calculate_delta(\
value_function, current_state, current_action, next_state, next_action, reward)
self.trace.update(current_state, current_action)
Expand All @@ -43,11 +43,11 @@ def __calculate_new_Q_value(self, value_function, state, action, eligibility, de
Q_value = self.__calculate_value(value_function, state, action)
return Q_value + self.alpha * delta * eligibility

def __choose_action(self, domain, policy, state):
def __choose_action(self, domain, policy, value_function, state):
if domain.is_terminal_state(state):
return self.ACTION_ON_TERMINAL_FLG
else:
return policy.choose_action(state)
return policy.choose_action(domain, value_function, state)

def __calculate_value(self, value_function, next_state, next_action):
if self.ACTION_ON_TERMINAL_FLG == next_action:
Expand Down
12 changes: 4 additions & 8 deletions kyoka/policy/base_policy.py
Expand Up @@ -3,18 +3,14 @@

class BasePolicy(object):

def __init__(self, domain, value_function):
self.domain = domain
self.value_function = value_function

def choose_action(self, state, action=None):
def choose_action(self, domain, value_function, state, action=None):
err_msg = self.__build_err_msg("choose_action")
raise NotImplementedError(err_msg)

def pack_arguments_for_value_function(self, state, action):
if isinstance(self.value_function, BaseStateValueFunction):
def pack_arguments_for_value_function(self, value_function, state, action):
if isinstance(value_function, BaseStateValueFunction):
return [state]
elif isinstance(self.value_function, BaseActionValueFunction):
elif isinstance(value_function, BaseActionValueFunction):
return [state, action]
else:
raise ValueError("Invalid value function is set")
Expand Down
17 changes: 8 additions & 9 deletions kyoka/policy/epsilon_greedy_policy.py
Expand Up @@ -3,22 +3,21 @@

class EpsilonGreedyPolicy(BasePolicy):

def __init__(self, domain, value_function, eps=0.05, rand=None):
super(EpsilonGreedyPolicy, self).__init__(domain, value_function)
def __init__(self, eps=0.05, rand=None):
self.eps = eps
self.rand = rand if rand else random

def choose_action(self, state):
actions = self.domain.generate_possible_actions(state)
best_action = self.__choose_best_action(state)
def choose_action(self, domain, value_function, state):
actions = domain.generate_possible_actions(state)
best_action = self.__choose_best_action(domain, value_function, state)
probs = self.__calc_select_probability(best_action, actions)
selected_action_idx = self.__roulette(probs)
return actions[selected_action_idx]

def __choose_best_action(self, state):
actions = self.domain.generate_possible_actions(state)
pack = lambda state, action: self.pack_arguments_for_value_function(state, action)
calc_Q_value = lambda packed_arg: self.value_function.calculate_value(*packed_arg)
def __choose_best_action(self, domain, value_func, state):
actions = domain.generate_possible_actions(state)
pack = lambda state, action: self.pack_arguments_for_value_function(value_func, state, action)
calc_Q_value = lambda packed_arg: value_func.calculate_value(*packed_arg)
Q_value_for_actions = [calc_Q_value(pack(state, action)) for action in actions]
max_Q_value = max(Q_value_for_actions)
Q_act_pair = zip(Q_value_for_actions, actions)
Expand Down
11 changes: 5 additions & 6 deletions kyoka/policy/greedy_policy.py
Expand Up @@ -3,14 +3,13 @@

class GreedyPolicy(BasePolicy):

def __init__(self, domain, value_function, rand=None):
super(GreedyPolicy, self).__init__(domain, value_function)
def __init__(self, rand=None):
self.rand = rand if rand else random

def choose_action(self, state):
actions = self.domain.generate_possible_actions(state)
pack = lambda state, action: self.pack_arguments_for_value_function(state, action)
calc_Q_value = lambda packed_arg: self.value_function.calculate_value(*packed_arg)
def choose_action(self, domain, value_function, state):
actions = domain.generate_possible_actions(state)
pack = lambda state, action: self.pack_arguments_for_value_function(value_function, state, action)
calc_Q_value = lambda packed_arg: value_function.calculate_value(*packed_arg)
Q_value_for_actions = [calc_Q_value(pack(state, action)) for action in actions]
max_Q_value = max(Q_value_for_actions)
Q_act_pair = zip(Q_value_for_actions, actions)
Expand Down
4 changes: 2 additions & 2 deletions sample/maze/maze_helper.py
Expand Up @@ -8,11 +8,11 @@ def visualize_maze(self, maze):

@classmethod
def measure_performance(self, domain, value_function, step_limit=10000):
policy = GreedyPolicy(domain, value_function)
policy = GreedyPolicy()
state = domain.generate_initial_state()
step_counter = 0
while not domain.is_terminal_state(state):
action = policy.choose_action(state)
action = policy.choose_action(domain, value_function, state)
state = domain.transit_state(state, action)
step_counter += 1
if step_counter >= step_limit:
Expand Down
2 changes: 1 addition & 1 deletion sample/maze/notebook/blocking_maze_peformance_test.ipynb
Expand Up @@ -82,7 +82,7 @@
" domain.read_maze(maze_file_path(maze_type))\n",
" value_func = MazeTableValueFunction(domain.get_maze_shape())\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callbacks = gen_callbacks(maze_type, transform_timing)\n",
" [rl_algo.set_gpi_callback(callback) for callback in callbacks]\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
2 changes: 1 addition & 1 deletion sample/maze/notebook/dyna_maze_peformance_test.ipynb
Expand Up @@ -89,7 +89,7 @@
" domain.read_maze(maze_file_path(maze_type))\n",
" value_func = MazeTableValueFunction(domain.get_maze_shape())\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callbacks = gen_callbacks(maze_type, transform_timing)\n",
" [rl_algo.set_gpi_callback(callback) for callback in callbacks]\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
2 changes: 1 addition & 1 deletion sample/maze/notebook/maze_keras_value_function_test.ipynb
Expand Up @@ -89,7 +89,7 @@
" domain.read_maze(maze_file_path(maze_type))\n",
" value_func = MazeKerasValueFunction(domain)\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callbacks = gen_callbacks(maze_type, transform_timing)\n",
" [rl_algo.set_gpi_callback(callback) for callback in callbacks]\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
2 changes: 1 addition & 1 deletion sample/maze/notebook/shortcut_maze_performance_test.ipynb
Expand Up @@ -82,7 +82,7 @@
" domain.read_maze(maze_file_path(maze_type))\n",
" value_func = MazeTableValueFunction(domain.get_maze_shape())\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callbacks = gen_callbacks(maze_type, transform_timing)\n",
" [rl_algo.set_gpi_callback(callback) for callback in callbacks]\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
2 changes: 1 addition & 1 deletion sample/maze/script/measure_performance
Expand Up @@ -59,7 +59,7 @@ if os.path.isfile(VALUE_FUNC_SAVE_PATH):
log.info("finished loading value function")

TEST_LENGTH = 100
policy = EpsilonGreedyPolicy(domain, value_func, eps=0.1)
policy = EpsilonGreedyPolicy(eps=0.1)
watch_iteration = WatchIterationCount(target_count=TEST_LENGTH, log_interval=10000)
finish_rules = [watch_iteration]
callbacks = [MazePerformanceLogger()]
Expand Down
Expand Up @@ -112,7 +112,7 @@
" finish_rules = [watch_iteration]\n",
" value_func = TickTackToeKerasValueFunction()\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callback = gen_performance_logger()\n",
" rl_algo.set_gpi_callback(callback)\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
Expand Up @@ -112,7 +112,7 @@
" finish_rules = [watch_iteration]\n",
" value_func = TickTackToeKerasValueFunction()\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callback = gen_performance_logger()\n",
" rl_algo.set_gpi_callback(callback)\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
Expand Up @@ -113,7 +113,7 @@
" finish_rules = [watch_iteration]\n",
" value_func = TickTackToeTableValueFunction()\n",
" value_func.setUp()\n",
" policy = EpsilonGreedyPolicy(domain, value_func, eps=epsilon)\n",
" policy = EpsilonGreedyPolicy(eps=epsilon)\n",
" callback = gen_performance_logger()\n",
" rl_algo.set_gpi_callback(callback)\n",
" rl_algo.GPI(domain, policy, value_func, finish_rules)\n",
Expand Down
2 changes: 1 addition & 1 deletion sample/ticktacktoe/script/measure_performance
Expand Up @@ -49,7 +49,7 @@ TEST_INTERVAL = 10
domain = TickTackToeDomain(is_first_player=is_first_player)
value_func = TickTackToeTableValueFunction()
value_func.setUp()
policy = EpsilonGreedyPolicy(domain, value_func, eps=0.7)
policy = EpsilonGreedyPolicy(eps=0.7)
watch_iteration = WatchIterationCount(target_count=TEST_LENGTH, log_interval=10)
finish_rules = [watch_iteration]

Expand Down
9 changes: 5 additions & 4 deletions sample/ticktacktoe/script/play_ticktacktoe
Expand Up @@ -55,19 +55,20 @@ domains = [TickTackToeDomain(is_first_player=is_first) for is_first in [True, Fa
player_builder = defaultdict(lambda : GreedyPolicy)
player_builder["human"] = TickTackToeManualPolicy
player_builder["minimax"] = TickTackToePerfectPolicy
builders = [player_builder[algo] for algo in algos]
players = [builder(domain, func) for builder, domain, func in zip(builders, domains, value_funcs)]
players = [player_builder[algo]() for algo in algos]

next_is_first_player = lambda state: bin(state[0]|state[1]).count("1") % 2 == 0
next_player = lambda state: players[0] if next_is_first_player(state) else players[1]
next_player_idx = lambda state: 0 if next_is_first_player(state) else 1
show_board = lambda state: log.info("\n" + TickTackToeHelper.visualize_board(state))

log.info("start the game (%s vs %s)" % tuple(algos))
domain = domains[0]
state = domain.generate_initial_state()
show_board(state)
while not domain.is_terminal_state(state):
action = next_player(state).choose_action(state)
idx = next_player_idx(state)
domain, player, value_func = domains[idx], players[idx], value_funcs[idx]
action = player.choose_action(domain, value_func, state)
state = domain.transit_state(state, action)
show_board(state)

4 changes: 2 additions & 2 deletions sample/ticktacktoe/script/training_ticktacktoe_agent
Expand Up @@ -60,7 +60,7 @@ if os.path.isfile(VALUE_FUNC_SAVE_PATH):
value_func.load(VALUE_FUNC_SAVE_PATH)
log.info("finished loading value function")

policy = EpsilonGreedyPolicy(domain, value_func, eps=0.7)
policy = EpsilonGreedyPolicy(eps=0.7)
watch_iteration = WatchIterationCount(target_count=100000, log_interval=10000)
manual_interruption = ManualInterruption(monitor_file_path=INTERRUPTION_MONITOR_FILE_PATH, log_interval=10000)
finish_rules = [watch_iteration, manual_interruption]
Expand All @@ -69,6 +69,6 @@ log.info("started GPI iteration...")
RL_algo.GPI(domain, policy, value_func, finish_rules)

log.info("saving value function into %s" % VALUE_FUNC_SAVE_PATH)
value_func.save(VALUE_FUNC_SAVE_PATH)
#value_func.save(VALUE_FUNC_SAVE_PATH)
log.info("finished saving value function")

0 comments on commit e504ecf

Please sign in to comment.