diff --git a/quantdsl/semantics.py b/quantdsl/semantics.py index 7e0e3fd..87c9d66 100644 --- a/quantdsl/semantics.py +++ b/quantdsl/semantics.py @@ -152,11 +152,11 @@ def assert_args_arg(self, args, posn, required_type): desc += str(args[posn]) raise DslSyntaxError(error_msg, desc, self.node) - def list_instances(self, dsl_type): - return list(self.find_instances(dsl_type)) + def list_instances(self, *dsl_types): + return list(self.find_instances(*dsl_types)) - def has_instances(self, dsl_type): - for _ in self.find_instances(dsl_type): + def has_instances(self, *dsl_types): + for _ in self.find_instances(*dsl_types): return True else: return False @@ -729,13 +729,13 @@ def apply(self, dsl_globals=None, present_time=None, observation_date=None, pend call_cache_key_dict = {} for arg_name, arg_value in raw_dsl_locals.items(): if isinstance(arg_value, FunctionCall): - if not list(arg_value.functionDef.find_instances(FunctionCall, StochasticObject)): + if not arg_value.functionDef.list_instances(FunctionCall, StochasticObject): try: arg_value = arg_value.call_functions() except DslError as e: raise Exception("Can't evaluate {}: {}: {}".format(arg_name, arg_value, e)) elif isinstance(arg_value, DslExpression): - if not arg_value.find_instances(StochasticObject): + if not arg_value.list_instances(StochasticObject): try: arg_value = arg_value.evaluate() except DslError as e: @@ -824,8 +824,8 @@ def create_hash(self, obj): if isinstance(obj, dict): return hash(tuple(sorted([(a, self.create_hash(b)) for a, b in obj.items()]))) - if isinstance(obj, list): - return hash(tuple(sorted([self.create_hash(a) for a in obj]))) + # if isinstance(obj, list): + # return hash(tuple(sorted([self.create_hash(a) for a in obj]))) raise DslSystemError("Can't create hash from obj type '%s'" % type(obj), obj, node=obj.node if isinstance(obj, DslObject) else None) diff --git a/quantdsl/tests/test_application.py b/quantdsl/tests/test_application.py index 3c8abe4..a47b66a 100644 --- a/quantdsl/tests/test_application.py +++ b/quantdsl/tests/test_application.py @@ -1219,22 +1219,22 @@ def ProfitFromRunning(start_date, underlying, time_since_off): # Check single-sided vs. double sided deltas. self.assert_contract_value(specification.format(end_date='2012-01-13'), expected_value=11.771, - expected_call_count=47, + expected_call_count=37, periodisation='monthly', expected_deltas={'SPARKSPREAD-2012-1': 11.978}) self.assert_contract_value(specification.format(end_date='2012-01-13'), expected_value=11.771, - expected_call_count=47, + expected_call_count=37, periodisation='monthly', expected_deltas={'SPARKSPREAD-2012-1': 11.978}, is_double_sided_deltas=False) # Check the call counts. - self.assert_contract_value(specification.format(end_date='2012-01-13'), 11.771, expected_call_count=47) - self.assert_contract_value(specification.format(end_date='2012-01-14'), expected_call_count=51) - self.assert_contract_value(specification.format(end_date='2012-01-15'), expected_call_count=55) - self.assert_contract_value(specification.format(end_date='2012-01-16'), expected_call_count=59) + self.assert_contract_value(specification.format(end_date='2012-01-13'), 11.771, expected_call_count=37) + self.assert_contract_value(specification.format(end_date='2012-01-14'), expected_call_count=40) + self.assert_contract_value(specification.format(end_date='2012-01-15'), expected_call_count=43) + self.assert_contract_value(specification.format(end_date='2012-01-16'), expected_call_count=46) def test_generate_valuation_power_plant_option_power_and_gas_forward(self): specification = """ @@ -1290,11 +1290,11 @@ def Tomorrow(today): PowerPlant(Date('2012-1-1'), Date('{end_date}'), Stopped(2)) """ - self.assert_contract_value(specification.format(end_date='2012-01-13'), 22.195, expected_call_count=47) - self.assert_contract_value(specification.format(end_date='2012-01-14'), expected_call_count=51) - self.assert_contract_value(specification.format(end_date='2012-01-15'), expected_call_count=55) - self.assert_contract_value(specification.format(end_date='2012-01-16'), expected_call_count=59) - self.assert_contract_value(specification.format(end_date='2012-01-17'), expected_call_count=63) + self.assert_contract_value(specification.format(end_date='2012-01-13'), 22.195, expected_call_count=37) + self.assert_contract_value(specification.format(end_date='2012-01-14'), expected_call_count=40) + self.assert_contract_value(specification.format(end_date='2012-01-15'), expected_call_count=43) + self.assert_contract_value(specification.format(end_date='2012-01-16'), expected_call_count=46) + self.assert_contract_value(specification.format(end_date='2012-01-17'), expected_call_count=49) def test_call_recombinations_with_function_calls_advancing_values(self): # This wasn't working, because each function call was being carried into the next @@ -1319,7 +1319,7 @@ def f2(d): f(9, 1) """ - self.assert_contract_value(specification, expected_call_count=28) + self.assert_contract_value(specification, expected_call_count=20) class TestObservationDate(ApplicationTestCase):