Skip to content

Commit

Permalink
Merge pull request #15 from ishikota/update-rlalgo_api
Browse files Browse the repository at this point in the history
Update api of base RL algorithm class
  • Loading branch information
ishikota committed Oct 24, 2016
2 parents e504ecf + 1d84b4b commit 51783b2
Show file tree
Hide file tree
Showing 32 changed files with 430 additions and 316 deletions.
47 changes: 33 additions & 14 deletions kyoka/algorithm/base_rl_algorithm.py
@@ -1,24 +1,46 @@
from kyoka.finish_rule.watch_iteration_count import WatchIterationCount

class BaseRLAlgorithm(object):

def __init__(self):
self.callbacks = []
def setUp(self, domain, policy, value_function):
self.domain = domain
self.value_function = value_function
self.value_function.setUp()
self.policy = policy

def save(self, save_dir_path):
self.value_function.save(save_dir_path)
self.save_algorithm_state(save_dir_path)

def load(self, load_dir_path):
self.value_function.load(load_dir_path)
self.load_algorithm_state(load_dir_path)

def save_algorithm_state(self, save_dir_path):
pass

def load_algorithm_state(self, load_dir_path):
pass

def update_value_function(self, domain, policy, value_function):
err_msg = self.__build_err_msg("update_value_function")
raise NotImplementedError(err_msg)

def GPI(self, domain, policy, value_function, finish_rules, debug=False):
def run_gpi(self, nb_iteration, finish_rules=[], callbacks=[]):
callbacks = self.__wrap_item_if_single(callbacks)
finish_rules = self.__wrap_item_if_single(finish_rules)
finish_rules.append(WatchIterationCount(nb_iteration))
iteration_counter = 0
[callback.before_gpi_start(domain, value_function) for callback in self.callbacks]
[callback.before_gpi_start(self.domain, self.value_function) for callback in callbacks]
while True:
[callback.before_update(iteration_counter, domain, value_function) for callback in self.callbacks]
self.update_value_function(domain, policy, value_function)
[callback.after_update(iteration_counter, domain, value_function) for callback in self.callbacks]
[callback.before_update(iteration_counter, self.domain, self.value_function) for callback in callbacks]
self.update_value_function(self.domain, self.policy, self.value_function)
[callback.after_update(iteration_counter, self.domain, self.value_function) for callback in callbacks]
iteration_counter += 1
for finish_rule in self.__wrap_rule_if_single(finish_rules):
for finish_rule in finish_rules:
if finish_rule.satisfy_condition(iteration_counter):
finish_msg = finish_rule.generate_finish_message(iteration_counter)
[callback.after_gpi_finish(domain, value_function) for callback in self.callbacks]
[callback.after_gpi_finish(self.domain, self.value_function) for callback in callbacks]
return finish_msg

def generate_episode(self, domain, value_function, policy):
Expand All @@ -32,12 +54,9 @@ def generate_episode(self, domain, value_function, policy):
state = next_state
return episode

def set_gpi_callback(self, callback):
self.callbacks.append(callback)


def __wrap_rule_if_single(self, finish_rule):
return [finish_rule] if not isinstance(finish_rule, list) else finish_rule
def __wrap_item_if_single(self, item):
return [item] if not isinstance(item, list) else item

def __build_err_msg(self, msg):
return "Accessed [ {0} ] method of BaseRLAlgorithm which should be overridden".format(msg)
Expand Down
43 changes: 31 additions & 12 deletions kyoka/algorithm/montecarlo/montecarlo.py
Expand Up @@ -4,23 +4,36 @@
from kyoka.value_function.base_table_action_value_function import BaseTableActionValueFunction
from kyoka.value_function.base_table_state_value_function import BaseTableStateValueFunction

import os
import pickle

class MonteCarlo(BaseRLAlgorithm):

__KEY_ADDITIONAL_DATA = "additinal_data_key_montecarlo_update_counter"
SAVE_FILE_NAME = "montecarlo_algorithm_state.pickle"

