Skip to content

Commit

Permalink
Merge e666930 into e537c5b
Browse files Browse the repository at this point in the history
  • Loading branch information
ishikota committed Nov 10, 2016
2 parents e537c5b + e666930 commit eb4c6d6
Show file tree
Hide file tree
Showing 19 changed files with 1,544 additions and 0 deletions.
Empty file added kyoka/algorithm_/__init__.py
Empty file.
64 changes: 64 additions & 0 deletions kyoka/algorithm_/montecarlo.py
@@ -0,0 +1,64 @@
import os

from kyoka.utils import pickle_data, unpickle_data
from kyoka.value_function_ import BaseTabularActionValueFunction
from kyoka.algorithm_.rl_algorithm import BaseRLAlgorithm, generate_episode

class MonteCarlo(BaseRLAlgorithm):

def setup(self, task, policy, value_function):
validate_value_function(value_function)
super(MonteCarlo, self).setup(task, policy, value_function)

def run_gpi_for_an_episode(self, task, policy, value_function):
episode = generate_episode(task, policy, value_function)
for idx, turn_info in enumerate(episode):
state, action, _next_state, _reward = turn_info
following_reward = self._calculate_following_state_reward(idx, episode)
value_function.backup(state, action, following_reward, alpha="dummy")

def _calculate_following_state_reward(self, current_turn, episode):
following_turn_info = episode[current_turn:]
following_reward = [reward for _, _, _, reward in following_turn_info]
return sum(following_reward)

class MontCarloTabularActionValueFunction(BaseTabularActionValueFunction):

SAVE_FILE_NAME = "montecarlo_update_counter.pickle"

def setup(self):
super(MontCarloTabularActionValueFunction, self).setup()
self.update_counter = self.generate_initial_table()

def define_save_file_prefix(self):
return "montecarlo"

def save(self, save_dir_path):
super(MontCarloTabularActionValueFunction, self).save(save_dir_path)
pickle_data(self._gen_update_counter_file_path(save_dir_path), self.update_counter)

def load(self, load_dir_path):
super(MontCarloTabularActionValueFunction, self).load(load_dir_path)
if not os.path.exists(self._gen_update_counter_file_path(load_dir_path)):
raise IOError('The saved data of "MonteCarlo" algorithm is not found in [ %s ]'% load_dir_path)
self.update_counter = unpickle_data(self._gen_update_counter_file_path(load_dir_path))

def backup(self, state, action, backup_target, alpha):
update_count = self.fetch_value_from_table(self.update_counter, state, action)
Q_value = self.fetch_value_from_table(self.table, state, action)
new_value = self._calc_average_in_incremental_way(update_count, backup_target, Q_value)
self.insert_value_into_table(self.table, state, action, new_value)
self.insert_value_into_table(self.update_counter, state, action, update_count+1)

def _calc_average_in_incremental_way(self, k, r, Q):
return Q + 1.0 / (k + 1) * (r - Q)

def _gen_update_counter_file_path(self, dir_path):
return os.path.join(dir_path, self.SAVE_FILE_NAME)

def validate_value_function(value_function):
if not isinstance(value_function, MontCarloTabularActionValueFunction):
err_msg = 'MonteCarlo method requires you to use "table" type function.\
(child class of [BaseTableStateValueFunction or BaseTableActionValueFunction])'
raise TypeError(err_msg)

78 changes: 78 additions & 0 deletions kyoka/algorithm_/rl_algorithm.py
@@ -0,0 +1,78 @@
from kyoka.utils import build_not_implemented_msg
from kyoka.policy_ import EpsilonGreedyPolicy
from kyoka.callback_ import EpsilonAnnealer
from kyoka.callback_ import WatchIterationCount

def generate_episode(task, policy, value_function):
state = task.generate_initial_state()
episode = []
while not task.is_terminal_state(state):
action = policy.choose_action(task, value_function, state)
next_state = task.transit_state(state, action)
reward = task.calculate_reward(next_state)
episode.append((state, action, next_state, reward))
state = next_state
return episode

class BaseRLAlgorithm(object):

def setup(self, task, policy, value_function):
self.task = task
self.value_function = value_function
self.value_function.setup()
self.policy = policy

def save(self, save_dir_path):
self.value_function.save(save_dir_path)
self.save_algorithm_state(save_dir_path)

def load(self, load_dir_path):
self.value_function.load(load_dir_path)
self.load_algorithm_state(load_dir_path)

def save_algorithm_state(self, save_dir_path):
pass

def load_algorithm_state(self, load_dir_path):
pass

def run_gpi_for_an_episode(self, task, policy, value_function):
err_msg = build_not_implemented_msg(self, "run_gpi_for_an_episode")
raise NotImplementedError(err_msg)

