Skip to content

Commit

Permalink
Merge 86c8c38 into e2e0ee4
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Dec 14, 2018
2 parents e2e0ee4 + 86c8c38 commit 6ba4941
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 38 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
astor
graphviz
git+git://github.com/phanrahan/magma.git#egg=magma
git+git://github.com/phanrahan/magma.git@ssa#egg=magma
git+git://github.com/phanrahan/mantle.git#egg=mantle
git+git://github.com/phanrahan/loam.git#egg=loam
coreir
Expand Down
1 change: 1 addition & 0 deletions silica/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .coroutine import coroutine, Coroutine, generator, Generator
from .function import function, compile_function
from .compile import compile
from ._config import Config
config = Config()
Expand Down
39 changes: 4 additions & 35 deletions silica/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from silica.visitors import collect_names
import silica.verilog as verilog
from .memory import MemoryType
from .width import get_io_width
from silica.transformations.specialize_arguments import specialize_arguments
from silica.type_check import TypeChecker
from silica.type_check import TypeChecker, to_type_str
from silica.analysis import CollectInitialWidthsAndTypes, collect_sub_coroutines
from silica.transformations.promote_widths import PromoteWidths
from silica.transformations.desugar_for_loops import propagate_types, get_final_widths
Expand Down Expand Up @@ -47,51 +48,19 @@ def visit_Call(self, node):
ListCompSpecializer().visit(tree)


def get_io_width(type_):
if type_ is magma.Bit:
return None
elif isinstance(type_, magma.ArrayKind):
if isinstance(type_.T, magma.ArrayKind):
elem_width = get_io_width(type_.T)
if isinstance(elem_width, tuple):
return (type_.N, ) + elem_width
else:
return (type_.N, elem_width)
else:
return type_.N
else:
raise NotImplementedError(type_)


def add_coroutine_to_tables(coroutine, width_table, type_table, sub_coroutine_name=None):
if coroutine._inputs:
for input_, type_ in coroutine._inputs.items():
if sub_coroutine_name:
input_ = "_si_sub_co_" + sub_coroutine_name + "_" + input_
width_table[input_] = get_io_width(type_)
if isinstance(type_, m.BitKind):
type_ = "bit"
elif isinstance(type_, m.UIntKind):
type_ = "uint"
elif isinstance(type_, m.BitsKind):
type_ = "bits"
else:
raise NotImplementedError(type_)
type_table[input_] = type_
type_table[input_] = to_type_str(type_)
if coroutine._outputs:
for output, type_ in coroutine._outputs.items():
if sub_coroutine_name:
output = "_si_sub_co_" + sub_coroutine_name + "_" + output
width_table[output] = get_io_width(type_)
if isinstance(type_, m.BitKind):
type_ = "bit"
elif isinstance(type_, m.UIntKind):
type_ = "uint"
elif isinstance(type_, m.BitsKind):
type_ = "bits"
else:
raise NotImplementedError(type_)
type_table[output] = type_
type_table[output] = to_type_str(type_)


def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog', strategy="by_statement"):
Expand Down
66 changes: 66 additions & 0 deletions silica/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import types
import magma as m
import silica.verilog as verilog
import inspect
from silica.visitors import collect_stores
from silica.analysis import CollectInitialWidthsAndTypes
from silica.type_check import to_type_str
from silica.width import get_io_width
import ast
import veriloggen as vg


class ReturnReplacer(ast.NodeTransformer):
def visit_Return(self, node):
if isinstance(node.value, ast.Tuple):
raise NotImplementedError()
return ast.Assign([ast.Name("O", ast.Store())], node.value)


def function(fn : types.FunctionType):
""" TODO: Do implicit conversion of Python types to BitVector """
return fn


def compile_function(fn : types.FunctionType, file_name : str):
stack = inspect.stack()
func_locals = stack[1].frame.f_locals
func_globals = stack[1].frame.f_globals
tree = m.ast_utils.get_func_ast(fn)
ctx = verilog.Context(tree.name)
width_table = {}
type_table = {}

def get_len(t):
try:
return len(t)
except Exception:
return 1


inputs = inspect.getfullargspec(fn).annotations
for input_, type_ in inputs.items():
type_table[input_] = to_type_str(type_)
width_table[input_] = get_io_width(type_)

inputs = { i : get_len(t) for i,t in inputs.items() }
output = inspect.signature(fn).return_annotation
assert not isinstance(output, tuple), "Not implemented"
type_table["O"] = to_type_str(output)
width_table["O"] = get_io_width(output)
outputs = { "O" : get_len(output) }

