Skip to content

Commit

Permalink
Rework serializer and lbmem implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Nov 4, 2018
1 parent 6fc4dff commit 7e24a62
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 86 deletions.
2 changes: 1 addition & 1 deletion silica/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def visit_Assign(self, node):
elif node.targets[0].id not in self.width_table:
self.width_table[node.targets[0].id] = get_width(node.value, self.width_table, self.func_locals, self.func_globals)
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and \
node.value.func.id in {"bits", "uint"}:
node.value.func.id in {"bits", "uint", "bit"}:
self.type_table[node.targets[0].id] = node.value.func.id
elif isinstance(node.value, ast.NameConstant) and node.value.value in [True, False]:
self.type_table[node.targets[0].id] = "bit"
Expand Down
36 changes: 2 additions & 34 deletions silica/cfg/control_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from silica.cfg.types import BasicBlock, Yield, Branch, HeadBlock, State
from silica.cfg.ssa import SSAReplacer, convert_to_ssa, parse_expr
from .liveness import liveness_analysis
from .util import find_branch_join
from ..memory import MemoryType
import silica.ast_utils as ast_utils

Expand Down Expand Up @@ -103,39 +104,6 @@ def get_io(tree):
return IOCollector().run(tree)


def get_next_block(block, seen):
if isinstance(block, Branch):
return find_branch_join(block, seen)
elif isinstance(block, (BasicBlock, Yield)):
return block.outgoing_edge[0]
raise NotImplementedError(block)


def find_branch_join(branch, seen=[]):
seen.append(branch)
curr_true_block = branch.true_edge
while curr_true_block not in seen:
seen.append(curr_true_block)
curr_true_block = get_next_block(curr_true_block, seen)
curr_false_block = branch.false_edge
while curr_false_block not in seen:
curr_false_block = get_next_block(curr_false_block, seen)
return curr_false_block
# curr_false_block = branch.false_edge
# curr_true_block = branch.true_edge
# while curr_false_block != curr_true_block:
# seen.add(curr_false_block)
# seen.add(curr_true_block)
# next_true_block = get_next_block(curr_true_block, seen)
# if next_true_block == curr_false_block:
# break
# next_false_block = get_next_block(curr_false_block, seen)
# if next_false_block == curr_true_block:
# break
# curr_true_block, curr_false_block = next_true_block, next_false_block
# return curr_false_block


def get_stores_on_branch(curr_block, join_block, var_counter):
stores = set()
while curr_block != join_block:
Expand Down Expand Up @@ -212,7 +180,7 @@ class ControlFlowGraph:
def __init__(self, tree, width_table, func_locals, func_globals):
self.blocks = []
self.curr_block = None
self.curr_yield_id = 0
self.curr_yield_id = 1
self.local_vars = set()
self.width_table = width_table
self.func_locals = func_locals
Expand Down
4 changes: 4 additions & 0 deletions silica/cfg/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from collections import defaultdict
from silica.cfg.types import BasicBlock, Yield, Branch, HeadBlock, State
from .util import find_branch_join


def parse_expr(expr):
Expand Down Expand Up @@ -239,6 +240,9 @@ def get_conds_up_to(path, predecessor, cfg):
if cfg.curr_yield_id > 1:
conds.add(f"yield_state == 0")
elif isinstance(block, Branch):
# join_block = find_branch_join(block)
# if join_block in path and path.index(join_block) > i:
# continue
cond = block.cond
if path[i + 1] is block.false_edge:
cond = ast.UnaryOp(ast.Invert(), cond)
Expand Down
34 changes: 34 additions & 0 deletions silica/cfg/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .types import Branch, BasicBlock, Yield


def get_next_block(block, seen):
if isinstance(block, Branch):
return find_branch_join(block, seen)
elif isinstance(block, (BasicBlock, Yield)):
return block.outgoing_edge[0]
raise NotImplementedError(block)


