-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
258 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,88 +1,163 @@ | ||
import ast | ||
import itertools | ||
import silica | ||
from silica.ast_utils import * | ||
from .replace_symbols import replace_symbols | ||
from .constant_fold import constant_fold | ||
from silica.memory import MemoryType | ||
from copy import deepcopy | ||
|
||
class ForLoopDesugarer(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.loopvars = set() | ||
|
||
def visit_For(self, node): | ||
new_body = [] | ||
for s in node.body: | ||
result = self.visit(s) | ||
if isinstance(result, list): | ||
new_body.extend(result) | ||
else: | ||
new_body.append(result) | ||
node.body = new_body | ||
if is_call(node.iter) and is_name(node.iter.func) and \ | ||
node.iter.func.id == "range": | ||
try: | ||
range_object = eval(astor.to_source(node.iter)) | ||
body = [] | ||
for i in range_object: | ||
symbol_table = {node.target.id: ast.Num(i)} | ||
for child in node.body: | ||
body.append( | ||
constant_fold(replace_symbols(deepcopy(child), symbol_table)) | ||
) | ||
return body | ||
except Exception as e: | ||
pass | ||
# TODO: Support mixed regular and keyword args | ||
if 4 > len(node.iter.args) > 0: | ||
assert isinstance(node.target, ast.Name) | ||
if len(node.iter.args) <= 2: | ||
incr = 1 | ||
# TODO: should probably delete | ||
import astor | ||
|
||
counter = itertools.count() | ||
|
||
def get_call_func(node): | ||
if is_name(node.func): | ||
return node.func.id | ||
# Should handle nested expressions/alternate types | ||
raise NotImplementedError(type(node.value.func)) # pragma: no cover | ||
|
||
|
||
|
||
def desugar_for_loops(tree, type_table): | ||
class ForLoopDesugarer(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.loopvars = set() | ||
|
||
def visit_For(self, node): | ||
new_body = [] | ||
for s in node.body: | ||
result = self.visit(s) | ||
if isinstance(result, list): | ||
new_body.extend(result) | ||
else: | ||
incr = eval(astor.to_source(node.iter.args[2])) | ||
if len(node.iter.args) == 1: | ||
start = 0 | ||
stop = node.iter.args[0] | ||
new_body.append(result) | ||
node.body = new_body | ||
|
||
|
||
|
||
# range() iterator | ||
if is_call(node.iter) and is_name(node.iter.func) and \ | ||
node.iter.func.id == "range": | ||
try: | ||
range_object = eval(astor.to_source(node.iter)) | ||
body = [] | ||
for i in range_object: | ||
symbol_table = {node.target.id: ast.Num(i)} | ||
for child in node.body: | ||
body.append( | ||
constant_fold(replace_symbols(deepcopy(child), symbol_table)) | ||
) | ||
return body | ||
except Exception as e: | ||
pass | ||
|
||
if len(node.iter.args) > 0: | ||
if not len(node.iter.args) < 4: | ||
raise TypeError("range expected at most 3 arguments, got {}" | ||
.format(len(node.iter.args))) | ||
if not is_name(node.target): | ||
raise NotImplementedError("Target is not a named argument: `{}`" | ||
.format(to_source(node.target))) | ||
|
||
start = ast.Num(0) | ||
stop = node.iter.args[0] | ||
step = ast.Num(1) | ||
|
||
if len(node.iter.args) > 1: | ||
start = node.iter.args[0] | ||
stop = node.iter.args[1] | ||
|
||
if len(node.iter.args) == 3: | ||
step = node.iter.args[2] | ||
else: | ||
start = node.iter.args[0].n | ||
stop = node.iter.args[1] | ||
else: | ||
assert len(node.iter.keywords) | ||
start = 0 | ||
stop = None | ||
incr = 1 | ||
for keyword in node.iter.keywords: | ||
if keyword.arg == "start": | ||
start = keyword.value | ||
elif keyword.arg == "stop": | ||
stop = keyword.value | ||
elif keyword.arg == "step": | ||
step = keyword.value | ||
assert stop is not None | ||
width = None | ||
if isinstance(stop, ast.Num): | ||
width = (stop.n - 1).bit_length() | ||
else: | ||
width = None | ||
for keyword in node.iter.keywords: | ||
if keyword.arg == "bit_width": | ||
assert isinstance(keyword.value, ast.Num) | ||
width = keyword.value.n | ||
break | ||
assert width is not None | ||
self.loopvars.add((node.target.id, width)) | ||
return [ | ||
ast.Assign([ast.Name(node.target.id, ast.Store())], ast.parse(f"bits({start}, {width})").body[0].value), | ||
ast.Assign([ast.Name(node.target.id + "_cout", ast.Store())], ast.NameConstant(False)), | ||
ast.While(ast.parse(f"not_({node.target.id}_cout)").body[0].value, | ||
node.body + [ | ||
ast.parse(f"{node.target.id}, {node.target.id}_cout = add({node.target.id}, bits({incr}, {width}), cout=True)").body[0] | ||
], []) | ||
] | ||
else: # pragma: no cover | ||
print_ast(node) | ||
raise NotImplementedError("Unsupported for loop construct `{}`".format(to_source(node.iter))) | ||
|
||
def desugar_for_loops(tree): | ||
raise NotImplementedError("keyword iterators are not yet supported") | ||
|
||
bit_width = eval(astor.to_source(stop).rstrip() + "-" + astor.to_source(start).rstrip()).bit_length() | ||
self.loopvars.add((node.target.id, bit_width)) | ||
|
||
return [ | ||
ast.Assign([ast.Name(node.target.id, ast.Store())], start), | ||
ast.While(ast.BinOp(ast.Name(node.target.id, ast.Load()), ast.Lt(), stop), | ||
node.body + [ | ||
ast.Assign([ast.Name(node.target.id, ast.Store())], ast.BinOp( | ||
ast.Name(node.target.id, ast.Load()), ast.Add(), step)) | ||
], []) | ||
] | ||
# for _ in name | ||
elif is_name(node.iter): | ||
# TODO: currently just assumes that this is a list | ||
# TODO: maybe this should just reshape the for-loop to use a range() | ||
# iterator and then just call the previous branch | ||
# TODO: index should be an unused identifier | ||
index = '__silica_count_' + str(next(counter)) | ||
start = ast.Num(0) | ||
stop = ast.Num(type_table[node.iter.id][1]) | ||
step = ast.Num(1) | ||
|
||
bit_width = stop.n.bit_length() | ||
self.loopvars.add((index, bit_width)) | ||
|
||
return [ | ||
ast.Assign([ast.Name(index, ast.Store())], start), | ||
ast.While(ast.BinOp(ast.Name(index, ast.Load()), ast.Lt(), stop), | ||
[ast.Assign([ast.Name(node.target.id, ast.Store())], | ||
ast.Subscript(node.iter, ast.Index(ast.Name(index, ast.Load()))))] + | ||
node.body + [ | ||
ast.Assign([ast.Name(index, ast.Store())], ast.BinOp( | ||
ast.Name(index, ast.Load()), ast.Add(), step)) | ||
], []) | ||
] | ||
else: # pragma: no cover | ||
print_ast(node) | ||
raise NotImplementedError("Unsupported for loop construct `{}`".format(to_source(node.iter))) | ||
|
||
|
||
desugarer = ForLoopDesugarer() | ||
desugarer.visit(tree) | ||
return tree, desugarer.loopvars | ||
|
||
def propagate_types(tree): | ||
class TypePropagator(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.loopvars = {} | ||
|
||
def visit_Assign(self, node): | ||
self.generic_visit(node) | ||
if is_list(node.value): | ||
self.loopvars[node.targets[0].id] = ('list', len(node.value.elts)) | ||
|
||
return node | ||
|
||
propagator = TypePropagator() | ||
propagator.visit(tree) | ||
return tree, propagator.loopvars | ||
|
||
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaa(tree, width_table, globals, locals): | ||
class AAAAAA(ast.NodeTransformer): | ||
def __init__(self): | ||
super().__init__() | ||
self.loopvars = {} | ||
|
||
def visit_Assign(self, node): | ||
self.generic_visit(node) | ||
if is_list(node.value): | ||
fst = node.value.elts[0] | ||
if is_name(fst): | ||
self.loopvars[node.targets[0].id] = MemoryType(len(node.value.elts), width_table[fst.id]) | ||
elif is_call(fst): | ||
if fst.func.id == 'bits': | ||
self.loopvars[node.targets[0].id] = MemoryType(len(node.value.elts), fst.args[1].n) | ||
else: | ||
raise NotImplementedError("Getting width of `{}` is unsupported.".format(to_source(fst))) | ||
else: | ||
raise NotImplementedError("Getting width of `{}` is unsupported.".format(to_source(fst))) | ||
|
||
return node | ||
|
||
aaaa = AAAAAA() | ||
aaaa.visit(tree) | ||
return tree, aaaa.loopvars |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.