Skip to content

Commit

Permalink
Merge 1c7f11c into 8cd3ebb
Browse files Browse the repository at this point in the history
  • Loading branch information
ishikota committed Oct 26, 2016
2 parents 8cd3ebb + 1c7f11c commit ead8f15
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 74 deletions.
2 changes: 2 additions & 0 deletions kyoka/algorithm/base_rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def update_value_function(self, domain, policy, value_function):
raise NotImplementedError(err_msg)

def run_gpi(self, nb_iteration, finish_rules=[], callbacks=[], verbose=1):
if not all([hasattr(self, attr) for attr in ["domain", "value_function", "policy"]]):
raise Exception('You need to call "setUp" method before calling "run_gpi" method.')
callbacks = self.__wrap_item_if_single(callbacks)
finish_rules = self.__wrap_item_if_single(finish_rules)
default_finish_rule = WatchIterationCount(nb_iteration, log_interval=float('inf') if verbose==0 else 1)
Expand Down
51 changes: 51 additions & 0 deletions kyoka/callback/base_performance_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from kyoka.callback.base_callback import BaseCallback

class BasePerformanceWatcher(BaseCallback):

def setUp(self, domain, value_function):
pass

def tearDown(self, domain, value_function):
pass

def define_performance_test_interval(self):
err_msg = self.__build_err_msg("define_performance_test_interval")
raise NotImplementedError(err_msg)

def run_performance_test(self, domain, value_function):
err_msg = self.__build_err_msg("run_performance_test")
raise NotImplementedError(err_msg)

def define_log_message(self, iteration_count, domain, value_function, test_result):
base_msg = "Performance test result : %s (nb_iteration=%d)"
return base_msg % (test_result, iteration_count)

def define_log_tag(self):
return self.__class__.__name__


def before_gpi_start(self, domain, value_function):
self.performance_log = []
self.test_interval = self.define_performance_test_interval()
self.setUp(domain, value_function)

def after_update(self, iteration_count, domain, value_function):
iteration_count = iteration_count + 1 # Fix to 1-index
if iteration_count % self.test_interval == 0:
result = self.run_performance_test(domain, value_function)
self.performance_log.append(result)
message = self.define_log_message(iteration_count, domain, value_function, result)
self.log(message)

def after_gpi_finish(self, domain, value_function):
self.tearDown(domain, value_function)


def log(self, message):
if message and len(message) != 0:
print "[%s] %s" % (self.define_log_tag(), message)

def __build_err_msg(self, msg):
base_msg = "[ {0} ] class does not implement [ {1} ] method"
return base_msg.format(self.__class__.__name__, msg)

14 changes: 0 additions & 14 deletions sample/maze/maze_performance_logger.py

This file was deleted.

21 changes: 21 additions & 0 deletions sample/maze/maze_performance_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher
from sample.maze.maze_helper import MazeHelper

class MazePerformanceWatcher(BasePerformanceWatcher):

def define_performance_test_interval(self):
return 1

def run_performance_test(self, domain, value_function):
step_to_goal = MazeHelper.measure_performance(domain, value_function)
policy = MazeHelper.visualize_policy(domain, value_function)
return step_to_goal, policy

def define_log_message(self, iteration_count, domain, value_function, test_result):
step_to_goal, _ = test_result
return "Step = %d (nb_iteration=%d)" % (step_to_goal,iteration_count)

def tearDown(self, domain, value_function):
msg_prefix = "Policy which agent learned is like this.\n"
self.log(msg_prefix + MazeHelper.visualize_policy(domain, value_function))

11 changes: 2 additions & 9 deletions sample/maze/script/measure_performance
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ sys.path.append(root)
sys.path.append(src_path)
sys.path.append(sample_path)

import logging as log
log.basicConfig(format='[%(levelname)s] %(message)s', level=log.DEBUG)

from kyoka.algorithm.montecarlo.montecarlo import MonteCarlo
from kyoka.algorithm.td_learning.sarsa import Sarsa
from kyoka.algorithm.td_learning.q_learning import QLearning
Expand All @@ -27,7 +24,7 @@ from kyoka.finish_rule.watch_iteration_count import WatchIterationCount
from sample.maze.maze_domain import MazeDomain
from sample.maze.maze_table_value_function import MazeTableValueFunction
from sample.maze.maze_helper import MazeHelper
from sample.maze.maze_performance_logger import MazePerformanceLogger
from sample.maze.maze_performance_watcher import MazePerformanceWatcher
from sample.maze.maze_transformer import MazeTransformer

