Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel macros #24

Merged
merged 4 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions silica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .function import function, compile_function
from .compile import compile
from ._config import Config
from .types import Channel, In, Out
config = Config()

import magma as m
Expand Down
172 changes: 172 additions & 0 deletions silica/channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import astor
import silica as si
import ast
import silica.types as types
import inspect


def desugar_channels(tree, coroutine):
new_inputs = {}
channels = {}
for input_, type_ in coroutine._inputs.items():
if not isinstance(type_, types.Channel):
new_inputs[input_] = type_
else:
if coroutine._outputs is inspect._empty:
coroutine._outputs = {}
channels[input_] = type_
if type_.direction is types.In:
new_inputs[input_ + "_data"] = type_.type_
new_inputs[input_ + "_valid"] = si.Bit
coroutine._outputs[input_ + "_ready"] = si.Bit
else:
new_inputs[input_ + "_ready"] = si.Bit
coroutine._outputs[input_ + "_data"] = type_.type_
coroutine._outputs[input_ + "_valid"] = si.Bit
coroutine._inputs = new_inputs

defaults = []
for name, channel in channels.items():
if channel.direction is types.Out:
defaults.append(f"{name}_valid = 0")
defaults.append(f"{name}_data = 0")
else:
defaults.append(f"{name}_ready = 0")
default_str = '\n '.join(defaults)

class Transformer(ast.NodeTransformer):
def visit_Expr(self, node):
node.value = self.visit(node.value)
if isinstance(node.value, list):
return node.value
return node

def visit(self, node):
node = super().visit(node)
for block_field in ["body", "orelse"]:
if hasattr(node, block_field) and not isinstance(node, ast.IfExp):
new_block = []
for statement in getattr(node, block_field):
if isinstance(statement, list):
new_block.extend(statement)
else:
new_block.append(statement)
setattr(node, block_field, new_block)
return node

def visit_Call(self, node):
if isinstance(node, ast.Call) and \
isinstance(node.func, ast.Attribute) and \
isinstance(node.func.value, ast.Name) and \
node.func.value.id in channels:
if node.func.attr == "is_full":
return ast.UnaryOp(ast.Invert(),
ast.Name(node.func.value.id + "_ready",
ast.Load()))
elif node.func.attr == "is_empty":
return ast.UnaryOp(ast.Invert(),
ast.Name(node.func.value.id + "_valid",
ast.Load()))
elif node.func.attr == "push":
inputs_str = ", ".join(coroutine._inputs)
outputs_str = ", ".join(coroutine._outputs)
wait_block = ast.parse(f"""\
while True:
{node.func.value.id}_valid = 1
{node.func.value.id}_data = {astor.to_source(node.args[0]).rstrip()}
if {node.func.value.id}_ready:
break
{inputs_str} = yield {outputs_str}
{default_str}
""").body
return wait_block
else:
assert False, f"Got unexpected channel method {astor.to_source(node).rstrip()}"
return node

def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and \
isinstance(node.value.func, ast.Attribute) and \
isinstance(node.value.func.value, ast.Name) and \
node.value.func.value.id in channels:
assert node.value.func.attr == "pop", f"Got unexpected channel method {astor.to_source(node).rstrip()}"
inputs_str = ", ".join(coroutine._inputs)
outputs_str = ", ".join(coroutine._outputs)
wait_block = ast.parse(f"""\
while True:
{node.value.func.value.id}_ready = 1
if {node.value.func.value.id}_valid:
break
{inputs_str} = yield {outputs_str}
{default_str}
""").body
node.value = ast.Name(node.value.func.value.id + "_data", ast.Load())
return wait_block + [node]
if not isinstance(node.value, ast.Yield):
return node
assert len(node.targets) == 1
targets = node.targets[0]
if isinstance(targets, ast.Name):
targets = [targets]
elif isinstance(targets, ast.Tuple):
targets = targets.elts
new_targets = []
new_values = []
for target in targets:
assert isinstance(target, ast.Name)
if target.id in channels:
assert channels[target.id].direction is types.In
new_targets.extend((
ast.Name(target.id + "_data", ast.Store()),
ast.Name(target.id + "_valid", ast.Store())
))
new_values.append(
ast.Name(target.id + "_ready", ast.Load())
)
else:
new_targets.append(target)

value = node.value.value
if value is not None:
if isinstance(value, ast.Name):
values = [value]
else:
assert isinstance(value, ast.Tuple), ast.dump(value)
values = value.elts
for value in values:
assert isinstance(value, ast.Name)
if value.id in channels:
new_targets.append(
ast.Name(value.id + "_ready", ast.Store())
)
new_values.extend((
ast.Name(value.id + "_data", ast.Load()),
ast.Name(value.id + "_valid", ast.Load())
))
else:
new_values.append(value)
if len(new_values) == 1:
node.value.value = new_values[0]
else:
node.value.value = ast.Tuple(new_values, ast.Load())
else:
for name, channel in channels.items():
if channel.direction is types.Out:
new_targets.append(ast.Name(name + "_ready",
ast.Store()))
if len(new_targets) > 1:
node.targets = [ast.Tuple(new_targets, ast.Store())]
else:
node.targets = new_targets
defaults = []
for name, channel in channels.items():
if channel.direction is types.Out:
defaults.append(ast.parse(f"{name}_valid = 0").body[0])
defaults.append(ast.parse(f"{name}_data = 0").body[0])
else:
defaults.append(ast.parse(f"{name}_ready = 0").body[0])
return [node] + defaults