def run_gpi(self, nb_iteration, callbacks=None, verbose=1):
self.__check_setup_call()
default_finish_rule = WatchIterationCount(nb_iteration, verbose)
callbacks = self.__setup_callbacks(default_finish_rule, callbacks)
[callback.before_gpi_start(self.task, self.value_function) for callback in callbacks]

iteration_counter = 1
while True:
[callback.before_update(iteration_counter, self.task, self.value_function) for callback in callbacks]
self.run_gpi_for_an_episode(self.task, self.policy, self.value_function)
[callback.after_update(iteration_counter, self.task, self.value_function) for callback in callbacks]
for finish_rule in callbacks:
if finish_rule.interrupt_gpi(iteration_counter, self.task, self.value_function):
[callback.after_gpi_finish(self.task, self.value_function) for callback in callbacks]
if finish_rule != default_finish_rule:
default_finish_rule.log(default_finish_rule.generate_finish_message(iteration_counter))
return
iteration_counter += 1


def __check_setup_call(self):
if not all([hasattr(self, attr) for attr in ["task", "value_function", "policy"]]):
raise Exception('You need to call "setup" method before calling "run_gpi" method.')

def __setup_callbacks(self, default_finish_rule, user_callbacks):
user_callbacks = self.__wrap_item_if_single(user_callbacks)
default_callbacks = [default_finish_rule]
if isinstance(self.policy, EpsilonGreedyPolicy) and self.policy.do_annealing:
default_callbacks.append(EpsilonAnnealer(self.policy))
return default_callbacks + user_callbacks

def __wrap_item_if_single(self, item):
if item is None: item = []
if not isinstance(item, list): item = [item]
return item

48 changes: 48 additions & 0 deletions kyoka/algorithm_/sarsa.py
@@ -0,0 +1,48 @@
import os
from kyoka.utils import pickle_data, unpickle_data
from kyoka.value_function_ import BaseTabularActionValueFunction
from kyoka.algorithm_.rl_algorithm import BaseRLAlgorithm, generate_episode

class Sarsa(BaseRLAlgorithm):

def __init__(self, alpha=0.1, gamma=0.9):
self.alpha = alpha
self.gamma = gamma

def run_gpi_for_an_episode(self, task, policy, value_function):
state = task.generate_initial_state()
action = policy.choose_action(task, value_function, state)
while not task.is_terminal_state(state):
next_state = task.transit_state(state, action)
next_action = choose_action(task, policy, value_function, next_state)
reward = task.calculate_reward(next_state)
next_Q_value = predict_value(value_function, next_state, next_action)
backup_target = reward + self.gamma * next_Q_value
value_function.backup(state, action, backup_target, self.alpha)
state, action = next_state, next_action

class SarsaTabularActionValueFunction(BaseTabularActionValueFunction):

def define_save_file_prefix(self):
return "sarsa"

def backup(self, state, action, backup_target, alpha):
Q_value = self.predict_value(state, action)
new_Q_value = Q_value + alpha * (backup_target - Q_value)
self.insert_value_into_table(self.table, state, action, new_Q_value)


ACTION_ON_TERMINAL_FLG = "action_on_terminal"

def choose_action(task, policy, value_function, state):
if task.is_terminal_state(state):
return ACTION_ON_TERMINAL_FLG
else:
return policy.choose_action(task, value_function, state)

def predict_value(value_function, next_state, next_action):
if ACTION_ON_TERMINAL_FLG == next_action:
return 0
else:
return value_function.predict_value(next_state, next_action)

179 changes: 179 additions & 0 deletions kyoka/callback_.py
@@ -0,0 +1,179 @@
import os
import time
from utils import build_not_implemented_msg

class BaseCallback(object):

def before_gpi_start(self, domain, value_function):
pass

def before_update(self, iteration_count, domain, value_function):
pass

def after_update(self, iteration_count, domain, value_function):
pass

def after_gpi_finish(self, domain, value_function):
pass

def interrupt_gpi(self, iteration_count, domain, value_function):
return False

def define_log_tag(self):
return self.__class__.__name__

def log(self, message):
if message and len(message) != 0:
print "[%s] %s" % (self.define_log_tag(), message)

class BasePerformanceWatcher(BaseCallback):

def setUp(self, domain, value_function):
pass

def tearDown(self, domain, value_function):
pass

def define_performance_test_interval(self):
err_msg = build_not_implemented_msg(self, "define_performance_test_interval")
raise NotImplementedError(err_msg)

def run_performance_test(self, domain, value_function):
err_msg = build_not_implemented_msg(self, "run_performance_test")
raise NotImplementedError(err_msg)

def define_log_message(self, iteration_count, domain, value_function, test_result):
base_msg = "Performance test result : %s (nb_iteration=%d)"
return base_msg % (test_result, iteration_count)


