Skip to content

Commit

Permalink
Merge 0c38342 into 13b4725
Browse files Browse the repository at this point in the history
  • Loading branch information
THofstee committed Oct 5, 2018
2 parents 13b4725 + 0c38342 commit 0d8d195
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 105 deletions.
6 changes: 4 additions & 2 deletions silica/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ def visit_Assign(self, node):
if isinstance(node.targets[0], ast.Name):
if isinstance(node.value, ast.Yield):
pass # width specified at compile time
elif isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == "coroutine_create":
elif isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and \
node.value.func.id == "coroutine_create":
pass
elif isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute) and node.value.func.attr == "send":
elif isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Attribute) and \
node.value.func.attr == "send":
pass
elif node.targets[0].id not in self.width_table:
self.width_table[node.targets[0].id] = get_width(node.value, self.width_table)
Expand Down
21 changes: 20 additions & 1 deletion silica/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from silica.type_check import TypeChecker
from silica.analysis import CollectInitialWidthsAndTypes
from silica.transformations.promote_widths import PromoteWidths
from silica.transformations.desugar_for_loops import propagate_types, get_final_widths

import veriloggen as vg

Expand Down Expand Up @@ -81,16 +82,34 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog',
inline_yield_from_functions(tree, func_globals, func_locals)
constant_fold(tree)
specialize_list_comps(tree, func_globals, func_locals)
desugar_for_loops(tree)
tree, list_lens = propagate_types(tree)
tree, loopvars = desugar_for_loops(tree, list_lens)

ast_utils.print_ast(tree)

width_table = {}
if coroutine._inputs:
for input_, type_ in coroutine._inputs.items():
width_table[input_] = get_input_width(type_)

for name,width in loopvars:
width_table[name] = width

constant_fold(tree)
type_table = {}

for name,_ in loopvars:
type_table[name] = 'uint'

CollectInitialWidthsAndTypes(width_table, type_table).visit(tree)
PromoteWidths(width_table, type_table).visit(tree)
tree, loopvars = get_final_widths(tree, width_table, func_locals, func_globals)

ast_utils.print_ast(tree)

for name,width in loopvars.items():
width_table[name] = width

# Desugar(width_table).visit(tree)
type_table = {}
TypeChecker(width_table, type_table).check(tree)
Expand Down
6 changes: 6 additions & 0 deletions silica/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ def __init__(self, height, width):

def __eq__(self, other):
return isinstance(other, MemoryType) and self.height == other.height and self.width == other.width

def __str__(self):
return "<MemoryType: {}, {}>".format(self.height, self.width)

def __repr__(self):
return "<MemoryType: {}, {}>".format(self.height, self.width)
223 changes: 146 additions & 77 deletions silica/transformations/desugar_for_loops.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,157 @@
import ast
import itertools
import silica
from silica.ast_utils import *
from .replace_symbols import replace_symbols
from .constant_fold import constant_fold
from silica.memory import MemoryType
from copy import deepcopy
import astor

