Skip to content

Commit

Permalink
Fixed call cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbywater committed Oct 2, 2017
1 parent ce9c274 commit d34ae18
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 56 deletions.
2 changes: 0 additions & 2 deletions quantdsl/application/with_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def protected_loop_on_evaluation_queue(self):
if not self.has_thread_errored.is_set():
self.thread_exception = e
self.has_thread_errored.set()
if not isinstance(e, TimeoutError):
raise

def get_result(self, contract_valuation):
assert isinstance(contract_valuation, ContractValuation)
Expand Down
1 change: 1 addition & 0 deletions quantdsl/domain/services/contract_valuations.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def compute_call_result(contract_valuation, call_requirement, market_simulation,
def evaluate_dsl_expr(dsl_expr, first_commodity_name, simulation_id, interest_rate, present_time, simulated_value_dict,
perturbation_dependencies, dependency_results, path_count, perturbation_factor, periodisation,
approximate_discounting):

evaluation_kwds = {
'simulated_value_dict': simulated_value_dict,
'simulation_id': simulation_id,
Expand Down
98 changes: 54 additions & 44 deletions quantdsl/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,31 +661,35 @@ def validateCallArgs(self, dsl_locals):
raise DslSyntaxError('expected call arg not found',
"arg '%s' not in call arg namespace %s" % (call_arg_name, dsl_locals.keys()))

def apply(self, dsl_globals=None, effective_present_time=None, pending_call_stack=None, is_destacking=False,
as_call_arg=False, **dsl_locals):
def apply(self, _dsl_globals=None, effective_present_time=None, pending_call_stack=None, is_destacking=False,
**raw_dsl_locals):

# Decide either to stub out the function in the caller, and put
# the call on the pending stack; or to actually apply the args and
# generate a DSL expression.
do_apply = pending_call_stack is None or is_destacking or 'inline' in self.decorator_names

# Sort out the namespaces.
if dsl_globals is None:
dsl_globals = DslNamespace()
if _dsl_globals is None:
_dsl_globals = DslNamespace()
if self.module_namespace is None:
module_namespace = DslNamespace()
else:
module_namespace = self.module_namespace
# Todo: This can be simpler... module_namespace and dsl_globals are trying to do the same thing.
dsl_globals = DslNamespace(itertools.chain(
new_dsl_globals = DslNamespace(itertools.chain(
self.enclosed_namespace.items(),
module_namespace.items(),
dsl_globals.items())
_dsl_globals.items())
)

dsl_locals = DslNamespace(dsl_locals)

# Validate the call args with the definition.
self.validateCallArgs(dsl_locals)
self.validateCallArgs(raw_dsl_locals)

# Create the cache key.
new_dsl_locals = DslNamespace()
call_cache_key_dict = {}
for arg_name, arg_value in dsl_locals.items():
for arg_name, arg_value in raw_dsl_locals.items():
if isinstance(arg_value, FunctionCall):
if not list(arg_value.functionDef.find_instances(FunctionCall, StochasticObject)):
try:
Expand All @@ -700,76 +704,82 @@ def apply(self, dsl_globals=None, effective_present_time=None, pending_call_stac
raise Exception("Can't evaluate {}, a non-stochastic expression: {}: {}".format(arg_name,
arg_value, e))

new_dsl_locals[arg_name] = arg_value
call_cache_key_dict[arg_name] = arg_value
call_cache_key_dict["__effective_present_time__"] = effective_present_time or 'None'
call_cache_key_dict["__do_apply__"] = do_apply
call_cache_key = self.create_hash(call_cache_key_dict)

# Check the call cache, to see whether this function has already been called with these args.
if not is_destacking and call_cache_key in self.call_cache:
if call_cache_key in self.call_cache:
return self.call_cache[call_cache_key]

if pending_call_stack and not is_destacking and not 'inline' in self.decorator_names and not as_call_arg:
# Don't actually apply anything now, just stack the call expression
# and return a stub. This is why apply() appears to be called twice,
# in FunctionCall.reduce() and generate_stubbed_calls(), the later making
# the call that was deferred in the former.

# Create a new stub - the stub ID is the name of the return value of the function call..
stub_id = create_uuid4()
dsl_stub = Stub(stub_id, node=self.node)

# Put the function call on the call stack, with the stub ID.
# assert isinstance(pending_call_stack, PendingCallQueue)
pending_call_stack.put(
stub_id=stub_id,
stacked_function_def=self,
stacked_locals=dsl_locals.copy(),
stacked_globals=dsl_globals.copy(),
effective_present_time=effective_present_time
)
# Return the stub so that the containing DSL can be fully evaluated
# once the stacked function call has been evaluated.
dsl_expr = dsl_stub
else:
if do_apply:
# Select expression from body.
dsl_expr = self.body
ns = dsl_globals.combine(dsl_locals)
ns = new_dsl_globals.combine(new_dsl_locals)
while isinstance(dsl_expr, BaseIf):
# Todo: Also allow user defined functions that just do dates or numbers in test expression.
# it doesn't have or expand into DSL elements that are the functions of time (Wait, Choice, Market,
# etc).
dsl_expr = dsl_expr.select_expression(**ns)

# Add this function to the dslNamespace (just in case it's called by itself).
new_dsl_globals = DslNamespace(dsl_globals)
new_dsl_globals[self.name] = self
# Add this function to the namespace (it might recurse).
ns[self.name] = self

# Reduce the selected expression.
assert isinstance(dsl_expr, DslExpression)
dsl_expr = dsl_expr.substitute_names(new_dsl_globals.combine(dsl_locals))
dsl_expr = dsl_expr.substitute_names(ns.combine(new_dsl_locals))
dsl_expr = dsl_expr.call_functions(
effective_present_time=effective_present_time,
pending_call_stack=pending_call_stack
)

# Cache the result.
if not is_destacking:
self.call_cache[call_cache_key] = dsl_expr
else:
# Stack the call expression, with the call args,
# and return a Stub, for the calling expression.

# Create a new Stub ID.
stub_id = create_uuid4()

# Put the stub ID on the call stack, with this
# FunctionDef, the prepared call args, and the effective
# present time. This defines a pending call.
# assert isinstance(pending_call_stack, PendingCallQueue)
# Todo: Extract object class PendingCall.
pending_call_stack.put(
stub_id=stub_id,
stacked_function_def=self,
stacked_locals=new_dsl_locals.copy(),
stacked_globals=new_dsl_globals.copy(),
effective_present_time=effective_present_time
)

# Return the stub so that the containing DSL can be fully evaluated
# once the stacked function call has been evaluated.
dsl_expr = Stub(stub_id, node=self.node)

# Cache the expression.
self.call_cache[call_cache_key] = dsl_expr

return dsl_expr

def create_hash(self, obj):
if isinstance(obj, relativedelta):
return hash(repr(obj))
if isinstance(obj, (
int, float, six.string_types, datetime.datetime, datetime.date, datetime.timedelta, relativedelta)):
if isinstance(obj, six.integer_types + (float, six.string_types, datetime.date, datetime.timedelta,
relativedelta)):
return hash(obj)

if isinstance(obj, dict):
return hash(tuple(sorted([(a, self.create_hash(b)) for a, b in obj.items()])))

if isinstance(obj, list):
return hash(tuple(sorted([self.create_hash(a) for a in obj])))

if isinstance(obj, DslObject):
return hash(obj)

raise DslSystemError("Can't create hash from obj type '%s'" % type(obj), obj,
node=obj.node if isinstance(obj, DslObject) else None)

Expand Down
36 changes: 26 additions & 10 deletions quantdsl/tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,9 @@ def f(x):
self.assert_contract_value(code + "f(1) + f(1) + f(1)", expected_call_count=2)
self.assert_contract_value(code + "f(1) + f(1) + f(1)", expected_call_count=2)
self.assert_contract_value(code + "f(f(1))", expected_call_count=2)
self.assert_contract_value(code + "f(f(f(1)))", expected_call_count=1)
self.assert_contract_value(code + "f(f(f(1))) + f(f(f(1)))", expected_call_count=1)
self.assert_contract_value(code + "f(f(f(1))) + f(f(f(f(1))))", expected_call_count=1)
self.assert_contract_value(code + "f(f(f(1)))", expected_call_count=2)
self.assert_contract_value(code + "f(f(f(1))) + f(f(f(1)))", expected_call_count=2)
self.assert_contract_value(code + "f(f(f(1))) + f(f(f(f(1))))", expected_call_count=2)

def test_call_cache_inlined_function(self):
code = """
Expand Down Expand Up @@ -574,6 +574,29 @@ def f2(x):
x
f1(0, f2(1), %s)
"""
self.assert_contract_value(code % 0, 1, expected_call_count=2)
self.assert_contract_value(code % 1, 2, expected_call_count=4)
self.assert_contract_value(code % 2, 2, expected_call_count=6)
self.assert_contract_value(code % 3, 2, expected_call_count=8)
self.assert_contract_value(code % 4, 2, expected_call_count=10)

def test_call_cache_recombine_branches_with_referenced_function_call_arg(self):
code = """
def f1(x, y, t):
if t <= 0:
x + y(1)
else:
if x <= 0:
Max(f1(x, y, t-1), f1(x+1, y, t-1))
else:
Max(f1(x-1, y, t-1), f1(x, y, t-1))
@inline
def f2(x):
x
f1(0, f2, %s)
"""
self.assert_contract_value(code % 0, 1, expected_call_count=2)
self.assert_contract_value(code % 1, 2, expected_call_count=4)
Expand Down Expand Up @@ -648,13 +671,6 @@ def fib(n):
self.assert_contract_value(fib_tmpl % 5, 5, expected_call_count=7)
self.assert_contract_value(fib_tmpl % 6, 8, expected_call_count=8)
self.assert_contract_value(fib_tmpl % 7, 13, expected_call_count=9)
# self.assert_contract_value(fib_tmpl % 17, 1597, expected_call_count=19)

def test_add(self):
fib_tmpl = """
TimeDelta('1d') + %s
"""
self.assert_contract_value(fib_tmpl % 4, 3, expected_call_count=None)

def test_compare_args_error(self):
code = "TimeDelta('1d') > 1"
Expand Down

0 comments on commit d34ae18

Please sign in to comment.