def find_branch_join(branch, seen=[]):
seen.append(branch)
curr_true_block = branch.true_edge
while curr_true_block not in seen:
seen.append(curr_true_block)
curr_true_block = get_next_block(curr_true_block, seen)
curr_false_block = branch.false_edge
while curr_false_block not in seen:
curr_false_block = get_next_block(curr_false_block, seen)
return curr_false_block
# curr_false_block = branch.false_edge
# curr_true_block = branch.true_edge
# while curr_false_block != curr_true_block:
# seen.add(curr_false_block)
# seen.add(curr_true_block)
# next_true_block = get_next_block(curr_true_block, seen)
# if next_true_block == curr_false_block:
# break
# next_false_block = get_next_block(curr_false_block, seen)
# if next_false_block == curr_true_block:
# break
# curr_true_block, curr_false_block = next_true_block, next_false_block
# return curr_false_block
3 changes: 2 additions & 1 deletion silica/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog',
tree = ast_utils.get_ast(coroutine._definition).body[0] # Get the first element of the ast.Module
module_name = coroutine._name
func_locals.update(coroutine._defn_locals)
func_locals.update(func_globals)
specialize_arguments(tree, coroutine)
specialize_constants(tree, coroutine._defn_locals)
specialize_evals(tree, func_globals, func_locals)
Expand Down Expand Up @@ -255,7 +256,7 @@ def get_len(t):
wens = {}
# if initial_basic_block:
# states = states[1:]
verilog.compile_states(ctx, states, cfg.curr_yield_id == 1, width_table,
verilog.compile_states(ctx, states, cfg.curr_yield_id == 3, width_table,
registers, sub_coroutines, strategy)
# cfg.render()
verilog_str = ""
Expand Down
12 changes: 9 additions & 3 deletions silica/transformations/promote_widths.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ def __init__(self, width_table, type_table):
self.type_table = type_table

def check_valid(self, int_length, expected_length):
if expected_length is None and int_length > 1 or int_length > expected_length:
raise TypeError("Cannot promote integer with greater width than other operand")
if expected_length is None and int_length <= 1:
return
if int_length <= expected_length:
return
raise TypeError("Cannot promote integer with greater width than other operand")

def make(self, value, width, type_):
return ast.parse(f"{type_}({value}, {width})").body[0].value
if type_ == "bit":
return ast.parse(f"{type_}({value})").body[0].value
else:
return ast.parse(f"{type_}({value}, {width})").body[0].value

def get_type(self, node):
if isinstance(node, ast.Name):
Expand Down
2 changes: 1 addition & 1 deletion silica/width.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_width(node, width_table, func_locals={}, func_globals={}):
left_width = get_width(node.left, width_table)
right_width = get_width(node.right, width_table)
if left_width != right_width:
raise TypeError(f"Binary operation with mismatched widths {ast.dump(node)}")
raise TypeError(f"Binary operation with mismatched widths ({left_width}, {right_width}) {ast.dump(node)}")
return left_width
elif isinstance(node, ast.IfExp):
left_width = get_width(node.body, width_table)
Expand Down
File renamed without changes.
File renamed without changes.
70 changes: 46 additions & 24 deletions tests/test_lbmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,31 +65,53 @@ def SILbMem(depth=64, lbmem_width=8):
def mem(wdata : si.Bits(8), wen : si.Bit):
lbmem = memory(depth, lbmem_width)
waddr = uint(0, eval(math.ceil(math.log2(depth))))
count = uint(0, 3)
count = uint(0, 4)
state = bit(0)
wdata, wen = yield
while True:
while (count < uint(7, 3)) | ~wen:
rdata = lbmem[waddr - uint(count, 6)]
valid = bit(0)
if wen:
raddr = waddr - uint(count, 6)
else:
raddr = waddr - uint(count - 1, 6)
rdata = lbmem[raddr]
valid = (state & (bit(count != 1) | wen)) | (bit(count == 7) & wen)
if state == 0:
state = (count == 7) & wen
if wen:
lbmem[waddr] = wdata
count = count + 1
waddr = waddr + 1
wdata, wen = yield rdata, valid
lbmem[waddr] = wdata
rdata = lbmem[waddr - uint(count, 6)]
waddr = waddr + 1
valid = bit(1)
wdata, wen = yield rdata, valid
while count > 0:
valid = bit(1)
rdata = lbmem[waddr - uint(count, 6)]
else:
state = (count != 1) | wen
if ~wen:
count = count - 1
if wen:
lbmem[waddr] = wdata
waddr = waddr + 1
wdata, wen = yield rdata, valid
if wen:
lbmem[waddr] = wdata
waddr = waddr + 1
wdata, wen = yield rdata, valid


# while True:
# while (count < uint(7, 3)) | ~wen:
# rdata = lbmem[waddr - uint(count, 6)]
# valid = bit(0)
# if wen:
# lbmem[waddr] = wdata
# count = count + 1
# waddr = waddr + 1
# wdata, wen = yield rdata, valid
# lbmem[waddr] = wdata
# rdata = lbmem[waddr - uint(count, 6)]
# waddr = waddr + 1
# valid = bit(1)
# wdata, wen = yield rdata, valid
# while count > 0:
# valid = bit(1)
# rdata = lbmem[waddr - uint(count, 6)]
# if ~wen:
# count = count - 1
# if wen:
# lbmem[waddr] = wdata
# waddr = waddr + 1
# wdata, wen = yield rdata, valid

# while True:
# waddr = yield from FillingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen)
Expand Down Expand Up @@ -121,10 +143,10 @@ def test_lbmem():
tester.poke(si_lbmem.wen, 1)
lbmem.send((i, BitVector(1)))
tester.step(1)
assert lbmem.valid == (i == 7)
assert lbmem.valid == (i == 7), (lbmem.valid, i)
# tester.print(si_lbmem.valid)
tester.expect(si_lbmem.valid, i == 7), "Valid only on last write"
assert lbmem.rdata == 0, "should be 0, even on first read"
assert lbmem.rdata == 0, f"should be 0, even on first read, iter={i}, got {lbmem.rdata}"
tester.expect(si_lbmem.rdata, 0)
tester.step(1)
for i in range(0, 8):
Expand All @@ -137,8 +159,8 @@ def test_lbmem():
# there's nothing left to drain
rdata = i + 1 if i < 7 else 0
valid = i < 7
assert lbmem.rdata == rdata, i
assert lbmem.valid == valid, i
assert lbmem.rdata == rdata, (i, (lbmem.rdata, lbmem.valid), (rdata, valid))
assert lbmem.valid == valid, (i, (lbmem.rdata, lbmem.valid), (rdata, valid))
tester.expect(si_lbmem.rdata, rdata)
# tester.print(si_lbmem.valid)
tester.expect(si_lbmem.valid, valid)
Expand Down Expand Up @@ -166,7 +188,7 @@ def test_lbmem():
# print(drain_state)

tester.compile_and_run(target="verilator", directory="tests/build",
flags=['-Wno-fatal', '--trace'])
flags=['-Wno-fatal'])
verilog_lbmem = m.DefineFromVerilogFile("verilog/lbmem.v",
type_map={"CLK": m.In(m.Clock)})[0]
verilog_tester = tester.retarget(verilog_lbmem, verilog_lbmem.CLK)
Expand Down
38 changes: 22 additions & 16 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,29 @@


