Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,544 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
|
||
from kyoka.utils import pickle_data, unpickle_data | ||
from kyoka.value_function_ import BaseTabularActionValueFunction | ||
from kyoka.algorithm_.rl_algorithm import BaseRLAlgorithm, generate_episode | ||
|
||
class MonteCarlo(BaseRLAlgorithm): | ||
|
||
def setup(self, task, policy, value_function): | ||
validate_value_function(value_function) | ||
super(MonteCarlo, self).setup(task, policy, value_function) | ||
|
||
def run_gpi_for_an_episode(self, task, policy, value_function): | ||
episode = generate_episode(task, policy, value_function) | ||
for idx, turn_info in enumerate(episode): | ||
state, action, _next_state, _reward = turn_info | ||
following_reward = self._calculate_following_state_reward(idx, episode) | ||
value_function.backup(state, action, following_reward, alpha="dummy") | ||
|
||
def _calculate_following_state_reward(self, current_turn, episode): | ||
following_turn_info = episode[current_turn:] | ||
following_reward = [reward for _, _, _, reward in following_turn_info] | ||
return sum(following_reward) | ||
|
||
class MontCarloTabularActionValueFunction(BaseTabularActionValueFunction): | ||
|
||
SAVE_FILE_NAME = "montecarlo_update_counter.pickle" | ||
|
||
def setup(self): | ||
super(MontCarloTabularActionValueFunction, self).setup() | ||
self.update_counter = self.generate_initial_table() | ||
|
||
def define_save_file_prefix(self): | ||
return "montecarlo" | ||
|
||
def save(self, save_dir_path): | ||
super(MontCarloTabularActionValueFunction, self).save(save_dir_path) | ||
pickle_data(self._gen_update_counter_file_path(save_dir_path), self.update_counter) | ||
|
||
def load(self, load_dir_path): | ||
super(MontCarloTabularActionValueFunction, self).load(load_dir_path) | ||
if not os.path.exists(self._gen_update_counter_file_path(load_dir_path)): | ||
raise IOError('The saved data of "MonteCarlo" algorithm is not found in [ %s ]'% load_dir_path) | ||
self.update_counter = unpickle_data(self._gen_update_counter_file_path(load_dir_path)) | ||
|
||
def backup(self, state, action, backup_target, alpha): | ||
update_count = self.fetch_value_from_table(self.update_counter, state, action) | ||
Q_value = self.fetch_value_from_table(self.table, state, action) | ||
new_value = self._calc_average_in_incremental_way(update_count, backup_target, Q_value) | ||
self.insert_value_into_table(self.table, state, action, new_value) | ||
self.insert_value_into_table(self.update_counter, state, action, update_count+1) | ||
|
||
def _calc_average_in_incremental_way(self, k, r, Q): | ||
return Q + 1.0 / (k + 1) * (r - Q) | ||
|
||
def _gen_update_counter_file_path(self, dir_path): | ||
return os.path.join(dir_path, self.SAVE_FILE_NAME) | ||
|
||
def validate_value_function(value_function): | ||
if not isinstance(value_function, MontCarloTabularActionValueFunction): | ||
err_msg = 'MonteCarlo method requires you to use "table" type function.\ | ||
(child class of [BaseTableStateValueFunction or BaseTableActionValueFunction])' | ||
raise TypeError(err_msg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from kyoka.utils import build_not_implemented_msg | ||
from kyoka.policy_ import EpsilonGreedyPolicy | ||
from kyoka.callback_ import EpsilonAnnealer | ||
from kyoka.callback_ import WatchIterationCount | ||
|
||
def generate_episode(task, policy, value_function): | ||
state = task.generate_initial_state() | ||
episode = [] | ||
while not task.is_terminal_state(state): | ||
action = policy.choose_action(task, value_function, state) | ||
next_state = task.transit_state(state, action) | ||
reward = task.calculate_reward(next_state) | ||
episode.append((state, action, next_state, reward)) | ||
state = next_state | ||
return episode | ||
|
||
class BaseRLAlgorithm(object): | ||
|
||
def setup(self, task, policy, value_function): | ||
self.task = task | ||
self.value_function = value_function | ||
self.value_function.setup() | ||
self.policy = policy | ||
|
||
def save(self, save_dir_path): | ||
self.value_function.save(save_dir_path) | ||
self.save_algorithm_state(save_dir_path) | ||
|
||
def load(self, load_dir_path): | ||
self.value_function.load(load_dir_path) | ||
self.load_algorithm_state(load_dir_path) | ||
|
||
def save_algorithm_state(self, save_dir_path): | ||
pass | ||
|
||
def load_algorithm_state(self, load_dir_path): | ||
pass | ||
|
||
def run_gpi_for_an_episode(self, task, policy, value_function): | ||
err_msg = build_not_implemented_msg(self, "run_gpi_for_an_episode") | ||
raise NotImplementedError(err_msg) | ||
|
||
def run_gpi(self, nb_iteration, callbacks=None, verbose=1): | ||
self.__check_setup_call() | ||
default_finish_rule = WatchIterationCount(nb_iteration, verbose) | ||
callbacks = self.__setup_callbacks(default_finish_rule, callbacks) | ||
[callback.before_gpi_start(self.task, self.value_function) for callback in callbacks] | ||
|
||
iteration_counter = 1 | ||
while True: | ||
[callback.before_update(iteration_counter, self.task, self.value_function) for callback in callbacks] | ||
self.run_gpi_for_an_episode(self.task, self.policy, self.value_function) | ||
[callback.after_update(iteration_counter, self.task, self.value_function) for callback in callbacks] | ||
for finish_rule in callbacks: | ||
if finish_rule.interrupt_gpi(iteration_counter, self.task, self.value_function): | ||
[callback.after_gpi_finish(self.task, self.value_function) for callback in callbacks] | ||
if finish_rule != default_finish_rule: | ||
default_finish_rule.log(default_finish_rule.generate_finish_message(iteration_counter)) | ||
return | ||
iteration_counter += 1 | ||
|
||
|
||
def __check_setup_call(self): | ||
if not all([hasattr(self, attr) for attr in ["task", "value_function", "policy"]]): | ||
raise Exception('You need to call "setup" method before calling "run_gpi" method.') | ||
|
||
def __setup_callbacks(self, default_finish_rule, user_callbacks): | ||
user_callbacks = self.__wrap_item_if_single(user_callbacks) | ||
default_callbacks = [default_finish_rule] | ||
if isinstance(self.policy, EpsilonGreedyPolicy) and self.policy.do_annealing: | ||
default_callbacks.append(EpsilonAnnealer(self.policy)) | ||
return default_callbacks + user_callbacks | ||
|
||
def __wrap_item_if_single(self, item): | ||
if item is None: item = [] | ||
if not isinstance(item, list): item = [item] | ||
return item | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
from kyoka.utils import pickle_data, unpickle_data | ||
from kyoka.value_function_ import BaseTabularActionValueFunction | ||
from kyoka.algorithm_.rl_algorithm import BaseRLAlgorithm, generate_episode | ||
|
||
class Sarsa(BaseRLAlgorithm): | ||
|
||
def __init__(self, alpha=0.1, gamma=0.9): | ||
self.alpha = alpha | ||
self.gamma = gamma | ||
|
||
def run_gpi_for_an_episode(self, task, policy, value_function): | ||
state = task.generate_initial_state() | ||
action = policy.choose_action(task, value_function, state) | ||
while not task.is_terminal_state(state): | ||
next_state = task.transit_state(state, action) | ||
next_action = choose_action(task, policy, value_function, next_state) | ||
reward = task.calculate_reward(next_state) | ||
next_Q_value = predict_value(value_function, next_state, next_action) | ||
backup_target = reward + self.gamma * next_Q_value | ||
value_function.backup(state, action, backup_target, self.alpha) | ||
state, action = next_state, next_action | ||
|
||
class SarsaTabularActionValueFunction(BaseTabularActionValueFunction): | ||
|
||
def define_save_file_prefix(self): | ||
return "sarsa" | ||
|
||
def backup(self, state, action, backup_target, alpha): | ||
Q_value = self.predict_value(state, action) | ||
new_Q_value = Q_value + alpha * (backup_target - Q_value) | ||
self.insert_value_into_table(self.table, state, action, new_Q_value) | ||
|
||
|
||
ACTION_ON_TERMINAL_FLG = "action_on_terminal" | ||
|
||
def choose_action(task, policy, value_function, state): | ||
if task.is_terminal_state(state): | ||
return ACTION_ON_TERMINAL_FLG | ||
else: | ||
return policy.choose_action(task, value_function, state) | ||
|
||
def predict_value(value_function, next_state, next_action): | ||
if ACTION_ON_TERMINAL_FLG == next_action: | ||
return 0 | ||
else: | ||
return value_function.predict_value(next_state, next_action) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import time | ||
from utils import build_not_implemented_msg | ||
|
||
class BaseCallback(object): | ||
|
||
def before_gpi_start(self, domain, value_function): | ||
pass | ||
|
||
def before_update(self, iteration_count, domain, value_function): | ||
pass | ||
|
||
def after_update(self, iteration_count, domain, value_function): | ||
pass | ||
|
||
def after_gpi_finish(self, domain, value_function): | ||
pass | ||
|
||
def interrupt_gpi(self, iteration_count, domain, value_function): | ||
return False | ||
|
||
def define_log_tag(self): | ||
return self.__class__.__name__ | ||
|
||
def log(self, message): | ||
if message and len(message) != 0: | ||
print "[%s] %s" % (self.define_log_tag(), message) | ||
|
||
class BasePerformanceWatcher(BaseCallback): | ||
|
||
def setUp(self, domain, value_function): | ||
pass | ||
|
||
def tearDown(self, domain, value_function): | ||
pass | ||
|
||
def define_performance_test_interval(self): | ||
err_msg = build_not_implemented_msg(self, "define_performance_test_interval") | ||
raise NotImplementedError(err_msg) | ||
|
||
def run_performance_test(self, domain, value_function): | ||
err_msg = build_not_implemented_msg(self, "run_performance_test") | ||
raise NotImplementedError(err_msg) | ||
|
||
def define_log_message(self, iteration_count, domain, value_function, test_result): | ||
base_msg = "Performance test result : %s (nb_iteration=%d)" | ||
return base_msg % (test_result, iteration_count) | ||
|
||
|
||
def before_gpi_start(self, domain, value_function): | ||
self.performance_log = [] | ||
self.test_interval = self.define_performance_test_interval() | ||
self.setUp(domain, value_function) | ||
|
||
def after_update(self, iteration_count, domain, value_function): | ||
if iteration_count % self.test_interval == 0: | ||
result = self.run_performance_test(domain, value_function) | ||
self.performance_log.append(result) | ||
message = self.define_log_message(iteration_count, domain, value_function, result) | ||
self.log(message) | ||
|
||
def after_gpi_finish(self, domain, value_function): | ||
self.tearDown(domain, value_function) | ||
|
||
class EpsilonAnnealer(BaseCallback): | ||
|
||
def __init__(self, epsilon_greedy_policy): | ||
self.policy = epsilon_greedy_policy | ||
self.anneal_finished = False | ||
|
||
def define_log_tag(self): | ||
return "EpsilonGreedyAnnealing" | ||
|
||
def before_gpi_start(self, _domain, _value_function): | ||
start_msg = "Anneal epsilon from %s to %s." % (self.policy.eps, self.policy.min_eps) | ||
self.log(start_msg) | ||
|
||
def after_update(self, iteration_count, _domain, _value_function): | ||
self.policy.anneal_eps() | ||
if not self.anneal_finished and self.policy.eps == self.policy.min_eps: | ||
self.anneal_finished = True | ||
finish_msg = "Annealing has finished at %d iteration." % iteration_count | ||
self.log(finish_msg) | ||
|
||
class BaseFinishRule(BaseCallback): | ||
|
||
def check_condition(self, iteration_count, domain, value_function): | ||
err_msg = build_not_implemented_msg(self, "check_condition") | ||
raise NotImplementedError(err_msg) | ||
|
||
def generate_start_message(self): | ||
err_msg = build_not_implemented_msg(self, "generate_start_message") | ||
raise NotImplementedError(err_msg) | ||
|
||
def generate_finish_message(self, iteration_count): | ||
err_msg = build_not_implemented_msg(self, "generate_finish_message") | ||
raise NotImplementedError(err_msg) | ||
|
||
def before_gpi_start(self, domain, value_function): | ||
self.log(self.generate_start_message()) | ||
|
||
def interrupt_gpi(self, iteration_count, domain, value_function): | ||
finish_iteration = self.check_condition(iteration_count, domain, value_function) | ||
if finish_iteration: self.log(self.generate_finish_message(iteration_count)) | ||
return finish_iteration | ||
|
||
class ManualInterruption(BaseFinishRule): | ||
|
||
TARGET_WARD = "stop" | ||
|
||
def __init__(self, monitor_file_path, watch_interval=30): | ||
self.monitor_file_path = monitor_file_path | ||
self.watch_interval = watch_interval | ||
|
||
def check_condition(self, _iteration_count, _domain, _value_function): | ||
current_time = time.time() | ||
if current_time - self.last_check_time >= self.watch_interval: | ||
self.last_check_time = current_time | ||
return self.__order_found_in_monitoring_file(self.monitor_file_path, self.TARGET_WARD) | ||
else: | ||
return False | ||
|
||
def generate_start_message(self): | ||
self.last_check_time = time.time() | ||
base_first_msg ='Write word "%s" on file "%s" will finish the GPI' | ||
base_second_msg = "(Stopping GPI may take about %s seconds. Because we check target file every %s seconds.)" | ||
first_msg = base_first_msg % (self.TARGET_WARD, self.monitor_file_path) | ||
second_msg = base_second_msg % (self.watch_interval, self.watch_interval) | ||
return "\n".join([first_msg, second_msg]) | ||
|
||
def generate_finish_message(self, iteration_count): | ||
base_msg = "Interrupt GPI after %d iterations because interupption order found in [ %s ]." | ||
return base_msg % (iteration_count, self.monitor_file_path) | ||
|
||
def __order_found_in_monitoring_file(self, filepath, target_word): | ||
return os.path.isfile(filepath) and self.__found_target_ward_in_file(filepath, target_word) | ||
|
||
def __found_target_ward_in_file(self, filepath, target_word): | ||
search_word = lambda src, target: target in src | ||
src = self.__read_data(filepath) | ||
return search_word(src, target_word) if src else False | ||
|
||
def __read_data(self, filepath): | ||
with open(filepath, 'rb') as f: return f.read() | ||
|
||
class WatchIterationCount(BaseFinishRule): | ||
|
||
def __init__(self, target_count, verbose=1): | ||
self.target_count = target_count | ||
self.start_time = self.last_update_time = 0 | ||
self.verbose = verbose | ||
|
||
def define_log_tag(self): | ||
return "Progress" | ||
|
||
def check_condition(self, iteration_count, domain, value_function): | ||
return iteration_count >= self.target_count | ||
|
||
def generate_start_message(self): | ||
self.start_time = self.last_update_time = time.time() | ||
return "Start GPI iteration for %d times" % self.target_count | ||
|
||
def generate_finish_message(self, iteration_count): | ||
base_msg = "Completed GPI iteration for %d times. (total time: %ds)" | ||
return base_msg % (iteration_count, time.time() - self.start_time) | ||
|
||
def before_update(self, iteration_count, domain, value_function): | ||
super(WatchIterationCount, self).before_update(iteration_count, domain, value_function) | ||
self.last_update_time = time.time() | ||
|
||
def after_update(self, iteration_count, domain, value_function): | ||
super(WatchIterationCount, self).after_update(iteration_count, domain, value_function) | ||
if self.verbose > 0: | ||
current_time = time.time() | ||
msg = "Finished %d / %d iterations (%.1fs)" %\ | ||
(iteration_count, self.target_count, current_time - self.last_update_time) | ||
self.last_update_time = current_time | ||
self.log(msg) | ||
|
Oops, something went wrong.