Skip to content

Commit

Permalink
Merge 16922e7 into 5129e5b
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt authored Sep 13, 2019
2 parents 5129e5b + 16922e7 commit f4c1381
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 11 deletions.
1 change: 1 addition & 0 deletions silica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .function import function, compile_function
from .compile import compile
from ._config import Config
from .types import Channel, In, Out
config = Config()

import magma as m
Expand Down
172 changes: 172 additions & 0 deletions silica/channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import astor
import silica as si
import ast
import silica.types as types
import inspect


def desugar_channels(tree, coroutine):
new_inputs = {}
channels = {}
for input_, type_ in coroutine._inputs.items():
if not isinstance(type_, types.Channel):
new_inputs[input_] = type_
else:
if coroutine._outputs is inspect._empty:
coroutine._outputs = {}
channels[input_] = type_
if type_.direction is types.In:
new_inputs[input_ + "_data"] = type_.type_
new_inputs[input_ + "_valid"] = si.Bit
coroutine._outputs[input_ + "_ready"] = si.Bit
else:
new_inputs[input_ + "_ready"] = si.Bit
coroutine._outputs[input_ + "_data"] = type_.type_
coroutine._outputs[input_ + "_valid"] = si.Bit
coroutine._inputs = new_inputs

defaults = []
for name, channel in channels.items():
if channel.direction is types.Out:
defaults.append(f"{name}_valid = 0")
defaults.append(f"{name}_data = 0")
else:
defaults.append(f"{name}_ready = 0")
default_str = '\n '.join(defaults)

class Transformer(ast.NodeTransformer):
def visit_Expr(self, node):
node.value = self.visit(node.value)
if isinstance(node.value, list):
return node.value
return node

def visit(self, node):
node = super().visit(node)
for block_field in ["body", "orelse"]:
if hasattr(node, block_field) and not isinstance(node, ast.IfExp):
new_block = []
for statement in getattr(node, block_field):
if isinstance(statement, list):
new_block.extend(statement)
else:
new_block.append(statement)
setattr(node, block_field, new_block)
return node

def visit_Call(self, node):
if isinstance(node, ast.Call) and \
isinstance(node.func, ast.Attribute) and \
isinstance(node.func.value, ast.Name) and \
node.func.value.id in channels:
if node.func.attr == "is_full":
return ast.UnaryOp(ast.Invert(),
ast.Name(node.func.value.id + "_ready",
ast.Load()))
elif node.func.attr == "is_empty":
return ast.UnaryOp(ast.Invert(),
ast.Name(node.func.value.id + "_valid",
ast.Load()))
elif node.func.attr == "push":
inputs_str = ", ".join(coroutine._inputs)
outputs_str = ", ".join(coroutine._outputs)
wait_block = ast.parse(f"""\
while True:
{node.func.value.id}_valid = 1
{node.func.value.id}_data = {astor.to_source(node.args[0]).rstrip()}
if {node.func.value.id}_ready:
break
{inputs_str} = yield {outputs_str}
{default_str}
""").body
return wait_block
else:
assert False, f"Got unexpected channel method {astor.to_source(node).rstrip()}"
return node

def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and \
isinstance(node.value.func, ast.Attribute) and \
isinstance(node.value.func.value, ast.Name) and \
node.value.func.value.id in channels:
assert node.value.func.attr == "pop", f"Got unexpected channel method {astor.to_source(node).rstrip()}"
inputs_str = ", ".join(coroutine._inputs)
outputs_str = ", ".join(coroutine._outputs)
wait_block = ast.parse(f"""\
while True:
{node.value.func.value.id}_ready = 1
if {node.value.func.value.id}_valid:
break
{inputs_str} = yield {outputs_str}
{default_str}
""").body
node.value = ast.Name(node.value.func.value.id + "_data", ast.Load())
return wait_block + [node]
if not isinstance(node.value, ast.Yield):
return node
assert len(node.targets) == 1
targets = node.targets[0]
if isinstance(targets, ast.Name):
targets = [targets]
elif isinstance(targets, ast.Tuple):
targets = targets.elts
new_targets = []
new_values = []
for target in targets:
assert isinstance(target, ast.Name)
if target.id in channels:
assert channels[target.id].direction is types.In
new_targets.extend((
ast.Name(target.id + "_data", ast.Store()),
ast.Name(target.id + "_valid", ast.Store())
))
new_values.append(
ast.Name(target.id + "_ready", ast.Load())
)
else:
new_targets.append(target)

