diff --git a/ast_tools/passes/__init__.py b/ast_tools/passes/__init__.py index bd6e3a7..87324b2 100644 --- a/ast_tools/passes/__init__.py +++ b/ast_tools/passes/__init__.py @@ -3,3 +3,4 @@ from .debug import * from .ssa import * from .util import * +from .bool_to_bit import * diff --git a/ast_tools/passes/bool_to_bit.py b/ast_tools/passes/bool_to_bit.py new file mode 100644 index 0000000..7a20cd5 --- /dev/null +++ b/ast_tools/passes/bool_to_bit.py @@ -0,0 +1,74 @@ +import ast +import typing as tp + +from . import Pass +from . import PASS_ARGS_T + +from ast_tools.stack import SymbolTable + +__ALL__ = ['bool_to_bit'] + +class BoolOpTransformer(ast.NodeTransformer): + def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: + # Can't get more specific on return type because if + # len(node.values) == 1 (which it shouldn't be) + # then the return type is expr otherwise + # the return type is Union[BinOp, BoolOp] + + if isinstance(node.op, self.match): + values = node.values + assert values # should not be empty + expr = self.visit(values[0]) + for v in map(self.visit, values[1:]): + expr = ast.BinOp(expr, self.replace(), v) + return expr + else: + return self.generic_visit(node) + + +class AndTransformer(BoolOpTransformer): + match = ast.And + replace = ast.BitAnd + + +class OrTransformer(BoolOpTransformer): + match = ast.Or + replace = ast.BitOr + + +class NotTransformer(ast.NodeTransformer): + def visit_Not(self, node: ast.Not) -> ast.Invert: + return ast.Invert() + + +class bool_to_bit(Pass): + ''' + Pass to replace bool operators (and, or, not) + with bit operators (&, |, ~) + ''' + def __init__(self, + replace_and: bool = True, + replace_or: bool = True, + replace_not: bool = True, + ): + self.replace_and = replace_and + self.replace_or = replace_or + self.replace_not = replace_not + + def rewrite(self, + tree: ast.AST, + env: SymbolTable, + metadata: tp.MutableMapping) -> PASS_ARGS_T: + if self.replace_and: + visitor = AndTransformer() + tree = visitor.visit(tree) + + if self.replace_or: + visitor = OrTransformer() + tree = visitor.visit(tree) + + if self.replace_not: + visitor = NotTransformer() + tree = visitor.visit(tree) + + return tree, env, metadata diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index 97d8240..be04450 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -1,12 +1,7 @@ import ast from collections import ChainMap, Counter -import itertools -import warnings -import weakref import typing as tp -import astor - from . import Pass from . import PASS_ARGS_T from ast_tools.common import gen_free_prefix, is_free_name @@ -342,7 +337,10 @@ class ssa(Pass): def __init__(self, return_prefix: str = '__return_value'): self.return_prefix = return_prefix - def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping): + def rewrite(self, + tree: ast.AST, + env: SymbolTable, + metadata: tp.MutableMapping) -> PASS_ARGS_T: if not isinstance(tree, ast.FunctionDef): raise TypeError('ssa should only be applied to functions') r_name = gen_free_prefix(tree, env, self.return_prefix) diff --git a/tests/test_bool_to_bit.py b/tests/test_bool_to_bit.py new file mode 100644 index 0000000..e8a3399 --- /dev/null +++ b/tests/test_bool_to_bit.py @@ -0,0 +1,56 @@ +import ast +import inspect + +import pytest + + +from ast_tools.passes import begin_rewrite, end_rewrite, bool_to_bit + +def test_and(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def and_f(x, y): + return x and y + + assert inspect.getsource(and_f) == '''\ +def and_f(x, y): + return x & y +''' + +def test_or(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def or_f(x, y): + return x or y + + assert inspect.getsource(or_f) == '''\ +def or_f(x, y): + return x | y +''' + +def test_not(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def not_f(x): + return not x + + assert inspect.getsource(not_f) == '''\ +def not_f(x): + return ~x +''' + +def test_xor(): + @end_rewrite() + @bool_to_bit() + @begin_rewrite() + def xor(x, y): + return x and not y or not x and y + + assert inspect.getsource(xor) == '''\ +def xor(x, y): + return x & ~y | ~x & y +''' +