SUPPORT_ALGORITHM = ["human", "montecarlo", "sarsa", "qlearning", "sarsalambda", "qlambda"]
Expand Down Expand Up @@ -56,7 +53,7 @@ value_func.setUp()

TEST_LENGTH = 100
policy = EpsilonGreedyPolicy(eps=0.1)
callbacks = [MazePerformanceLogger()]
callbacks = [MazePerformanceWatcher()]
if maze_type in ["blocking", "shortcut"]:
transfomer = MazeTransformer()
transformed_maze_filepath = MAZE_FILE_PATH[:-len(".txt")] + "_transformed.txt"
Expand All @@ -71,10 +68,6 @@ RL_algo = {
"qlambda": lambda :QLambda()
}[algo]()

log.info("start to measure performnce for %d episode" % TEST_LENGTH)
RL_algo.setUp(domain, policy, value_func)
RL_algo.run_gpi(TEST_LENGTH, callbacks=callbacks)
log.info("finished to measure performnce for %d episode" % TEST_LENGTH)
log.info("performance_log = %s" % callbacks[0].step_log)
log.info("Policy which agent learned is like this.\n%s" % MazeHelper.visualize_policy(domain, value_func))

13 changes: 2 additions & 11 deletions sample/ticktacktoe/script/measure_performance
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ sys.path.append(root)
sys.path.append(src_path)
sys.path.append(sample_path)

import logging as log
log.basicConfig(format='[%(levelname)s] %(message)s', level=log.INFO)

from kyoka.algorithm.montecarlo.montecarlo import MonteCarlo
from kyoka.algorithm.td_learning.sarsa import Sarsa
from kyoka.algorithm.td_learning.q_learning import QLearning
Expand All @@ -30,7 +27,7 @@ from sample.ticktacktoe.ticktacktoe_table_value_function import TickTackToeTable
from sample.ticktacktoe.ticktacktoe_helper import TickTackToeHelper
from sample.ticktacktoe.ticktacktoe_manual_policy import TickTackToeManualPolicy
from sample.ticktacktoe.ticktacktoe_perfect_policy import TickTackToePerfectPolicy
from sample.ticktacktoe.ticktacktoe_performance_logger import TickTackToePerformanceLogger
from sample.ticktacktoe.ticktacktoe_performance_watcher import TickTackToePerformanceWatcher

SUPPORT_ALGORITHM = ["montecarlo", "sarsa", "qlearning", "sarsalambda", "qlambda", "minimax"]

Expand Down Expand Up @@ -59,14 +56,8 @@ RL_algo = {
"qlambda": lambda :QLambda()
}[algo]()

callback = TickTackToePerformanceLogger()
callback.set_performance_test_interval(TEST_INTERVAL)
callback.set_is_first_player(is_first_player)
callback.set_test_game_count(TEST_GAME_COUNT)
callback = TickTackToePerformanceWatcher(TEST_INTERVAL, TEST_GAME_COUNT, is_first_player)

log.info("start to measure performnce for %d episode" % TEST_LENGTH)
RL_algo.setUp(domain, policy, value_func)
RL_algo.run_gpi(TEST_LENGTH, callbacks=callback)
log.info("finished to measure performnce for %d episode" % TEST_LENGTH)
log.info("performance log = %s" % callback.game_log)

40 changes: 0 additions & 40 deletions sample/ticktacktoe/ticktacktoe_performance_logger.py

This file was deleted.

37 changes: 37 additions & 0 deletions sample/ticktacktoe/ticktacktoe_performance_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher
from kyoka.policy.greedy_policy import GreedyPolicy
from sample.ticktacktoe.ticktacktoe_domain import TickTackToeDomain
from sample.ticktacktoe.ticktacktoe_table_value_function import TickTackToeTableValueFunction
from sample.ticktacktoe.ticktacktoe_perfect_policy import TickTackToePerfectPolicy
from sample.ticktacktoe.ticktacktoe_helper import TickTackToeHelper

class TickTackToePerformanceWatcher(BasePerformanceWatcher):

def __init__(self, test_interval, test_game_count, is_first_player=True):
self.test_interval = test_interval
self.test_game_count = test_game_count
self.is_first_player = is_first_player

def define_performance_test_interval(self):
return self.test_interval