value = node.value.value
if value is not None:
if isinstance(value, ast.Name):
values = [value]
else:
assert isinstance(value, ast.Tuple), ast.dump(value)
values = value.elts
for value in values:
assert isinstance(value, ast.Name)
if value.id in channels:
new_targets.append(
ast.Name(value.id + "_ready", ast.Store())
)
new_values.extend((
ast.Name(value.id + "_data", ast.Load()),
ast.Name(value.id + "_valid", ast.Load())
))
else:
new_values.append(value)
if len(new_values) == 1:
node.value.value = new_values[0]
else:
node.value.value = ast.Tuple(new_values, ast.Load())
else:
for name, channel in channels.items():
if channel.direction is types.Out:
new_targets.append(ast.Name(name + "_ready",
ast.Store()))
if len(new_targets) > 1:
node.targets = [ast.Tuple(new_targets, ast.Store())]
else:
node.targets = new_targets
defaults = []
for name, channel in channels.items():
if channel.direction is types.Out:
defaults.append(ast.parse(f"{name}_valid = 0").body[0])
defaults.append(ast.parse(f"{name}_data = 0").body[0])
else:
defaults.append(ast.parse(f"{name}_ready = 0").body[0])
return [node] + defaults

tree = Transformer().visit(tree)
print(astor.to_source(tree))
return tree, coroutine
2 changes: 2 additions & 0 deletions silica/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys

import silica
from silica.channel import desugar_channels
from silica.coroutine import Coroutine
from silica.cfg import ControlFlowGraph, BasicBlock, HeadBlock
from silica.cfg.control_flow_graph import render_paths_between_yields, build_state_info, render_fsm, get_constant
Expand Down Expand Up @@ -73,6 +74,7 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog',

has_ce = coroutine.has_ce
tree = ast_utils.get_ast(coroutine._definition).body[0] # Get the first element of the ast.Module
tree, coroutine = desugar_channels(tree, coroutine)
module_name = coroutine._name
func_locals.update(coroutine._defn_locals)
func_locals.update(func_globals)
Expand Down
3 changes: 3 additions & 0 deletions silica/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .width import get_width
import astor
import magma as m
import silica.types as types


def to_type_str(type_):
Expand All @@ -11,6 +12,8 @@ def to_type_str(type_):
return "uint"
elif isinstance(type_, m.BitsKind):
return "bits"
elif isinstance(type_, types.Channel):
return to_type_str(type_.type_)
else:
raise NotImplementedError(type_)

Expand Down
16 changes: 16 additions & 0 deletions silica/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import enum


class Direction(enum.Enum):
In = 0
Out = 1


In = Direction.In
Out = Direction.Out


class Channel:
def __init__(self, type_, direction):
self.type_ = type_
self.direction = direction
3 changes: 3 additions & 0 deletions silica/width.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import silica.types as types
import ast
import astor
from .memory import MemoryType
Expand Down Expand Up @@ -130,5 +131,7 @@ def get_io_width(type_):
return (type_.N, elem_width)
else:
return type_.N
elif isinstance(type_, types.Channel):
return get_io_width(type_.type_)
else:
raise NotImplementedError(type_)
26 changes: 16 additions & 10 deletions tests/test_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def Downsample(
data_out_valid = keep & data_in_valid
data_out_data = data_in_data
data_in_ready = data_out_ready | ~keep
if data_in_ready:
if data_in_ready & data_in_valid:
if x == 31:
data_in_valid, data_in_data, data_out_ready = yield \
data_in_ready, data_out_data, data_out_valid
Expand Down Expand Up @@ -100,8 +100,9 @@ def test_downsample_loops_simple():
tester.poke(magma_downsample.data_out_ready, 1)
tester.eval()
tester.expect(magma_downsample.data_out_data, y * 32 + x)
tester.expect(magma_downsample.data_out_valid, (y % 2 == 0) &
(x % 2 == 0))
if (y % 2 == 0) & (x % 2 == 0):
tester.expect(magma_downsample.data_out_valid,
(y % 2 == 0) & (x % 2 == 0))
tester.expect(magma_downsample.data_in_ready, 1)
tester.step(2)
tester.poke(magma_downsample.data_in_valid, 0)
Expand Down Expand Up @@ -132,21 +133,26 @@ def test_downsample_loops_simple_random_stalls():
for y in range(32):
for x in range(32):
while True:
in_valid = random.getrandbits(1)
in_valid = hwtypes.Bit(random.getrandbits(1))
tester.poke(magma_downsample.data_in_valid, in_valid)
tester.poke(magma_downsample.data_in_data, y * 32 + x)
out_ready = random.getrandbits(1)
tester.poke(magma_downsample.data_out_ready, out_ready)
tester.eval()
tester.expect(magma_downsample.data_out_data, y * 32 + x)
keep = hwtypes.Bit((y % 2 == 0) & (x % 2 == 0))
tester.expect(magma_downsample.data_out_valid,
keep & in_valid)
out_valid = keep & in_valid
if out_ready:
tester.expect(magma_downsample.data_out_valid,
out_valid)
if out_valid & out_ready:
tester.expect(magma_downsample.data_out_data,
y * 32 + x)
in_ready = hwtypes.Bit(out_ready) | ~keep
tester.expect(magma_downsample.data_in_ready,
in_ready)
if in_valid & in_ready:
tester.expect(magma_downsample.data_in_ready,
in_ready)
tester.step(2)
if in_ready:
if in_ready & in_valid:
break

tester.compile_and_run("verilator", flags=["-Wno-fatal"],
Expand Down
Loading

0 comments on commit f4c1381

Please sign in to comment.