def update_value_function(self, domain, policy, value_function):
def setUp(self, domain, policy, value_function):
super(MonteCarlo, self).setUp(domain, policy, value_function)
self.__validate_value_function(value_function)
self.__initialize_update_counter_if_needed(value_function)
update_counter = value_function.get_additinal_data(self.__KEY_ADDITIONAL_DATA)
self.update_counter = value_function.generate_initial_table()

def update_value_function(self, domain, policy, value_function):
episode = self.generate_episode(domain, value_function, policy)
for idx, turn_info in enumerate(episode):
if isinstance(value_function, BaseActionValueFunction):
self.__update_action_value_function(\
domain, value_function, update_counter, episode, idx, turn_info)
domain, value_function, self.update_counter, episode, idx, turn_info)
elif isinstance(value_function, BaseStateValueFunction):
self.__update_state_value_function(\
domain, value_function, update_counter, episode, idx, turn_info)
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, update_counter)
domain, value_function, self.update_counter, episode, idx, turn_info)

def save_algorithm_state(self, save_dir_path):
self.__pickle_data(self.__gen_save_file_path(save_dir_path), self.update_counter)

def load_algorithm_state(self, load_dir_path):
if not os.path.exists(self.__gen_save_file_path(load_dir_path)):
raise IOError('The saved data of "MonteCarlo" algorithm is not found in [ %s ]'% load_dir_path)
self.update_counter = self.__unpickle_data(self.__gen_save_file_path(load_dir_path))


def __validate_value_function(self, value_function):
valid_type = isinstance(value_function, BaseTableActionValueFunction) or \
Expand All @@ -32,11 +45,6 @@ def __build_type_error_message(self):
return 'MonteCarlo method requires you to use "table" type function.\
(child class of [BaseTableStateValueFunction or BaseTableActionValueFunction])'

def __initialize_update_counter_if_needed(self, value_function):
if value_function.get_additinal_data(self.__KEY_ADDITIONAL_DATA) is None:
update_counter = value_function.generate_initial_table()
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, update_counter)

def __update_action_value_function(\
self, domain, value_function, update_counter, episode, idx, turn_info):
state, action, _next_state, _reward = turn_info
Expand Down Expand Up @@ -68,3 +76,14 @@ def __calculate_new_Q_value(self, Q_val_average, update_count, update_reward):
def __calc_average_in_incremental_way(self, k, r, Q):
return Q + 1.0 / (k + 1) * (r - Q)

def __gen_save_file_path(self, base_dir_path):
return os.path.join(base_dir_path, self.SAVE_FILE_NAME)

def __pickle_data(self, file_path, data):
with open(file_path, "wb") as f:
pickle.dump(data, f)

def __unpickle_data(self, file_path):
with open(file_path, "rb") as f:
return pickle.load(f)

Expand Up @@ -42,11 +42,13 @@ def clear(self):
self.eligibility_holder = self.__generate_action_eligibility_holder()

def dump(self):
return self.get_eligibilities()
return (self.update_type, self.discard_threshold,\
self.gamma, self.lambda_, self.get_eligibilities())

def load(self, serial):
self.clear()
for state, action, eligibility in serial:
self.update_type, self.discard_threshold, self.gamma, self.lambda_, eligibilities = serial
for state, action, eligibility in eligibilities:
self.__update(state, action, eligibility)

def __validate_update_type(self, update_type):
Expand Down
39 changes: 27 additions & 12 deletions kyoka/algorithm/td_learning/q_lambda.py
Expand Up @@ -3,20 +3,36 @@
import ActionEligibilityTrace as EligibilityTrace
from kyoka.policy.greedy_policy import GreedyPolicy

import os
import pickle

class QLambda(BaseTDMethod):

__KEY_ADDITIONAL_DATA = "additinal_data_key_q_lambda_eligibility_trace"
SAVE_FILE_NAME = "qlambda_algorithm_state.pickle"
ACTION_ON_TERMINAL_FLG = "action_on_terminal"

def __init__(self, alpha=0.1, gamma=0.9, eligibility_trace=None):
BaseTDMethod.__init__(self)
self.alpha = alpha
self.gamma = gamma
self.greedy_policy = GreedyPolicy()
self.trace = eligibility_trace if eligibility_trace else self.__generate_default_trace()
self.trace = eligibility_trace