@coroutine
def Serializer4(I0 : Bits(16), I1 : Bits(16), I2 : Bits(16), I3 : Bits(16)):
def Serializer4(valid : Bit, I0 : Bits(16), I1 : Bits(16), I2 : Bits(16), I3 : Bits(16)):
# data = [bits(0, 16) for _ in range(3)]
data0 = bits(0, 16)
data1 = bits(0, 16)
data2 = bits(0, 16)
# I0, I1, I2, I3 = yield
I0, I1, I2, I3 = yield
O = bits(0, 16)
while True:
O = I0
# data = I[1:]
data0 = I1
data1 = I2
data2 = I3
I0, I1, I2, I3 = yield O
O = data0
I0, I1, I2, I3 = yield O
O = data1
I0, I1, I2, I3 = yield O
O = data2
I0, I1, I2, I3 = yield O
valid, I0, I1, I2, I3 = yield O
if valid:
O = I0
# data = I[1:]
data0 = I1
data1 = I2
data2 = I3
valid, I0, I1, I2, I3 = yield O
O = data0
valid, I0, I1, I2, I3 = yield O
O = data1
valid, I0, I1, I2, I3 = yield O
O = data2
else:
O = bits(0, 16)
# for i in range(3):
# O = data[i]
# I0, I1, I2, I3 = yield O
Expand All @@ -53,20 +56,23 @@ def test_ser4():
# serializer_si = m.DefineFromVerilogFile("tests/build/serializer_si.v",
# type_map={"CLK": m.In(m.Clock)})[0]
tester = fault.Tester(serializer_si, serializer_si.CLK)
tester.step(1)
for I in inputs:
tester.poke(serializer_si.valid, 1)
for j in range(len(I)):
tester.poke(getattr(serializer_si, f"I{j}"), I[j])
tester.step(1)
ser.send(I)
ser.send([1] + I)
assert ser.O == I[0]
tester.print(serializer_si.O)
tester.expect(serializer_si.O, I[0])
tester.step(1)
for i in range(3):
tester.poke(serializer_si.valid, 0)
for j in range(len(I)):
tester.poke(getattr(serializer_si, f"I{j}"), 0)
tester.step(1)
ser.send([0,0,0,0])
ser.send([0,0,0,0,0])
assert ser.O == I[i + 1]
tester.expect(serializer_si.O, I[i + 1])
tester.step(1)
Expand Down
2 changes: 1 addition & 1 deletion verilog/lbmem.v
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module lbmem(

reg state = 0;

reg [3:0] cnt = 0;
reg [4:0] cnt = 0;

always @(posedge CLK) begin
if (state==0) begin
Expand Down
13 changes: 8 additions & 5 deletions verilog/serializer.v
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

module serializer(
input CLK,
input valid,
// yosys doesn't like this syntax
// input [3:0][15:0] I,
input [15:0] I0,
Expand All @@ -12,21 +13,23 @@ module serializer(
reg [15:0] s1;
reg [15:0] s2;
reg [15:0] s3;
reg state = 0;

reg [1:0] cnt = 0;

always @(posedge CLK) begin
cnt <= (cnt==3) ? 0 : (cnt + 1);
end

always @(posedge CLK) begin
if (cnt==2'h0) begin
if (state == 0 && valid) begin
cnt <= 0;
// s1 <= I[1];
// s2 <= I[2];
// s3 <= I[3];
s1 <= I1;
s2 <= I2;
s3 <= I3;
state <= 1;
end else begin
cnt <= cnt + 1;
state <= cnt == 2 ? 0 : 1;
end
end

Expand Down

0 comments on commit 7e24a62

Please sign in to comment.