diff --git a/quantdsl/interfaces/calcandplot.py b/quantdsl/interfaces/calcandplot.py index b04a52f..15ab5c1 100644 --- a/quantdsl/interfaces/calcandplot.py +++ b/quantdsl/interfaces/calcandplot.py @@ -29,50 +29,51 @@ def calc_and_plot(*args, **kwargs): class CalcAndPlot(object): def __init__(self): self.result_values_computed_count = 0 + self.call_result_id = None + self.is_evaluation_ready = Event() + + def is_evaluation_complete(self, event): + return isinstance(event, CallResult.Created) and event.entity_id == self.call_result_id + + def on_evaluation_complete(self, _): + self.is_evaluation_ready.set() def run(self, title, source_code, observation_date, interest_rate, path_count, perturbation_factor, price_process, periodisation, supress_plot=False): - app = QuantDslApplicationWithMultithreadingAndPythonObjects() + with QuantDslApplicationWithMultithreadingAndPythonObjects() as app: - start_compile = datetime.datetime.now() - contract_specification = app.compile(source_code) - end_compile = datetime.datetime.now() - print("Compilation in {}s".format((end_compile - start_compile).total_seconds())) + start_compile = datetime.datetime.now() + contract_specification = app.compile(source_code) + end_compile = datetime.datetime.now() + print("Compilation in {}s".format((end_compile - start_compile).total_seconds())) - start_calc = datetime.datetime.now() - evaluation, market_simulation = self.calc_results(app, interest_rate, observation_date, - path_count, perturbation_factor, contract_specification, - price_process['name'], price_process) + subscribe(self.is_evaluation_complete, self.on_evaluation_complete) + start_calc = datetime.datetime.now() - is_evaluation_ready = Event() + evaluation, market_simulation = self.calc_results(app, interest_rate, observation_date, + path_count, perturbation_factor, contract_specification, + price_process['name'], price_process) - call_result_id = make_call_result_id(evaluation.id, evaluation.contract_specification_id) + self.call_result_id = make_call_result_id(evaluation.id, evaluation.contract_specification_id) - def is_evaluation_complete(event): - return isinstance(event, CallResult.Created) and event.entity_id == call_result_id - def on_evaluation_complete(_): - is_evaluation_ready.set() + while self.call_result_id not in app.call_result_repo: + if self.is_evaluation_ready.wait(timeout=2): + break - subscribe(is_evaluation_complete, on_evaluation_complete) + unsubscribe(self.is_evaluation_complete, self.on_evaluation_complete) - while call_result_id not in app.call_result_repo: - if is_evaluation_ready.wait(timeout=2): - break + fair_value_stderr, fair_value_mean, periods = self.read_results(app, evaluation, market_simulation, + path_count) - unsubscribe(is_evaluation_complete, on_evaluation_complete) + end_calc = datetime.datetime.now() + self.print_results(fair_value_mean, fair_value_stderr, periods) + print("") + print("Results in {}s".format((end_calc - start_calc).total_seconds())) - fair_value_stderr, fair_value_mean, periods = self.read_results(app, evaluation, market_simulation, - path_count) - - end_calc = datetime.datetime.now() - self.print_results(fair_value_mean, fair_value_stderr, periods) - print("") - print("Results in {}s".format((end_calc - start_calc).total_seconds())) - - supress_plot = supress_plot or os.getenv('SUPRESS_PLOT') - if not supress_plot and plt and len(periods) > 1: - self.plot_results(interest_rate, path_count, perturbation_factor, periods, title, periodisation) + supress_plot = supress_plot or os.getenv('SUPRESS_PLOT') + if not supress_plot and plt and len(periods) > 1: + self.plot_results(interest_rate, path_count, perturbation_factor, periods, title, periodisation) def calc_results(self, app, interest_rate, observation_date, path_count, perturbation_factor, contract_specification, price_process_name, calibration_params): @@ -90,41 +91,42 @@ def calc_results(self, app, interest_rate, observation_date, path_count, perturb ) call_costs = app.calc_call_costs(contract_specification.id) - total_cost = sum(call_costs.values()) - - times = collections.deque() - - def is_result_value_computed(event): - return isinstance(event, ResultValueComputed) - - def count_result_values_computed(event): - times.append(datetime.datetime.now()) - if len(times) > 0.5 * total_cost: - times.popleft() - if len(times) > 1: - duration = times[-1] - times[0] - rate = len(times) / duration.total_seconds() - else: - rate = 0.001 - eta = (total_cost - self.result_values_computed_count) / rate - assert isinstance(event, ResultValueComputed) - self.result_values_computed_count += 1 - sys.stdout.write( - "\r{:.2f}% complete ({}/{}) {:.2f}/s eta {:.0f}s".format( - (100.0 * self.result_values_computed_count) / total_cost, - self.result_values_computed_count, - total_cost, - rate, - eta - ) - ) - sys.stdout.flush() + self.total_cost = sum(call_costs.values()) - subscribe(is_result_value_computed, count_result_values_computed) + self.times = collections.deque() + + subscribe(self.is_result_value_computed, self.count_result_values_computed) evaluation = app.evaluate(contract_specification, market_simulation) - # unsubscribe(is_result_value_computed, count_result_values_computed) + unsubscribe(self.is_result_value_computed, self.count_result_values_computed) return evaluation, market_simulation + def is_result_value_computed(self, event): + return isinstance(event, ResultValueComputed) + + def count_result_values_computed(self, event): + self.times.append(datetime.datetime.now()) + if len(self.times) > 0.5 * self.total_cost: + self.times.popleft() + if len(self.times) > 1: + duration = self.times[-1] - self.times[0] + rate = len(self.times) / duration.total_seconds() + else: + rate = 0.001 + eta = (self.total_cost - self.result_values_computed_count) / rate + assert isinstance(event, ResultValueComputed) + self.result_values_computed_count += 1 + sys.stdout.write( + "\r{:.2f}% complete ({}/{}) {:.2f}/s eta {:.0f}s".format( + (100.0 * self.result_values_computed_count) / self.total_cost, + self.result_values_computed_count, + self.total_cost, + rate, + eta + ) + ) + sys.stdout.flush() + + def read_results(self, app, evaluation, market_simulation, path_count): assert isinstance(evaluation, ContractValuation) diff --git a/quantdsl/tests/test_calc_and_plot.py b/quantdsl/tests/test_calc_and_plot.py index 9c9c818..5fdce7a 100644 --- a/quantdsl/tests/test_calc_and_plot.py +++ b/quantdsl/tests/test_calc_and_plot.py @@ -1,9 +1,17 @@ from unittest.case import TestCase +from eventsourcing.domain.model.events import assert_event_handlers_empty + from quantdsl.interfaces.calcandplot import calc_and_plot class TestCalcAndPlot(TestCase): + def setUp(self): + assert_event_handlers_empty() + + def tearDown(self): + assert_event_handlers_empty() + def test(self): source_code = """from quantdsl.lib.storage2 import GasStorage