def setUp(self, domain, policy, value_function):
super(QLambda, self).setUp(domain, policy, value_function)
if self.trace is None:
self.trace = EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)

def save_algorithm_state(self, save_dir_path):
self.__pickle_data(self.__gen_save_file_path(save_dir_path), self.trace.dump())

def load_algorithm_state(self, load_dir_path):
new_trace = EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)
trace_serial = self.__unpickle_data(self.__gen_save_file_path(load_dir_path))
new_trace.load(trace_serial)
self.trace = new_trace

def update_action_value_function(self, domain, policy, value_function):
self.__setup_trace(value_function)
current_state = domain.generate_initial_state()
current_action = policy.choose_action(domain, value_function, current_state)
while not domain.is_terminal_state(current_state):
Expand All @@ -35,7 +51,6 @@ def update_action_value_function(self, domain, policy, value_function):
if greedy_action != next_action:
self.trace.clear()
current_state, current_action = next_state, next_action
self.__save_trace(value_function)

def __calculate_delta(self,\
value_function, state, action, next_state, greedy_action, reward):
Expand All @@ -60,14 +75,14 @@ def __calculate_value(self, value_function, next_state, next_action):
else:
return value_function.calculate_value(next_state, next_action)

def __generate_default_trace(self):
return EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)
def __gen_save_file_path(self, base_dir_path):
return os.path.join(base_dir_path, self.SAVE_FILE_NAME)

def __setup_trace(self, value_function):
trace_dump = value_function.get_additinal_data(self.__KEY_ADDITIONAL_DATA)
if trace_dump:
self.trace.load(trace_dump)
def __pickle_data(self, file_path, data):
with open(file_path, "wb") as f:
pickle.dump(data, f)

def __save_trace(self, value_function):
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, self.trace.dump())
def __unpickle_data(self, file_path):
with open(file_path, "rb") as f:
return pickle.load(f)

39 changes: 27 additions & 12 deletions kyoka/algorithm/td_learning/sarsa_lambda.py
Expand Up @@ -2,19 +2,35 @@
from kyoka.algorithm.td_learning.eligibility_trace.action_eligibility_trace\
import ActionEligibilityTrace as EligibilityTrace

import os
import pickle

class SarsaLambda(BaseTDMethod):

__KEY_ADDITIONAL_DATA = "additinal_data_key_sarsa_lambda_eligibility_trace"
SAVE_FILE_NAME = "sarsalambda_algorithm_state.pickle"
ACTION_ON_TERMINAL_FLG = "action_on_terminal"

def __init__(self, alpha=0.1, gamma=0.9, eligibility_trace=None):
BaseTDMethod.__init__(self)
self.alpha = alpha
self.gamma = gamma
self.trace = eligibility_trace if eligibility_trace else self.__generate_default_trace()
self.trace = eligibility_trace

def setUp(self, domain, policy, value_function):
super(SarsaLambda, self).setUp(domain, policy, value_function)
if self.trace is None:
self.trace = EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)

def save_algorithm_state(self, save_dir_path):
self.__pickle_data(self.__gen_save_file_path(save_dir_path), self.trace.dump())

def load_algorithm_state(self, load_dir_path):
new_trace = EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)
trace_serial = self.__unpickle_data(self.__gen_save_file_path(load_dir_path))
new_trace.load(trace_serial)
self.trace = new_trace

def update_action_value_function(self, domain, policy, value_function):
self.__setup_trace(value_function)
current_state = domain.generate_initial_state()
current_action = policy.choose_action(domain, value_function, current_state)
while not domain.is_terminal_state(current_state):
Expand All @@ -30,7 +46,6 @@ def update_action_value_function(self, domain, policy, value_function):
value_function.update_function(state, action, new_Q_value)
self.trace.decay(state, action)
current_state, current_action = next_state, next_action
self.__save_trace(value_function)


