Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ast_tools/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .debug import *
from .ssa import *
from .util import *
from .bool_to_bit import *
74 changes: 74 additions & 0 deletions ast_tools/passes/bool_to_bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import ast
import typing as tp

from . import Pass
from . import PASS_ARGS_T

from ast_tools.stack import SymbolTable

__ALL__ = ['bool_to_bit']

class BoolOpTransformer(ast.NodeTransformer):
def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr:
# Can't get more specific on return type because if
# len(node.values) == 1 (which it shouldn't be)
# then the return type is expr otherwise
# the return type is Union[BinOp, BoolOp]

if isinstance(node.op, self.match):
values = node.values
assert values # should not be empty
expr = self.visit(values[0])
for v in map(self.visit, values[1:]):
expr = ast.BinOp(expr, self.replace(), v)
return expr
else:
return self.generic_visit(node)


class AndTransformer(BoolOpTransformer):
match = ast.And
replace = ast.BitAnd


class OrTransformer(BoolOpTransformer):
match = ast.Or
replace = ast.BitOr


class NotTransformer(ast.NodeTransformer):
def visit_Not(self, node: ast.Not) -> ast.Invert:
return ast.Invert()


class bool_to_bit(Pass):
'''
Pass to replace bool operators (and, or, not)
with bit operators (&, |, ~)
'''
def __init__(self,
replace_and: bool = True,
replace_or: bool = True,
replace_not: bool = True,
):
self.replace_and = replace_and
self.replace_or = replace_or
self.replace_not = replace_not

def rewrite(self,
tree: ast.AST,
env: SymbolTable,
metadata: tp.MutableMapping) -> PASS_ARGS_T:
if self.replace_and:
visitor = AndTransformer()
tree = visitor.visit(tree)

if self.replace_or:
visitor = OrTransformer()
tree = visitor.visit(tree)

if self.replace_not:
visitor = NotTransformer()
tree = visitor.visit(tree)

return tree, env, metadata
10 changes: 4 additions & 6 deletions ast_tools/passes/ssa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import ast
from collections import ChainMap, Counter
import itertools
import warnings
import weakref
import typing as tp

import astor

from . import Pass
from . import PASS_ARGS_T
from ast_tools.common import gen_free_prefix, is_free_name
Expand Down Expand Up @@ -342,7 +337,10 @@ class ssa(Pass):
def __init__(self, return_prefix: str = '__return_value'):
self.return_prefix = return_prefix

def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping):
def rewrite(self,
tree: ast.AST,
env: SymbolTable,
metadata: tp.MutableMapping) -> PASS_ARGS_T:
if not isinstance(tree, ast.FunctionDef):
raise TypeError('ssa should only be applied to functions')
r_name = gen_free_prefix(tree, env, self.return_prefix)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_bool_to_bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import ast
import inspect

import pytest


from ast_tools.passes import begin_rewrite, end_rewrite, bool_to_bit

def test_and():
@end_rewrite()
@bool_to_bit()
@begin_rewrite()
def and_f(x, y):
return x and y

assert inspect.getsource(and_f) == '''\
def and_f(x, y):
return x & y
'''

def test_or():
@end_rewrite()
@bool_to_bit()
@begin_rewrite()
def or_f(x, y):
return x or y

assert inspect.getsource(or_f) == '''\
def or_f(x, y):
return x | y
'''

def test_not():
@end_rewrite()
@bool_to_bit()
@begin_rewrite()
def not_f(x):
return not x

assert inspect.getsource(not_f) == '''\
def not_f(x):
return ~x
'''

def test_xor():
@end_rewrite()
@bool_to_bit()
@begin_rewrite()
def xor(x, y):
return x and not y or not x and y

assert inspect.getsource(xor) == '''\
def xor(x, y):
return x & ~y | ~x & y
'''