def before_gpi_start(self, domain, value_function):
self.performance_log = []
self.test_interval = self.define_performance_test_interval()
self.setUp(domain, value_function)

def after_update(self, iteration_count, domain, value_function):
if iteration_count % self.test_interval == 0:
result = self.run_performance_test(domain, value_function)
self.performance_log.append(result)
message = self.define_log_message(iteration_count, domain, value_function, result)
self.log(message)

def after_gpi_finish(self, domain, value_function):
self.tearDown(domain, value_function)

class EpsilonAnnealer(BaseCallback):

def __init__(self, epsilon_greedy_policy):
self.policy = epsilon_greedy_policy
self.anneal_finished = False

def define_log_tag(self):
return "EpsilonGreedyAnnealing"

def before_gpi_start(self, _domain, _value_function):
start_msg = "Anneal epsilon from %s to %s." % (self.policy.eps, self.policy.min_eps)
self.log(start_msg)

def after_update(self, iteration_count, _domain, _value_function):
self.policy.anneal_eps()
if not self.anneal_finished and self.policy.eps == self.policy.min_eps:
self.anneal_finished = True
finish_msg = "Annealing has finished at %d iteration." % iteration_count
self.log(finish_msg)

class BaseFinishRule(BaseCallback):

def check_condition(self, iteration_count, domain, value_function):
err_msg = build_not_implemented_msg(self, "check_condition")
raise NotImplementedError(err_msg)

def generate_start_message(self):
err_msg = build_not_implemented_msg(self, "generate_start_message")
raise NotImplementedError(err_msg)

def generate_finish_message(self, iteration_count):
err_msg = build_not_implemented_msg(self, "generate_finish_message")
raise NotImplementedError(err_msg)

def before_gpi_start(self, domain, value_function):
self.log(self.generate_start_message())

def interrupt_gpi(self, iteration_count, domain, value_function):
finish_iteration = self.check_condition(iteration_count, domain, value_function)
if finish_iteration: self.log(self.generate_finish_message(iteration_count))
return finish_iteration

class ManualInterruption(BaseFinishRule):

TARGET_WARD = "stop"

def __init__(self, monitor_file_path, watch_interval=30):
self.monitor_file_path = monitor_file_path
self.watch_interval = watch_interval

def check_condition(self, _iteration_count, _domain, _value_function):
current_time = time.time()
if current_time - self.last_check_time >= self.watch_interval:
self.last_check_time = current_time
return self.__order_found_in_monitoring_file(self.monitor_file_path, self.TARGET_WARD)
else:
return False

def generate_start_message(self):
self.last_check_time = time.time()
base_first_msg ='Write word "%s" on file "%s" will finish the GPI'
base_second_msg = "(Stopping GPI may take about %s seconds. Because we check target file every %s seconds.)"
first_msg = base_first_msg % (self.TARGET_WARD, self.monitor_file_path)
second_msg = base_second_msg % (self.watch_interval, self.watch_interval)
return "\n".join([first_msg, second_msg])

def generate_finish_message(self, iteration_count):
base_msg = "Interrupt GPI after %d iterations because interupption order found in [ %s ]."
return base_msg % (iteration_count, self.monitor_file_path)

def __order_found_in_monitoring_file(self, filepath, target_word):
return os.path.isfile(filepath) and self.__found_target_ward_in_file(filepath, target_word)

def __found_target_ward_in_file(self, filepath, target_word):
search_word = lambda src, target: target in src
src = self.__read_data(filepath)
return search_word(src, target_word) if src else False

def __read_data(self, filepath):
with open(filepath, 'rb') as f: return f.read()

class WatchIterationCount(BaseFinishRule):

def __init__(self, target_count, verbose=1):
self.target_count = target_count
self.start_time = self.last_update_time = 0
self.verbose = verbose

def define_log_tag(self):
return "Progress"

def check_condition(self, iteration_count, domain, value_function):
return iteration_count >= self.target_count

def generate_start_message(self):
self.start_time = self.last_update_time = time.time()
return "Start GPI iteration for %d times" % self.target_count

def generate_finish_message(self, iteration_count):
base_msg = "Completed GPI iteration for %d times. (total time: %ds)"
return base_msg % (iteration_count, time.time() - self.start_time)

def before_update(self, iteration_count, domain, value_function):
super(WatchIterationCount, self).before_update(iteration_count, domain, value_function)
self.last_update_time = time.time()

def after_update(self, iteration_count, domain, value_function):
super(WatchIterationCount, self).after_update(iteration_count, domain, value_function)
if self.verbose > 0:
current_time = time.time()
msg = "Finished %d / %d iterations (%.1fs)" %\
(iteration_count, self.target_count, current_time - self.last_update_time)
self.last_update_time = current_time
self.log(msg)

0 comments on commit eb4c6d6

Please sign in to comment.