def __calculate_delta(self,\
Expand All @@ -55,14 +70,14 @@ def __calculate_value(self, value_function, next_state, next_action):
else:
return value_function.calculate_value(next_state, next_action)

def __generate_default_trace(self):
return EligibilityTrace(EligibilityTrace.TYPE_ACCUMULATING)
def __gen_save_file_path(self, base_dir_path):
return os.path.join(base_dir_path, self.SAVE_FILE_NAME)

def __setup_trace(self, value_function):
trace_dump = value_function.get_additinal_data(self.__KEY_ADDITIONAL_DATA)
if trace_dump:
self.trace.load(trace_dump)
def __pickle_data(self, file_path, data):
with open(file_path, "wb") as f:
pickle.dump(data, f)

def __save_trace(self, value_function):
value_function.set_additinal_data(self.__KEY_ADDITIONAL_DATA, self.trace.dump())
def __unpickle_data(self, file_path):
with open(file_path, "rb") as f:
return pickle.load(f)

6 changes: 0 additions & 6 deletions kyoka/value_function/base_action_value_function.py
Expand Up @@ -13,12 +13,6 @@ def update_function(self, state, action, new_value):
def setUp(self):
pass

def provide_data_to_store(self):
return None

def receive_data_to_restore(self, restored_data):
pass


def __build_err_msg(self, msg):
base_msg = "[ {0} ] class does not implement [ {1} ] method"
Expand Down
6 changes: 0 additions & 6 deletions kyoka/value_function/base_keras_action_value_function.py
Expand Up @@ -26,12 +26,6 @@ def save_model_weights(self, file_path):
def load_model_weights(self, file_path):
self.model.load_weights(file_path)

def provide_data_to_store(self):
pass

def receive_data_to_restore(self, restored_data):
pass

def generate_model(self):
err_msg = self.__build_err_msg("generate_model")
raise NotImplementedError(err_msg)
Expand Down
6 changes: 0 additions & 6 deletions kyoka/value_function/base_state_value_function.py
Expand Up @@ -13,12 +13,6 @@ def update_function(self, state, new_value):
def setUp(self):
pass

def provide_data_to_store(self):
return None

def receive_data_to_restore(self, stored_data):
pass


def __build_err_msg(self, msg):
base_msg = "[ {0} ] class does not implement [ {1} ] method"
Expand Down
27 changes: 23 additions & 4 deletions kyoka/value_function/base_table_action_value_function.py
@@ -1,7 +1,11 @@
from kyoka.value_function.base_action_value_function import BaseActionValueFunction
import os
import pickle

class BaseTableActionValueFunction(BaseActionValueFunction):

SAVE_FILE_NAME = "table_action_value_function_data.pickle"

def setUp(self):
self.table = self.generate_initial_table()

Expand All @@ -11,11 +15,15 @@ def calculate_value(self, state, action):
def update_function(self, state, action, new_value):
self.update_table(self.table, state, action, new_value)

def provide_data_to_store(self):
return self.table
def save(self, save_dir_path):
file_path = os.path.join(save_dir_path, self.SAVE_FILE_NAME)
self.__pickle_data(file_path, self.table)

def receive_data_to_restore(self, restored_data):
self.table = restored_data
def load(self, load_dir_path):
file_path = os.path.join(load_dir_path, self.SAVE_FILE_NAME)
if not os.path.exists(file_path):
raise IOError('The saved data of "TableActionValueFunction" is not found on [ %s ]'% load_dir_path)
self.table = self.__unpickle_data(file_path)

def generate_initial_table(self):
err_msg = self.__build_err_msg("generate_initial_table")
Expand All @@ -30,6 +38,17 @@ def update_table(self, table, state, action, new_value):
raise NotImplementedError(err_msg)


def __gen_save_file_path(self, base_dir_path):
return os.path.join(base_dir_path, self.SAVE_FILE_NAME)

def __pickle_data(self, file_path, data):
with open(file_path, "wb") as f:
pickle.dump(data, f)

def __unpickle_data(self, file_path):
with open(file_path, "rb") as f:
return pickle.load(f)

def __build_err_msg(self, msg):
base_msg = "[ {0} ] class does not implement [ {1} ] method"
return base_msg.format(self.__class__.__name__, msg)
Expand Down

0 comments on commit 51783b2

Please sign in to comment.