-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
214 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from kyoka.callback.base_callback import BaseCallback | ||
|
||
class BasePerformanceWatcher(BaseCallback): | ||
|
||
def setUp(self, domain, value_function): | ||
pass | ||
|
||
def tearDown(self, domain, value_function): | ||
pass | ||
|
||
def define_performance_test_interval(self): | ||
err_msg = self.__build_err_msg("define_performance_test_interval") | ||
raise NotImplementedError(err_msg) | ||
|
||
def run_performance_test(self, domain, value_function): | ||
err_msg = self.__build_err_msg("run_performance_test") | ||
raise NotImplementedError(err_msg) | ||
|
||
def define_log_message(self, iteration_count, domain, value_function, test_result): | ||
base_msg = "Performance test result : %s (nb_iteration=%d)" | ||
return base_msg % (test_result, iteration_count) | ||
|
||
def define_log_tag(self): | ||
return self.__class__.__name__ | ||
|
||
|
||
def before_gpi_start(self, domain, value_function): | ||
self.performance_log = [] | ||
self.test_interval = self.define_performance_test_interval() | ||
self.setUp(domain, value_function) | ||
|
||
def after_update(self, iteration_count, domain, value_function): | ||
iteration_count = iteration_count + 1 # Fix to 1-index | ||
if iteration_count % self.test_interval == 0: | ||
result = self.run_performance_test(domain, value_function) | ||
self.performance_log.append(result) | ||
message = self.define_log_message(iteration_count, domain, value_function, result) | ||
self.log(message) | ||
|
||
def after_gpi_finish(self, domain, value_function): | ||
self.tearDown(domain, value_function) | ||
|
||
|
||
def log(self, message): | ||
if message and len(message) != 0: | ||
print "[%s] %s" % (self.define_log_tag(), message) | ||
|
||
def __build_err_msg(self, msg): | ||
base_msg = "[ {0} ] class does not implement [ {1} ] method" | ||
return base_msg.format(self.__class__.__name__, msg) | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher | ||
from sample.maze.maze_helper import MazeHelper | ||
|
||
class MazePerformanceWatcher(BasePerformanceWatcher): | ||
|
||
def define_performance_test_interval(self): | ||
return 1 | ||
|
||
def run_performance_test(self, domain, value_function): | ||
step_to_goal = MazeHelper.measure_performance(domain, value_function) | ||
policy = MazeHelper.visualize_policy(domain, value_function) | ||
return step_to_goal, policy | ||
|
||
def define_log_message(self, iteration_count, domain, value_function, test_result): | ||
step_to_goal, _ = test_result | ||
return "Step = %d (nb_iteration=%d)" % (step_to_goal,iteration_count) | ||
|
||
def tearDown(self, domain, value_function): | ||
msg_prefix = "Policy which agent learned is like this.\n" | ||
self.log(msg_prefix + MazeHelper.visualize_policy(domain, value_function)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher | ||
from kyoka.policy.greedy_policy import GreedyPolicy | ||
from sample.ticktacktoe.ticktacktoe_domain import TickTackToeDomain | ||
from sample.ticktacktoe.ticktacktoe_table_value_function import TickTackToeTableValueFunction | ||
from sample.ticktacktoe.ticktacktoe_perfect_policy import TickTackToePerfectPolicy | ||
from sample.ticktacktoe.ticktacktoe_helper import TickTackToeHelper | ||
|
||
class TickTackToePerformanceWatcher(BasePerformanceWatcher): | ||
|
||
def __init__(self, test_interval, test_game_count, is_first_player=True): | ||
self.test_interval = test_interval | ||
self.test_game_count = test_game_count | ||
self.is_first_player = is_first_player | ||
|
||
def define_performance_test_interval(self): | ||
return self.test_interval | ||
|
||
def run_performance_test(self, domain, value_function): | ||
players = self.__setup_players(value_function) | ||
game_results = [TickTackToeHelper.measure_performance(domain, value_function, players)\ | ||
for _ in range(self.test_game_count)] | ||
result_count = [game_results.count(result) for result in [-1, 0, 1]] | ||
result_rate = [1.0 * count / len(game_results) for count in result_count] | ||
return result_rate | ||
|
||
def define_log_message(self, iteration_count, domain, value_function, test_result): | ||
return "VS PerfectPolicy average result: lose=%f, draw=%f, win=%f" % tuple(test_result) | ||
|
||
|
||
def __setup_players(self, value_function): | ||
domains = [TickTackToeDomain(is_first_player=is_first)\ | ||
for is_first in [self.is_first_player, not self.is_first_player]] | ||
value_funcs = [value_function, TickTackToeTableValueFunction()] | ||
players = [GreedyPolicy(), TickTackToePerfectPolicy()] | ||
players = players if self.is_first_player else players[::-1] | ||
return players | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from tests.base_unittest import BaseUnitTest | ||
from kyoka.callback.base_performance_watcher import BasePerformanceWatcher | ||
from nose.tools import raises | ||
|
||
import sys | ||
import StringIO | ||
|
||
class BasePerformanceWatcherTest(BaseUnitTest): | ||
|
||
def setUp(self): | ||
self.capture = StringIO.StringIO() | ||
sys.stdout = self.capture | ||
|
||
def tearDown(self): | ||
sys.stdout = sys.__stdout__ | ||
|
||
@raises(NotImplementedError) | ||
def test_implementation_check_on_define_performance_test_interval(self): | ||
BasePerformanceWatcher().define_performance_test_interval() | ||
|
||
@raises(NotImplementedError) | ||
def test_implementation_check_on_run_performance_test(self): | ||
BasePerformanceWatcher().run_performance_test("dummy", "dummy") | ||
|
||
def test_default_log_message(self): | ||
watcher = self.TestMinimumImplementation() | ||
watcher.before_gpi_start("dummy", "dummy") | ||
watcher.after_update(1, "dummy", "dummy") | ||
watcher.after_update(2, "dummy", "dummy") | ||
expected = "[TestMinimumImplementation] Performance test result : 3 (nb_iteration=2)\n" | ||
self.eq(expected, self.capture.getvalue()) | ||
|
||
def test_setup_is_called(self): | ||
watcher = self.TestCompleteImplementation() | ||
self.false(watcher.is_setup_called) | ||
watcher.before_gpi_start("dummy", "dummy") | ||
self.true(watcher.is_setup_called) | ||
|
||
def test_teardown_is_called(self): | ||
watcher = self.TestCompleteImplementation() | ||
self.false(watcher.is_teardown_called) | ||
watcher.after_gpi_finish("dummy", "dummy") | ||
self.true(watcher.is_teardown_called) | ||
|
||
def test_timing_of_performance_test(self): | ||
watcher = self.TestCompleteImplementation() | ||
watcher.before_gpi_start("dummy", "dummy") | ||
watcher.after_update(0, "dummy", "dummy") | ||
self.eq('', self.capture.getvalue()) | ||
watcher.after_update(1, "dummy", "dummy") | ||
self.eq('[Test] test:1\n', self.capture.getvalue()) | ||
watcher.after_update(2, "dummy", "dummy") | ||
self.eq('[Test] test:1\n', self.capture.getvalue()) | ||
watcher.after_update(3, "dummy", "dummy") | ||
self.eq('[Test] test:1\n[Test] test:4\n', self.capture.getvalue()) | ||
self.eq([1, 4], watcher.performance_log) | ||
|
||
class TestCompleteImplementation(BasePerformanceWatcher): | ||
|
||
def __init__(self): | ||
self.is_setup_called = False | ||
self.is_teardown_called = False | ||
self.test_count = 0 | ||
|
||
def setUp(self, domain, value_function): | ||
self.is_setup_called = True | ||
|
||
def tearDown(self, domain, value_function): | ||
self.is_teardown_called = True | ||
|
||
def define_performance_test_interval(self): | ||
return 2 | ||
|
||
def run_performance_test(self, _domain, _value_function): | ||
self.test_count += 1 | ||
return self.test_count**2 | ||
|
||
def define_log_message(self, iteration_count, domain, value_function, test_result): | ||
return "test:%s" % test_result | ||
|
||
def define_log_tag(self): | ||
return "Test" | ||
|
||
class TestMinimumImplementation(BasePerformanceWatcher): | ||
|
||
def define_performance_test_interval(self): | ||
return 2 | ||
|
||
def run_performance_test(self, _domain, _value_function): | ||
return 3 | ||
|