class ForLoopDesugarer(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.loopvars = set()

def visit_For(self, node):
new_body = []
for s in node.body:
result = self.visit(s)
if isinstance(result, list):
new_body.extend(result)
else:
new_body.append(result)
node.body = new_body
if is_call(node.iter) and is_name(node.iter.func) and \
node.iter.func.id == "range":
try:
range_object = eval(astor.to_source(node.iter))
body = []
for i in range_object:
symbol_table = {node.target.id: ast.Num(i)}
for child in node.body:
body.append(
constant_fold(replace_symbols(deepcopy(child), symbol_table))
)
return body
except Exception as e:
pass
# TODO: Support mixed regular and keyword args
if 4 > len(node.iter.args) > 0:
assert isinstance(node.target, ast.Name)
if len(node.iter.args) <= 2:
incr = 1
counter = itertools.count()

def get_call_func(node):
if is_name(node.func):
return node.func.id
# Should handle nested expressions/alternate types
raise NotImplementedError(type(node.value.func)) # pragma: no cover


def desugar_for_loops(tree, type_table):
class ForLoopDesugarer(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.loopvars = set()

def visit_For(self, node):
new_body = []
for s in node.body:
result = self.visit(s)
if isinstance(result, list):
new_body.extend(result)
else:
incr = eval(astor.to_source(node.iter.args[2]))
if len(node.iter.args) == 1:
start = 0
stop = node.iter.args[0]
new_body.append(result)
node.body = new_body

# range() iterator
if is_call(node.iter) and is_name(node.iter.func) and \
node.iter.func.id == "range":
try:
range_object = eval(astor.to_source(node.iter))
body = []
for i in range_object:
symbol_table = {node.target.id: ast.Num(i)}
for child in node.body:
body.append(
constant_fold(replace_symbols(deepcopy(child), symbol_table))
)
return body
except Exception as e:
pass

if len(node.iter.args) > 0:
if not len(node.iter.args) < 4:
raise TypeError("range expected at most 3 arguments, got {}"
.format(len(node.iter.args)))
if not is_name(node.target):
raise NotImplementedError("Target is not a named argument: `{}`"
.format(to_source(node.target)))

start = ast.Num(0)
stop = node.iter.args[0]
step = ast.Num(1)

if len(node.iter.args) > 1:
start = node.iter.args[0]
stop = node.iter.args[1]

if len(node.iter.args) == 3:
step = node.iter.args[2]
else:
start = node.iter.args[0].n
stop = node.iter.args[1]
else:
assert len(node.iter.keywords)
start = 0
stop = None
incr = 1
for keyword in node.iter.keywords:
if keyword.arg == "start":
start = keyword.value
elif keyword.arg == "stop":
stop = keyword.value
elif keyword.arg == "step":
step = keyword.value
assert stop is not None
width = None
if isinstance(stop, ast.Num):
width = (stop.n - 1).bit_length()
else:
width = None
for keyword in node.iter.keywords:
if keyword.arg == "bit_width":
assert isinstance(keyword.value, ast.Num)
width = keyword.value.n
break
assert width is not None
self.loopvars.add((node.target.id, width))
return [
ast.Assign([ast.Name(node.target.id, ast.Store())], ast.parse(f"bits({start}, {width})").body[0].value),
ast.Assign([ast.Name(node.target.id + "_cout", ast.Store())], ast.NameConstant(False)),
ast.While(ast.parse(f"not_({node.target.id}_cout)").body[0].value,
node.body + [
ast.parse(f"{node.target.id}, {node.target.id}_cout = add({node.target.id}, bits({incr}, {width}), cout=True)").body[0]
], [])
]
else: # pragma: no cover
print_ast(node)
raise NotImplementedError("Unsupported for loop construct `{}`".format(to_source(node.iter)))

def desugar_for_loops(tree):
raise NotImplementedError("keyword iterators are not yet supported")

bit_width = eval(astor.to_source(stop).rstrip() + "-" + astor.to_source(start).rstrip()).bit_length()
self.loopvars.add((node.target.id, bit_width))

return [
ast.Assign([ast.Name(node.target.id, ast.Store())], start),
ast.While(ast.BinOp(ast.Name(node.target.id, ast.Load()), ast.Lt(), stop),
node.body + [
ast.Assign([ast.Name(node.target.id, ast.Store())], ast.BinOp(
ast.Name(node.target.id, ast.Load()), ast.Add(), step))
], [])
]
# for _ in name
elif is_name(node.iter):
# TODO: currently just assumes that this is a list
# TODO: maybe this should just reshape the for-loop to use a range()
# iterator and then just call the previous branch
index = '__silica_count_' + str(next(counter))
start = ast.Num(0)
stop = ast.Num(type_table[node.iter.id][1])
step = ast.Num(1)

bit_width = stop.n.bit_length()
self.loopvars.add((index, bit_width))

return [
ast.Assign([ast.Name(index, ast.Store())], start),
ast.While(ast.BinOp(ast.Name(index, ast.Load()), ast.Lt(), stop),
[ast.Assign([ast.Name(node.target.id, ast.Store())],
ast.Subscript(node.iter, ast.Index(ast.Name(index, ast.Load()))))] +
node.body + [
ast.Assign([ast.Name(index, ast.Store())], ast.BinOp(
ast.Name(index, ast.Load()), ast.Add(), step))
], [])
]
else: # pragma: no cover
print_ast(node)
raise NotImplementedError("Unsupported for loop construct `{}`".format(to_source(node.iter)))


desugarer = ForLoopDesugarer()
desugarer.visit(tree)
return tree, desugarer.loopvars

def propagate_types(tree):
class TypePropagator(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.loopvars = {}

def visit_Assign(self, node):
self.generic_visit(node)
if is_list(node.value):
self.loopvars[node.targets[0].id] = ('list', len(node.value.elts))

return node

propagator = TypePropagator()
propagator.visit(tree)
return tree, propagator.loopvars

def get_final_widths(tree, width_table, globals, locals):
class WidthCalculator(ast.NodeTransformer):
def __init__(self):
super().__init__()
self.loopvars = {}

def visit_Assign(self, node):
self.generic_visit(node)
if is_list(node.value):
fst = node.value.elts[0]
if is_name(fst):
self.loopvars[node.targets[0].id] = MemoryType(len(node.value.elts), width_table[fst.id])
elif is_call(fst):
if fst.func.id == 'bits':
self.loopvars[node.targets[0].id] = MemoryType(len(node.value.elts), fst.args[1].n)
else:
raise NotImplementedError("Getting width of `{}` is unsupported.".format(to_source(fst)))
else:
raise NotImplementedError("Getting width of `{}` is unsupported.".format(to_source(fst)))

return node

wc = WidthCalculator()
wc.visit(tree)
return tree, wc.loopvars
2 changes: 1 addition & 1 deletion silica/transformations/promote_widths.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, width_table, type_table):

def check_valid(self, int_length, expected_length):
if int_length > expected_length:
raise TypeError("Cannot promote integer with greated width than other operand")
raise TypeError("Cannot promote integer with greater width than other operand")

def make(self, value, width, type_):
return ast.parse(f"{type_}({value}, {width})").body[0].value
Expand Down
Loading

0 comments on commit 0d8d195

Please sign in to comment.