From d5c8d456e8d5fb7613f07461bed12f508ac1f512 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Thu, 22 Aug 2019 11:30:50 -0700 Subject: [PATCH 1/3] Remove unused imports --- ast_tools/passes/ssa.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 97d8240..9fd02d4 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -1,12 +1,7 @@ import ast from collections import ChainMap, Counter -import itertools -import warnings -import weakref import typing as tp -import astor - from . import Pass from . import PASS_ARGS_T from ast_tools.common import gen_free_prefix, is_free_name From cbc3d30230a6f32ace5ae157afe01f6092e79d8c Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Thu, 22 Aug 2019 14:41:39 -0700 Subject: [PATCH 2/3] bool to bit pass --- ast_tools/passes/__init__.py | 1 + ast_tools/passes/bool_to_bit.py | 74 +++++++++++++++++++++++++++++++++ ast_tools/passes/ssa.py | 5 ++- tests/test_bool_to_bit.py | 56 +++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 ast_tools/passes/bool_to_bit.py create mode 100644 tests/test_bool_to_bit.py diff --git a/ast_tools/passes/__init__.py b/ast_tools/passes/__init__.py index bd6e3a7..87324b2 100644 --- a/ast_tools/passes/__init__.py +++ b/ast_tools/passes/__init__.py @@ -3,3 +3,4 @@ from .debug import * from .ssa import * from .util import * +from .bool_to_bit import * diff --git a/ast_tools/passes/bool_to_bit.py b/ast_tools/passes/bool_to_bit.py new file mode 100644 index 0000000..ce9020d --- /dev/null +++ b/ast_tools/passes/bool_to_bit.py @@ -0,0 +1,74 @@ +import ast +import typing as tp + +from . import Pass +from . import PASS_ARGS_T + +from ast_tools.stack import SymbolTable + +__ALL__ = ['bool_to_bit'] + +class AndTransformer(ast.NodeTransformer): + def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: + # Can't get more specific on return type because if + # len(node.values) == 1 (which it shouldn't be) + # then the return type is expr otherwise + # the return type is Union[BinOp, BoolOp] + + if isinstance(node.op, ast.And): + values = node.values + assert values # should not be empty + expr = self.visit(values[0]) + for v in map(self.visit, values[1:]): + expr = ast.BinOp(expr, ast.BitAnd(), v) + return expr + else: + return self.generic_visit(node) + +class OrTransformer(ast.NodeTransformer): + def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: + if isinstance(node.op, ast.Or): + values = node.values + assert values # should not be empty + expr = self.visit(values[0]) + for v in map(self.visit, values[1:]): + expr = ast.BinOp(expr, ast.BitOr(), v) + return expr + else: + return self.generic_visit(node) + +class NotTransformer(ast.NodeTransformer): + def visit_Not(self, node: ast.Not) -> ast.Invert: + return ast.Invert() + +class bool_to_bit(Pass): + ''' + Pass to replace bool operators (and, or, not) + with bit operators (&, |, ~) + ''' + def __init__(self, + replace_and: bool = True, + replace_or: bool = True, + replace_not: bool = True, + ): + self.replace_and = replace_and + self.replace_or = replace_or + self.replace_not = replace_not + + def rewrite(self, + tree: ast.AST, + env: SymbolTable, + metadata: tp.MutableMapping) -> PASS_ARGS_T: + if self.replace_and: + visitor = AndTransformer() + tree = visitor.visit(tree) + + if self.replace_or: + visitor = OrTransformer() + tree = visitor.visit(tree) + + if self.replace_not: + visitor = NotTransformer() + tree = visitor.visit(tree) + + return tree, env, metadata diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 9fd02d4..be04450 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -337,7 +337,10 @@ class ssa(Pass): def __init__(self, return_prefix: str = '__return_value'): self.return_prefix = return_prefix - def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping): + def rewrite(self, + tree: ast.AST, + env: SymbolTable, + metadata: tp.MutableMapping) -> PASS_ARGS_T: if not isinstance(tree, ast.FunctionDef): raise TypeError('ssa should only be applied to functions') r_name = gen_free_prefix(tree, env, self.return_prefix) diff --git a/tests/test_bool_to_bit.py b/tests/test_bool_to_bit.py new file mode 100644 index 0000000..e8a3399 --- /dev/null +++ b/tests/test_bool_to_bit.py @@ -0,0 +1,56 @@ +import ast +import inspect + +import pytest + + +from ast_tools.passes import begin_rewrite, end_rewrite, bool_to_bit + +def test_and(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def and_f(x, y): + return x and y + + assert inspect.getsource(and_f) == '''\ +def and_f(x, y): + return x & y +''' + +def test_or(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def or_f(x, y): + return x or y + + assert inspect.getsource(or_f) == '''\ +def or_f(x, y): + return x | y +''' + +def test_not(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def not_f(x): + return not x + + assert inspect.getsource(not_f) == '''\ +def not_f(x): + return ~x +''' + +def test_xor(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def xor(x, y): + return x and not y or not x and y + + assert inspect.getsource(xor) == '''\ +def xor(x, y): + return x & ~y | ~x & y +''' + From 63cd7d1f43c276b2c3d9f1d7efc36cbc30fe837d Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Thu, 22 Aug 2019 21:12:30 -0700 Subject: [PATCH 3/3] Dedupe --- ast_tools/passes/bool_to_bit.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ast_tools/passes/bool_to_bit.py b/ast_tools/passes/bool_to_bit.py index ce9020d..7a20cd5 100644 --- a/ast_tools/passes/bool_to_bit.py +++ b/ast_tools/passes/bool_to_bit.py @@ -8,39 +8,39 @@ __ALL__ = ['bool_to_bit'] -class AndTransformer(ast.NodeTransformer): +class BoolOpTransformer(ast.NodeTransformer): def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: # Can't get more specific on return type because if # len(node.values) == 1 (which it shouldn't be) # then the return type is expr otherwise # the return type is Union[BinOp, BoolOp] - if isinstance(node.op, ast.And): + if isinstance(node.op, self.match): values = node.values assert values # should not be empty expr = self.visit(values[0]) for v in map(self.visit, values[1:]): - expr = ast.BinOp(expr, ast.BitAnd(), v) + expr = ast.BinOp(expr, self.replace(), v) return expr else: return self.generic_visit(node) -class OrTransformer(ast.NodeTransformer): - def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: - if isinstance(node.op, ast.Or): - values = node.values - assert values # should not be empty - expr = self.visit(values[0]) - for v in map(self.visit, values[1:]): - expr = ast.BinOp(expr, ast.BitOr(), v) - return expr - else: - return self.generic_visit(node) + +class AndTransformer(BoolOpTransformer): + match = ast.And + replace = ast.BitAnd + + +class OrTransformer(BoolOpTransformer): + match = ast.Or + replace = ast.BitOr + class NotTransformer(ast.NodeTransformer): def visit_Not(self, node: ast.Not) -> ast.Invert: return ast.Invert() + class bool_to_bit(Pass): ''' Pass to replace bool operators (and, or, not)