From 8c36678d14e13d45df701fff450ed3e852c2cfbe Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 22 Aug 2019 16:41:36 -0700 Subject: [PATCH 1/5] Add test for function argument --- tests/test_ssa.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_ssa.py b/tests/test_ssa.py index dd6bea3..6ad999c 100644 --- a/tests/test_ssa.py +++ b/tests/test_ssa.py @@ -1,4 +1,5 @@ import ast +import inspect import pytest @@ -119,3 +120,24 @@ def test_imbalanced(a, b, c, d): for x in (False, True): for y in (False, True): assert imbalanced(x, y) == imbalanced_ssa(x, y) + + +def test_reassign_arg(): + def bar(x): + return x + + @end_rewrite() + @ssa() + @begin_rewrite() + def foo(a, b): + if b: + a = bar(a) + return a + assert inspect.getsource(foo) == """\ +def foo(a, b): + a0 = bar(a) + a1 = a0 if b else a + __return_value0 = a1 + return __return_value0 +""" + From b3a2f10be63946396134e338bb6e3e1c239a4fb1 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 22 Aug 2019 16:46:27 -0700 Subject: [PATCH 2/5] Reorder visit_Assign in ssa --- ast_tools/passes/ssa.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 97d8240..b219bed 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -62,6 +62,11 @@ def visit(self, node: ast.AST) -> ast.AST: else: return super().visit(node) + def visit_Assign(self, node): + node.value = self.visit(node.value) + node.targets = [self.visit(t) for t in node.targets] + return node + def visit_If(self, node: ast.If) -> tp.List[ast.stmt]: test = self.visit(node.test) From faccf6a34d1ba5f2cc067615e7249afbe8a1fcdb Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 22 Aug 2019 17:02:32 -0700 Subject: [PATCH 3/5] Add failing test --- tests/test_ssa.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_ssa.py b/tests/test_ssa.py index 6ad999c..f5bee21 100644 --- a/tests/test_ssa.py +++ b/tests/test_ssa.py @@ -141,3 +141,36 @@ def foo(a, b): return __return_value0 """ + +def test_double_nested_function_call(): + def bar(x): + return x + + def baz(x): + return x + 1 + + @end_rewrite() + @ssa() + @begin_rewrite() + def foo(a, b, c): + if b: + a = bar(a) + else: + a = bar(a) + if c: + b = bar(b) + else: + b = bar(b) + return a, b + print(inspect.getsource(foo)) + assert inspect.getsource(foo) == """\ +def foo(a, b, c): + a0 = bar(a) + a1 = bar(a) + a2 = a0 if b else a1 + b0 = bar(b) + b1 = bar(b) + b2 = b0 if c else b1 + __return_value0 = a2 + return __return_value0 +""" From b5cb232d31b55435a69ff87f1af732c19a1180c7 Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Fri, 23 Aug 2019 10:05:44 -0700 Subject: [PATCH 4/5] Add comment about visiting RHS first --- ast_tools/passes/ssa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index b219bed..402dacc 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -63,6 +63,7 @@ def visit(self, node: ast.AST) -> ast.AST: return super().visit(node) def visit_Assign(self, node): + # visit RHS first node.value = self.visit(node.value) node.targets = [self.visit(t) for t in node.targets] return node From 6a078d63a8057cce3cad376af23f2fb1a29557e6 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 23 Aug 2019 11:06:27 -0700 Subject: [PATCH 5/5] Fix unnecessary muxing --- ast_tools/passes/ssa.py | 10 +++++++++- tests/test_ssa.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 402dacc..77ca468 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -215,8 +215,16 @@ def visit_Name(self, node: ast.Name) -> ast.Name: name = node.id ctx = node.ctx if isinstance(ctx, ast.Load): + # Names in Load context should not be added to the name table + # as it makes them seem like they have been modified. + try: + return ast.Name( + id=self.name_table[name], + ctx=ctx) + except KeyError: + pass return ast.Name( - id=self.name_table.setdefault(name, name), + id=name, ctx=ctx) else: return ast.Name( diff --git a/tests/test_ssa.py b/tests/test_ssa.py index f5bee21..89b8f74 100644 --- a/tests/test_ssa.py +++ b/tests/test_ssa.py @@ -171,6 +171,6 @@ def foo(a, b, c): b0 = bar(b) b1 = bar(b) b2 = b0 if c else b1 - __return_value0 = a2 + __return_value0 = a2, b2 return __return_value0 """