def run_performance_test(self, domain, value_function):
players = self.__setup_players(value_function)
game_results = [TickTackToeHelper.measure_performance(domain, value_function, players)\
for _ in range(self.test_game_count)]
result_count = [game_results.count(result) for result in [-1, 0, 1]]
result_rate = [1.0 * count / len(game_results) for count in result_count]
return result_rate

def define_log_message(self, iteration_count, domain, value_function, test_result):
return "VS PerfectPolicy average result: lose=%f, draw=%f, win=%f" % tuple(test_result)


def __setup_players(self, value_function):
domains = [TickTackToeDomain(is_first_player=is_first)\
for is_first in [self.is_first_player, not self.is_first_player]]
value_funcs = [value_function, TickTackToeTableValueFunction()]
players = [GreedyPolicy(), TickTackToePerfectPolicy()]
players = players if self.is_first_player else players[::-1]
return players

8 changes: 8 additions & 0 deletions tests/kyoka/algorithm/base_rl_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ def test_set_callback(self):
callback.after_update.assert_called_with(1, "domain", value_func)
callback.after_gpi_finish.assert_called_with("domain", value_func)

def test_error_when_run_gpi_called_without_setup(self):
algo = self.TestImplementation()
with self.assertRaises(Exception) as e:
algo.run_gpi(nb_iteration=2)
self.include("setUp", e.exception.message)
self.include("run_gpi", e.exception.message)


def __setup_stub_domain(self):
mock_domain = Mock()
mock_domain.generate_initial_state.return_value = 0
Expand Down
91 changes: 91 additions & 0 deletions tests/kyoka/callback/base_performance_watcher_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from tests.base_unittest import BaseUnitTest
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher
from nose.tools import raises

import sys
import StringIO

class BasePerformanceWatcherTest(BaseUnitTest):

def setUp(self):
self.capture = StringIO.StringIO()
sys.stdout = self.capture

def tearDown(self):
sys.stdout = sys.__stdout__

@raises(NotImplementedError)
def test_implementation_check_on_define_performance_test_interval(self):
BasePerformanceWatcher().define_performance_test_interval()

@raises(NotImplementedError)
def test_implementation_check_on_run_performance_test(self):
BasePerformanceWatcher().run_performance_test("dummy", "dummy")

def test_default_log_message(self):
watcher = self.TestMinimumImplementation()
watcher.before_gpi_start("dummy", "dummy")
watcher.after_update(1, "dummy", "dummy")
watcher.after_update(2, "dummy", "dummy")
expected = "[TestMinimumImplementation] Performance test result : 3 (nb_iteration=2)\n"
self.eq(expected, self.capture.getvalue())

def test_setup_is_called(self):
watcher = self.TestCompleteImplementation()
self.false(watcher.is_setup_called)
watcher.before_gpi_start("dummy", "dummy")
self.true(watcher.is_setup_called)

def test_teardown_is_called(self):
watcher = self.TestCompleteImplementation()
self.false(watcher.is_teardown_called)
watcher.after_gpi_finish("dummy", "dummy")
self.true(watcher.is_teardown_called)

def test_timing_of_performance_test(self):
watcher = self.TestCompleteImplementation()
watcher.before_gpi_start("dummy", "dummy")
watcher.after_update(0, "dummy", "dummy")
self.eq('', self.capture.getvalue())
watcher.after_update(1, "dummy", "dummy")
self.eq('[Test] test:1\n', self.capture.getvalue())
watcher.after_update(2, "dummy", "dummy")
self.eq('[Test] test:1\n', self.capture.getvalue())
watcher.after_update(3, "dummy", "dummy")
self.eq('[Test] test:1\n[Test] test:4\n', self.capture.getvalue())
self.eq([1, 4], watcher.performance_log)

class TestCompleteImplementation(BasePerformanceWatcher):

def __init__(self):
self.is_setup_called = False
self.is_teardown_called = False
self.test_count = 0

def setUp(self, domain, value_function):
self.is_setup_called = True

def tearDown(self, domain, value_function):
self.is_teardown_called = True

def define_performance_test_interval(self):
return 2

def run_performance_test(self, _domain, _value_function):
self.test_count += 1
return self.test_count**2

def define_log_message(self, iteration_count, domain, value_function, test_result):
return "test:%s" % test_result

def define_log_tag(self):
return "Test"

class TestMinimumImplementation(BasePerformanceWatcher):

def define_performance_test_interval(self):
return 2

def run_performance_test(self, _domain, _value_function):
return 3

0 comments on commit ead8f15

Please sign in to comment.