Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mycpp: only emit StackRoots if Collect() possbile #1963

Closed
wants to merge 11 commits into from
5 changes: 5 additions & 0 deletions mycpp/NINJA_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def ShouldSkipBenchmark(name):
'mycpp/testpkg/module2.py',
'mycpp/examples/modules.py',
],
'test_root_call_graph': [
'mycpp/testpkg/module1.py',
'mycpp/testpkg/module2.py',
'mycpp/examples/test_root_call_graph.py',
],
'parse': [], # added dynamically from mycpp/examples/parse.translate.txt
}

Expand Down
145 changes: 135 additions & 10 deletions mycpp/cppgen_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def __init__(self,
field_gc=None,
decl=False,
forward_decl=False,
stack_roots_warn=None):
stack_roots_warn=None,
call_graph=None):
self.types = types
self.const_lookup = const_lookup
self.f = f
Expand All @@ -429,6 +430,7 @@ def __init__(self,
self.local_var_list = [] # Collected at assignment
self.prepend_to_block = None # For writing vars after {
self.current_func_node = None
self.current_func_name = None
self.current_stmt_node = None
# Temporary lists to use as output params for generators
self.yield_accumulators = {
Expand All @@ -445,6 +447,7 @@ def __init__(self,
self.current_method_name = None

self.imported_names = set() # MemberExpr -> module::Foo() or self->foo
self.module_aliases: Dict[str, str] = {}

# HACK for conditional import inside mylib.PYTHON
# in core/shell.py
Expand All @@ -456,6 +459,8 @@ def __init__(self,

self.writing_default_arg = False

self.call_graph: Dict[str, Dict[str, int]] = call_graph or {}

def log(self, msg, *args):
ind_str = self.indent * ' '
log(ind_str + msg, *args)
Expand Down Expand Up @@ -542,6 +547,7 @@ def report_error(self, node: Union[Statement, Expression], msg: str):
def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
# Skip some stdlib stuff. A lot of it is brought in by 'import
# typing'.
self.full_module_name = o.fullname
if o.fullname in ('__future__', 'sys', 'types', 'typing', 'abc',
'_ast', 'ast', '_weakrefset', 'collections',
'cStringIO', 're', 'builtins'):
Expand Down Expand Up @@ -756,6 +762,93 @@ def _IsInstantiation(self, o):

return False

def add_to_call_graph(self, o):
full_callee = None
if isinstance(o.callee, NameExpr):
full_callee = o.callee.fullname

elif isinstance(o.callee, MemberExpr):
if isinstance(o.callee.expr, NameExpr):
is_module = (isinstance(o.callee.expr, NameExpr) and
o.callee.expr.name in self.imported_names)
if is_module:
full_callee = '%s.%s' % (self.module_aliases.get(o.callee.expr.name, o.callee.expr.name), o.callee.name)

elif o.callee.expr.name == 'self':
assert self.current_class_name
full_callee = '%s.%s.%s' % (self.full_module_name, self.current_class_name, o.callee.name)

else:
local_type = None
for name, _, t in self.local_var_list:
if name == o.callee.expr.name:
local_type = t
break

if local_type:
if isinstance(local_type, Instance):
full_callee = '%s.%s' % (local_type.type.fullname, o.callee.name)

elif isinstance(local_type, UnionType):
assert len(local_type.items) == 2
full_callee = '%s.%s' % (local_type.items[0].type.fullname, o.callee.expr.name)

else:
assert not isinstance(local_type, CallableType)
# primitive type or string. don't care.
pass

else:
# context or exception handler. probably safe to ignore.
pass

else:
t = self.types[o.callee.expr]
if isinstance(t, Instance):
full_callee = '%s.%s' % (t.type.fullname, o.callee.name)

elif isinstance(t, UnionType):
assert len(t.items) == 2
full_callee = '%s.%s' % (t.items[0].type.fullname, o.callee.name)

elif isinstance(t, CallableType):
# codegen doesn't call into the GC.
assert o.callee.expr.fullname.startswith('_devbuild')

else:
# constructors of things that we don't care about.
pass

else:
# Don't currently get here
raise AssertionError()

if full_callee:
if full_callee.startswith('mycpp.'):
# For compatability with testpkg...
full_callee = full_callee[6:]

self.call_graph[self.current_func_name][full_callee] = 1

def call_path_exists(self, src, dst, visited):
"""Do a DFS from src to dst. Returns true if a path was found."""

if self.decl or self.forward_decl:
return False

visited.add(src)
if src not in self.call_graph:
return False

for neighbor in self.call_graph[src]:
if neighbor == dst:
return True

if neighbor not in visited and self.call_path_exists(neighbor, dst, visited):
return True

return False

def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
if o.callee.name == 'probe':
assert len(o.args) >= 2 and len(o.args) < 13, o.args
Expand All @@ -779,6 +872,9 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
self.def_write(')')
return

if self.decl and self.current_func_name:
self.add_to_call_graph(o)

if o.callee.name == 'isinstance':
assert len(o.args) == 2, o.args
obj = o.args[0]
Expand Down Expand Up @@ -1311,7 +1407,7 @@ def _write_tuple_unpacking(self,
item_c_type = GetCType(item_type)
# declare it at the top of the function
if self.decl:
self.local_var_list.append((lval_item.name, item_c_type))
self.local_var_list.append((lval_item.name, item_c_type, item_type))
self.def_write_ind('%s', lval_item.name)
else:
# Could be MemberExpr like self.foo, self.bar = baz
Expand Down Expand Up @@ -1411,6 +1507,8 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
# problem, but I believe that's gone now.

callee = o.rvalue.callee
if self.decl and self.current_func_name:
self.add_to_call_graph(o.rvalue)

if callee.name == 'NewDict':
lval_type = self.types[lval]
Expand All @@ -1425,7 +1523,7 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:

c_type = GetCType(lval_type)
if self.decl:
self.local_var_list.append((lval.name, c_type))
self.local_var_list.append((lval.name, c_type, lval_type))

assert c_type.endswith('*')

Expand Down Expand Up @@ -1462,7 +1560,7 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
else:
# Normal variable
if self.decl:
self.local_var_list.append((lval.name, subtype_name))
self.local_var_list.append((lval.name, subtype_name, subtype_name.replace('::', '.')))
self.def_write_ind('%s = %s<%s>(', lval.name, cast_kind,
subtype_name)

Expand Down Expand Up @@ -1503,7 +1601,7 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
if self.current_func_node:
self.def_write_ind('%s = ', lval.name)
if self.decl:
self.local_var_list.append((lval.name, c_type))
self.local_var_list.append((lval.name, c_type, lval_type))
else:
# globals always get a type -- they're not mutated
self.def_write_ind('%s %s = ', c_type, lval.name)
Expand Down Expand Up @@ -1739,6 +1837,9 @@ def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
return

reverse = False
need_roots = self.call_path_exists(self.current_func_name,
'mylib.MaybeCollect',
set({}))

# for i, x in enumerate(...):
index0_name = None
Expand Down Expand Up @@ -1842,7 +1943,7 @@ def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
if index0_name:
# can't initialize two things in a for loop, so do it on a separate line
if self.decl:
self.local_var_list.append((index0_name, 'int'))
self.local_var_list.append((index0_name, 'int', None))
self.def_write_ind('%s = 0;\n', index0_name)
index_update = ', ++%s' % index0_name
else:
Expand Down Expand Up @@ -1872,7 +1973,7 @@ def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
# it's called in a loop by _ExecuteList(). Although the 'child'
# variable is already live by other means.
# TODO: Test how much this affects performance.
if CTypeIsManaged(c_item_type):
if CTypeIsManaged(c_item_type) and need_roots:
self.def_write_ind(' StackRoot _for(&')
self.accept(index_expr)
self.def_write_ind(');\n')
Expand Down Expand Up @@ -2336,7 +2437,7 @@ def _WriteFuncParams(self,
# only do it in one place. TODO: Check if locals are used in
# __init__ after allocation.
if update_locals:
self.local_var_list.append((arg_name, c_type))
self.local_var_list.append((arg_name, c_type, arg_type))

# We can't use __str__ on these Argument objects? That seems like an
# oversight
Expand Down Expand Up @@ -2410,12 +2511,28 @@ def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
# write default values in the declaration only
write_defaults=self.decl)

# Produce a fully qualified name for the function for the call graph.
func_name = o.name
if self.current_class_name:
func_name = '%s.%s' % (self.current_class_name, o.name)

func_name = '%s.%s' % (self.full_module_name, func_name)
if func_name.startswith('mycpp.'):
# For compatability with testpkg...
func_name = func_name[6:]

self.current_func_name = func_name

if self.decl:
assert func_name not in self.call_graph, func_name
self.call_graph[func_name] = {}

self.always_write(');\n')
self.accept(
o.body) # Collect member_vars, but don't write anything

self.current_func_node = None
self.current_func_name = None
return

self.def_write(') ')
Expand All @@ -2427,11 +2544,12 @@ def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
#log('local_vars %s', self.local_vars[o])
self.prepend_to_block = [
(lval_name, c_type, lval_name in arg_names)
for (lval_name, c_type) in self.local_vars[o]
for (lval_name, c_type, _) in self.local_vars[o]
]

self.accept(o.body)
self.current_func_node = None
self.current_func_name = None

def visit_overloaded_func_def(self,
o: 'mypy.nodes.OverloadedFuncDef') -> T:
Expand Down Expand Up @@ -2725,8 +2843,10 @@ def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T:
for name, alias in o.names:
if alias:
self.imported_names.add(alias)
self.module_aliases[alias] = '%s.%s' % (o.id, alias)
else:
self.imported_names.add(name)
self.module_aliases[name] = '%s.%s' % (o.id, name)

if o.id in ('__future__', 'typing'):
return # do nothing
Expand Down Expand Up @@ -2843,7 +2963,12 @@ def visit_block(self, block: 'mypy.nodes.Block') -> T:
roots.append(lval_name)
#self.log('roots %s', roots)

if len(roots):
need_roots = self.current_func_name.startswith('core.shell.') or self.call_path_exists(
self.current_func_name,
'mylib.MaybeCollect',
set({}))

if len(roots) and need_roots:
if (self.stack_roots_warn and
len(roots) > self.stack_roots_warn):
log('WARNING: %s::%s() has %d stack roots. Consider refactoring this function.'
Expand Down
Loading
Loading