Skip to content

Commit

Permalink
Second attempt to implement module namespaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbywater committed Sep 12, 2017
1 parent 63826b8 commit 7ca4eb0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 33 deletions.
6 changes: 5 additions & 1 deletion quantdsl/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,9 @@ def __init__(self, *args, **kwds):
self.call_cache = {}
self.enclosed_namespace = DslNamespace()

# Second attempt to implement module namespaces...
self.module_namespace = None

def validate(self, args):
self.assert_args_len(args, required_len=4)

Expand Down Expand Up @@ -657,7 +660,8 @@ def apply(self, dsl_globals=None, effective_present_time=None, pending_call_stac
else:
pass
# assert isinstance(dsl_globals, DslNamespace)
dsl_globals = DslNamespace(itertools.chain(self.enclosed_namespace.items(), dsl_globals.items()))
dsl_globals = DslNamespace(itertools.chain(self.enclosed_namespace.items(), self.module_namespace.items(),
dsl_globals.items()))
dsl_locals = DslNamespace(dsl_locals)

# Validate the call args with the definition.
Expand Down
71 changes: 40 additions & 31 deletions quantdsl/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import six

from quantdsl.exceptions import DslSyntaxError
from quantdsl.semantics import FunctionDef, DslNamespace


class DslParser(object):
Expand Down Expand Up @@ -62,6 +63,9 @@ def visitModule(self, node):
# assert isinstance(node, ast.Module)
body = []

# Namespace for function defs in module.
module_namespace = DslNamespace()

def inline(body):
flat = []
assert isinstance(body, list), body
Expand All @@ -76,6 +80,12 @@ def inline(body):
for n in node.body:
dsl_object = self.visitAstNode(n)

if isinstance(dsl_object, FunctionDef):
# Put function def in module namespace.
module_namespace[dsl_object.name] = dsl_object
# Share module namespace with this function.
dsl_object.module_namespace = module_namespace

if isinstance(dsl_object, list):
for _dsl_object in inline(dsl_object):
body.append(_dsl_object)
Expand All @@ -84,28 +94,22 @@ def inline(body):

return self.dsl_classes['Module'](body, node=node)

def visitImport(self, node):
"""
Visitor method for ast.Import nodes.
Returns the result of visiting the expression held by the return statement.
"""
assert isinstance(node, ast.Import)
nodes = []
for imported in node.names:
name = imported.name
if name == 'quantdsl.semantics':
continue
# spec = importlib.util.find_spec()
module = importlib.import_module(name)
path = module.__file__.strip('c')
source = open(path).read() # .py not .pyc
dsl_node = self.parse(source, filename=path)
assert isinstance(dsl_node, self.dsl_classes['Module']), type(dsl_node)
for node in dsl_node.body:
nodes.append(node)

return nodes
# def visitImport(self, node):
# """
# Visitor method for ast.Import nodes.
#
# Returns the result of visiting the expression held by the return statement.
# """
# assert isinstance(node, ast.Import)
# nodes = []
# for imported in node.names:
# name = imported.name
# if name == 'quantdsl.semantics':
# continue
# # spec = importlib.util.find_spec()
# node = self.import_python_module(name, node, nodes)
#
# return nodes

def visitImportFrom(self, node):
"""
Expand All @@ -114,19 +118,24 @@ def visitImportFrom(self, node):
Returns the result of visiting the expression held by the return statement.
"""
assert isinstance(node, ast.ImportFrom)
nodes = []
if node.module == 'quantdsl.semantics':
return nodes
imported_names = [a.name for a in node.names]
module = importlib.import_module(node.module)
return []
from_names = [a.name for a in node.names]
dsl_module = self.import_python_module(node.module)
nodes = []
for node in dsl_module.body:
if isinstance(node, FunctionDef) and node.name in from_names:
nodes.append(node)
return nodes

def import_python_module(self, module_name):
nodes = []
module = importlib.import_module(module_name)
path = module.__file__.strip('c')
source = open(path).read()
source = open(path).read() # .py not .pyc
dsl_node = self.parse(source, filename=path)
assert isinstance(dsl_node, self.dsl_classes['Module']), type(dsl_node)
for dsl_obj in dsl_node.body:
# if dsl_obj.name in imported_names:
nodes.append(dsl_obj)
return nodes
return dsl_node

def visitReturn(self, node):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class ExpressionTests(ContractValuationTestCase, TestCase):
def test_generate_valuation_addition(self):
specification_tmpl = """
from quantdsl.lib.storage1 import GasStorage, Date
from quantdsl.lib.storage1 import GasStorage
GasStorage(Date('%(start_date)s'), Date('%(end_date)s'), '%(commodity)s', %(quantity)s, %(limit)s, TimeDelta('1m'))
"""
Expand Down

0 comments on commit 7ca4eb0

Please sign in to comment.