tree = Transformer().visit(tree)
print(astor.to_source(tree))
return tree, coroutine
2 changes: 2 additions & 0 deletions silica/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys

import silica
from silica.channel import desugar_channels
from silica.coroutine import Coroutine
from silica.cfg import ControlFlowGraph, BasicBlock, HeadBlock
from silica.cfg.control_flow_graph import render_paths_between_yields, build_state_info, render_fsm, get_constant
Expand Down Expand Up @@ -73,6 +74,7 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog',

has_ce = coroutine.has_ce
tree = ast_utils.get_ast(coroutine._definition).body[0] # Get the first element of the ast.Module
tree, coroutine = desugar_channels(tree, coroutine)
module_name = coroutine._name
func_locals.update(coroutine._defn_locals)
func_locals.update(func_globals)
Expand Down
3 changes: 3 additions & 0 deletions silica/type_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .width import get_width
import astor
import magma as m
import silica.types as types


def to_type_str(type_):
Expand All @@ -11,6 +12,8 @@ def to_type_str(type_):
return "uint"
elif isinstance(type_, m.BitsKind):
return "bits"
elif isinstance(type_, types.Channel):
return to_type_str(type_.type_)
else:
raise NotImplementedError(type_)

Expand Down
16 changes: 16 additions & 0 deletions silica/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import enum


class Direction(enum.Enum):
In = 0
Out = 1


In = Direction.In
Out = Direction.Out


class Channel:
def __init__(self, type_, direction):
self.type_ = type_
self.direction = direction
3 changes: 3 additions & 0 deletions silica/width.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import silica.types as types
import ast
import astor
from .memory import MemoryType
Expand Down Expand Up @@ -130,5 +131,7 @@ def get_io_width(type_):
return (type_.N, elem_width)
else:
return type_.N
elif isinstance(type_, types.Channel):
return get_io_width(type_.type_)
else:
raise NotImplementedError(type_)
26 changes: 16 additions & 10 deletions tests/test_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def Downsample(
data_out_valid = keep & data_in_valid
data_out_data = data_in_data
data_in_ready = data_out_ready | ~keep
if data_in_ready:
if data_in_ready & data_in_valid:
if x == 31:
data_in_valid, data_in_data, data_out_ready = yield \
data_in_ready, data_out_data, data_out_valid
Expand Down Expand Up @@ -100,8 +100,9 @@ def test_downsample_loops_simple():
tester.poke(magma_downsample.data_out_ready, 1)
tester.eval()
tester.expect(magma_downsample.data_out_data, y * 32 + x)
tester.expect(magma_downsample.data_out_valid, (y % 2 == 0) &
(x % 2 == 0))
if (y % 2 == 0) & (x % 2 == 0):
tester.expect(magma_downsample.data_out_valid,
(y % 2 == 0) & (x % 2 == 0))
tester.expect(magma_downsample.data_in_ready, 1)
tester.step(2)
tester.poke(magma_downsample.data_in_valid, 0)
Expand Down Expand Up @@ -132,21 +133,26 @@ def test_downsample_loops_simple_random_stalls():
for y in range(32):
for x in range(32):
while True:
in_valid = random.getrandbits(1)
in_valid = hwtypes.Bit(random.getrandbits(1))
tester.poke(magma_downsample.data_in_valid, in_valid)
tester.poke(magma_downsample.data_in_data, y * 32 + x)
out_ready = random.getrandbits(1)
tester.poke(magma_downsample.data_out_ready, out_ready)
tester.eval()
tester.expect(magma_downsample.data_out_data, y * 32 + x)
keep = hwtypes.Bit((y % 2 == 0) & (x % 2 == 0))
tester.expect(magma_downsample.data_out_valid,
keep & in_valid)
out_valid = keep & in_valid
if out_ready:
tester.expect(magma_downsample.data_out_valid,
out_valid)
if out_valid & out_ready:
tester.expect(magma_downsample.data_out_data,
y * 32 + x)
in_ready = hwtypes.Bit(out_ready) | ~keep
tester.expect(magma_downsample.data_in_ready,
in_ready)
if in_valid & in_ready:
tester.expect(magma_downsample.data_in_ready,
in_ready)
tester.step(2)
if in_ready:
if in_ready & in_valid:
break

tester.compile_and_run("verilator", flags=["-Wno-fatal"],
Expand Down
Loading