diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 97d8240..77ca468 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -62,6 +62,12 @@ def visit(self, node: ast.AST) -> ast.AST: else: return super().visit(node) + def visit_Assign(self, node): + # visit RHS first + node.value = self.visit(node.value) + node.targets = [self.visit(t) for t in node.targets] + return node + def visit_If(self, node: ast.If) -> tp.List[ast.stmt]: test = self.visit(node.test) @@ -209,8 +215,16 @@ def visit_Name(self, node: ast.Name) -> ast.Name: name = node.id ctx = node.ctx if isinstance(ctx, ast.Load): + # Names in Load context should not be added to the name table + # as it makes them seem like they have been modified. + try: + return ast.Name( + id=self.name_table[name], + ctx=ctx) + except KeyError: + pass return ast.Name( - id=self.name_table.setdefault(name, name), + id=name, ctx=ctx) else: return ast.Name( diff --git a/tests/test_ssa.py b/tests/test_ssa.py index dd6bea3..89b8f74 100644 --- a/tests/test_ssa.py +++ b/tests/test_ssa.py @@ -1,4 +1,5 @@ import ast +import inspect import pytest @@ -119,3 +120,57 @@ def test_imbalanced(a, b, c, d): for x in (False, True): for y in (False, True): assert imbalanced(x, y) == imbalanced_ssa(x, y) + + +def test_reassign_arg(): + def bar(x): + return x + + @end_rewrite() + @ssa() + @begin_rewrite() + def foo(a, b): + if b: + a = bar(a) + return a + assert inspect.getsource(foo) == """\ +def foo(a, b): + a0 = bar(a) + a1 = a0 if b else a + __return_value0 = a1 + return __return_value0 +""" + + +def test_double_nested_function_call(): + def bar(x): + return x + + def baz(x): + return x + 1 + + @end_rewrite() + @ssa() + @begin_rewrite() + def foo(a, b, c): + if b: + a = bar(a) + else: + a = bar(a) + if c: + b = bar(b) + else: + b = bar(b) + return a, b + print(inspect.getsource(foo)) + assert inspect.getsource(foo) == """\ +def foo(a, b, c): + a0 = bar(a) + a1 = bar(a) + a2 = a0 if b else a1 + b0 = bar(b) + b1 = bar(b) + b2 = b0 if c else b1 + __return_value0 = a2, b2 + return __return_value0 +"""