CollectInitialWidthsAndTypes(width_table, type_table, func_locals, func_globals).visit(tree)

ctx.declare_ports(inputs, outputs)

stores = collect_stores(tree)
for store in stores:
ctx.declare_wire(store, width_table[store])

tree = ReturnReplacer().visit(tree)

ctx.module.Always(vg.SensitiveAll())([ctx.translate(s) for s in tree.body])

with open(file_name, "w") as f:
f.write(ctx.to_verilog())
12 changes: 12 additions & 0 deletions silica/type_check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import ast
from .width import get_width
import astor
import magma as m


def to_type_str(type_):
if isinstance(type_, m.BitKind):
return "bit"
elif isinstance(type_, m.UIntKind):
return "uint"
elif isinstance(type_, m.BitsKind):
return "bits"
else:
raise NotImplementedError(type_)


class TypeChecker(ast.NodeVisitor):
Expand Down
9 changes: 7 additions & 2 deletions silica/verilog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .cfg.util import find_branch_join

class Context:
def __init__(self, name, sub_coroutines):
def __init__(self, name, sub_coroutines=[]):
self.module = vg.Module(name)
self.sub_coroutines = sub_coroutines

Expand Down Expand Up @@ -117,9 +117,14 @@ def translate(self, stmt):
return vg.Eq
elif is_if(stmt):
body = [self.translate(stmt) for stmt in stmt.body]
return vg.If(
if_ = vg.If(
self.translate(stmt.test),
)(body)
if stmt.orelse:
if_.Else(
[self.translate(stmt) for stmt in stmt.orelse]
)
return if_
elif is_if_exp(stmt):
return vg.Cond(
self.translate(stmt.test),
Expand Down
16 changes: 16 additions & 0 deletions silica/width.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,19 @@ def get_width(node, width_table, func_locals={}, func_globals={}):


raise NotImplementedError(ast.dump(node))


def get_io_width(type_):
if type_ is m.Bit:
return None
elif isinstance(type_, m.ArrayKind):
if isinstance(type_.T, m.ArrayKind):
elem_width = get_io_width(type_.T)
if isinstance(elem_width, tuple):
return (type_.N, ) + elem_width
else:
return (type_.N, elem_width)
else:
return type_.N
else:
raise NotImplementedError(type_)
136 changes: 136 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import silica as si


# TODO: These were adapted from magma's SSA tests, perhaps we can find a nice
# way to reuse test functions?
def test_basic():
@si.function
def basic_if(I: si.Bits(2), S: si.Bit) -> si.Bit:
if S:
x = I[0]
else:
x = I[1]
return x

si.compile_function(basic_if, file_name="tests/build/basic_if.v")
with open("tests/build/basic_if.v", "r") as f:
result = f.read()
print(result)
assert result == """\
module basic_if
(
output O,
input return,
input [2-1:0] I,
input S
);
wire x;
always @(*) begin
if(S) begin
x = I[0];
end else begin
x = I[1];
end
O = x;
end
endmodule
"""


def test_default():
@si.function
def default(I: si.Bits(2), S: si.Bit) -> si.Bit:
x = I[1]
if S:
x = I[0]
return x

si.compile_function(default, file_name="tests/build/default.v")
with open("tests/build/default.v", "r") as f:
result = f.read()
print(result)
assert result == """\
module default
(
output O,
input return,
input [2-1:0] I,
input S
);
wire x;
always @(*) begin
x = I[1];
if(S) begin
x = I[0];
end
O = x;
end
endmodule
"""


def test_nested():
@si.function
def nested(I: si.Bits(4), S: si.Bits(2)) -> si.Bit:
if S[0]:
if S[1]:
x = I[0]
else:
x = I[1]
else:
if S[1]:
x = I[2]
else:
x = I[3]
return x

si.compile_function(nested, file_name="tests/build/nested.v")
with open("tests/build/nested.v", "r") as f:
result = f.read()
print(result)
assert result == """\
module nested
(
output O,
input return,
input [4-1:0] I,
input [2-1:0] S
);
wire x;
always @(*) begin
if(S[0]) begin
if(S[1]) begin
x = I[0];
end else begin
x = I[1];
end
end else if(S[1]) begin
x = I[2];
end else begin
x = I[3];
end
O = x;
end
endmodule
"""

0 comments on commit 6ba4941

Please sign in to comment.