diff --git a/README.md b/README.md index 8ab37d1..25eb991 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ A language for constructing hardware finite state machines using coroutines. # Setup ``` +pip install -r requirements.txt pip install -e . ``` diff --git a/requirements.txt b/requirements.txt index 76be0ef..2fd2bce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ git+git://github.com/phanrahan/magma.git#egg=magma git+git://github.com/phanrahan/mantle.git#egg=mantle git+git://github.com/phanrahan/loam.git#egg=loam coreir -veriloggen +git+git://github.com/leonardt/veriloggen.git#egg=veriloggen diff --git a/setup.py b/setup.py index db11cf2..da99675 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,9 @@ name='silica', version='0.1-alpha', description='', - packages=["silica"], + packages=["silica", "silica.cfg"], install_requires=[ - "fault==0.27" + "fault==0.27", + "python-constraint" ] ) diff --git a/silica/cfg/control_flow_graph.py b/silica/cfg/control_flow_graph.py index 696fa6e..bfedecd 100644 --- a/silica/cfg/control_flow_graph.py +++ b/silica/cfg/control_flow_graph.py @@ -10,12 +10,15 @@ import ast import astor import magma +import constraint import silica from silica.transformations import specialize_constants, replace_symbols, constant_fold -from silica.visitors import collect_names, collect_stores +from silica.visitors import collect_names, collect_stores, collect_loads from silica.cfg.types import BasicBlock, Yield, Branch, HeadBlock, State from silica.cfg.ssa import SSAReplacer +from .liveness import liveness_analysis +from ..memory import MemoryType def get_constant(node): if isinstance(node, ast.Num) and len(stmt.targets) == 1: @@ -197,18 +200,24 @@ class ControlFlowGraph: * ``self.curr_block`` - the current block used by the construction algorithm """ - def __init__(self, tree): + def __init__(self, tree, width_table, func_locals): self.blocks = [] self.curr_block = None self.curr_yield_id = 0 self.local_vars = set() + self.width_table = width_table + self.func_locals = func_locals # inputs, outputs = parse_arguments(tree.args.args) inputs, outputs = get_io(tree) self.build(tree) + # for var in self.replacer.id_counter: + # width = self.width_table[var] + # for i in range(self.replacer.id_counter[var] + 1): + # self.width_table[f"{var}_{i}"] = width self.bypass_conds() - replacer = SSAReplacer() - var_counter = {} + # replacer = SSAReplacer() + # var_counter = {} # for block in self.blocks: # if isinstance(block, (HeadBlock, BasicBlock)): # replacer.process_block(block) @@ -226,10 +235,47 @@ def __init__(self, tree): self.render() raise error # self.paths = promote_live_variables(self.paths) + liveness_analysis(self) self.states, self.state_vars = build_state_info(self.paths, outputs, inputs) - # self.render() + + self.registers = set() + for path in self.paths: + for block in path: + if isinstance(block, (HeadBlock, Yield)): + loads = self.collect_loads_before_yield(block.outgoing_edge[0], set()) + for load in loads: + if load in block.live_ins and "_si_tmp_val_" not in load: + prefix = "_".join(load.split("_")[:-1]) + self.registers.add(prefix) + block.loads[load] = prefix + # render_paths_between_yields(self.paths) + for path in self.paths: + for block in path: + if isinstance(block, Yield): + stores = set() + for edge, _ in block.incoming_edges: + stores |= self.collect_stores_after_yield(edge, set()) + for store in stores: + prefix = "_".join(store.split("_")[:-1]) + for path2 in self.paths: + for block2 in path2: + if isinstance(block2, Yield): + if prefix in block2.loads.values(): + store_id = -1 + for store in stores: + try: + if prefix in store and "_si_tmp_val_" not in store: + store_id = max(int(store.split("_")[-1]), store_id) + except ValueError: + pass + if store_id < 0: + store_id = 0 + block.stores[prefix + "_" + str(store_id)] = prefix + break + + # self.render() # render_fsm(self.states) # exit() @@ -253,7 +299,7 @@ def build(self, func_def): # else: # raise NotImplementedError() # assert isinstance(func_def.body[-1], ast.While), "FSMs should end with a ``while True:``" - self.replacer = SSAReplacer() + self.replacer = SSAReplacer(self.width_table) for statement in func_def.body: self.process_stmt(statement) # self.process_stmt(func_def.body[-1]) @@ -261,7 +307,7 @@ def build(self, func_def): self.remove_if_trues() - def find_paths(self, block, initial_block): + def find_paths(self, block, initial_block, conds=[]): """ Given a block, recursively build paths to yields """ @@ -269,14 +315,59 @@ def find_paths(self, block, initial_block): if isinstance(block.value, ast.Yield) and block.value.value is None or \ isinstance(block.value, ast.Assign) and block.value.value.value is None: initial_block.initial_yield = block - return [path for path in self.find_paths(block.outgoing_edge[0], initial_block)] + return [path for path in self.find_paths(block.outgoing_edge[0], initial_block, conds)] else: - return [[copy(block)]] + return [[(block)]] elif isinstance(block, BasicBlock): - return [[copy(block)] + path for path in self.find_paths(block.outgoing_edge[0], initial_block)] + return [[(block)] + path for path in self.find_paths(block.outgoing_edge[0], initial_block, conds)] elif isinstance(block, Branch): - true_paths = [[copy(block)] + path for path in self.find_paths(block.true_edge, initial_block)] - false_paths = [[copy(block)] + path for path in self.find_paths(block.false_edge, initial_block)] + true_paths = [] + false_paths = [] + # print("begin", block) + for value, paths, edge in [(True, true_paths, block.true_edge), + (False, false_paths, block.false_edge)]: + _conds = conds[:] + [(block.cond, value)] + problem = constraint.Problem() + variables = {} + constraints = [] + # print("===========") + seen = set() + for cond, result in _conds: + args = collect_names(cond) + for special_func in ["uint", "bit"]: + if special_func in args: + args.remove(special_func) + for arg in args: + if arg in seen: + continue + seen.add(arg) + base_arg = "_".join(arg.split("_")[:-1]) + width = self.width_table[arg] + if width is None: + width = 1 + # -1 hack to work around ~0 == -1 in Python + problem.addVariable(arg, list(range(-1, 1 << width))) + if base_arg not in variables: + variables[base_arg] = [] + variables[base_arg].append(arg) + check = "!=" if result else "==" + _constraint = f"lambda {', '.join(args)}: {astor.to_source(cond).rstrip()} {check} 0" + # print(_constraint) + problem.addConstraint( + eval(_constraint, self.func_locals), + tuple(args)) + for same in variables.values(): + if len(same) > 1: + _constraint = f"lambda {', '.join(same)}: { ' == '.join(same) }" + # print(_constraint) + problem.addConstraint( + eval(_constraint, self.func_locals), + tuple(same)) + # print(problem.getSolution(), value, edge) + # print("===========") + if problem.getSolution(): + paths += [[(block)] + path for path in + self.find_paths(edge, initial_block, _conds)] for path in true_paths: path[0].true_edge = path[1] for path in false_paths: @@ -308,13 +399,14 @@ def collect_paths_between_yields(self): if isinstance(block, (Yield, HeadBlock)): if isinstance(block, Yield) and block.is_initial_yield: continue # Skip initial yield - paths.extend([deepcopy(block)] + path for path in self.find_paths(block.outgoing_edge[0], block)) + paths.extend([block] + path for path in self.find_paths(block.outgoing_edge[0], block, [])) for path in paths: for i in range(len(path) - 1): path[i].next = path[i + 1] path[-1].next = None path[-1].terminal = True path[0].terminal = False + # render_paths_between_yields(paths) return paths def bypass_conds(self): @@ -410,12 +502,11 @@ def process_branch(self, stmt): orig_bb = self.curr_block true_stores = set() for sub_stmt in stmt.body: - # true_stores |= collect_stores(sub_stmt) - true_stores |= collect_names(sub_stmt, ast.Store) + true_stores |= collect_stores(sub_stmt) phi_vars = {} for var in true_stores: if var in self.replacer.id_counter: - phi_vars[var] = [None, f"{var}_{self.replacer.id_counter[var]}"] + phi_vars[var] = [None, None] # stmt.body holds the True path for both If and While nodes orig_id_counter = copy(self.replacer.id_counter) for sub_stmt in stmt.body: @@ -425,24 +516,39 @@ def process_branch(self, stmt): if var in self.replacer.id_counter and var in phi_vars: phi_vars[var][0] = f"{var}_{self.replacer.id_counter[var]}" if isinstance(stmt, ast.While): - for base_var, (true_var, false_var) in phi_vars.items(): - self.curr_block.statements.append( - # ast.parse(f"{orig_var} = phi({true_var}, {orig_var}, cond={astor.to_source(stmt.test)})").body[0]) - ast.parse(f"{false_var} = {true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var}").body[0]) - for (name, index), value in self.replacer.array_stores.items(): + # for base_var, (true_var, false_var) in phi_vars.items(): + # if not (isinstance(stmt.test, ast.NameConstant) and stmt.test.value == True): + # self.replacer.id_counter[base_var] += 1 + # if false_var is None: + # false_var = f"{base_var}_{self.replacer.id_counter[base_var]}" + # self.replacer.id_counter[base_var] += 1 + + # next_var = f"{base_var}_{self.replacer.id_counter[base_var]}" + # self.curr_block.statements.append( + # # ast.parse(f"{orig_var} = phi({true_var}, {orig_var}, cond={astor.to_source(stmt.test)})").body[0]) + # ast.parse(f"{next_var} = phi({true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var})").body[0]) + seen = set() + for (name, index), (value, orig_value) in self.replacer.array_stores.items(): index_hash = "_".join(ast.dump(i) for i in index) count = self.replacer.index_map[index_hash] if not (name, index) in orig_array_stores or \ orig_index_map[index_hash] < count: if (name, index, value, count) in self.replacer.array_store_processed: continue + # if name not in seen: + # seen.add(name) + # width = self.width_table[name] + # if isinstance(width, MemoryType): + # for i in range(width.height): + # self.curr_block.append(ast.parse(f"{name}[{i}] = {orig_value}[{i}]").body[0]) + # else: + # self.curr_block.append(ast.parse(f"{name} = {orig_value}").body[0]) self.replacer.array_store_processed.add((name, index, value, count)) index_str = "" for i in index: index_str = f"[{astor.to_source(i).rstrip()}]" + index_str - false_var = name + index_str - true_var = name + f"_{value}_i{count}" - self.curr_block.statements.append( ast.parse(f"{name}{index_str} = {true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var}").body[0]) + true_var = name + f"_si_tmp_val_{value}_i{count}" + self.curr_block.statements.append(ast.parse(f"{name}{index_str} = {true_var}").body[0]) # Exit the current basic block by looping back to the branch # node add_edge(self.curr_block, branch) @@ -453,15 +559,15 @@ def process_branch(self, stmt): elif isinstance(stmt, (ast.If,)): end_then_block = self.curr_block true_counts = {} - for (name, index), value in self.replacer.array_stores.items(): + for (name, index), (value, orig_value) in self.replacer.array_stores.items(): index_hash = "_".join(ast.dump(i) for i in index) count = self.replacer.index_map[index_hash] true_counts[name, index] = count if stmt.orelse: false_stores = set() for sub_stmt in stmt.orelse: - # false_stores |= collect_stores(sub_stmt) - false_stores |= collect_names(sub_stmt, ast.Store) + false_stores |= collect_stores(sub_stmt) + # false_stores |= collect_names(sub_stmt, ast.Store) self.curr_block = self.gen_new_block() add_false_edge(branch, self.curr_block) load_store_offset = {} @@ -477,45 +583,52 @@ def process_branch(self, stmt): for var, diff in load_store_offset.items(): if var in false_stores: self.replacer.id_counter[var] += diff + self.width_table[f"{var}_{self.replacer.id_counter[var]}"] = self.width_table[var] self.replacer.load_store_offset = {} for var in false_stores: if var in self.replacer.id_counter and var in phi_vars: phi_vars[var][1] = f"{var}_{self.replacer.id_counter[var]}" end_else_block = self.curr_block - self.curr_block = self.gen_new_block() - for base_var, (true_var, false_var) in phi_vars.items(): - if stmt.orelse: - target = false_var - else: - target = true_var - self.curr_block.statements.append( - # 0, ast.parse(f"{last_var} = phi({true_var}, {false_var}, cond={astor.to_source(stmt.test)})").body[0]) - ast.parse(f"{target} = {true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var}").body[0]) - for (name, index), value in self.replacer.array_stores.items(): + seen = set() + for (name, index), (value, orig_value) in self.replacer.array_stores.items(): index_hash = "_".join(ast.dump(i) for i in index) count = self.replacer.index_map[index_hash] if not (name, index) in orig_array_stores or \ orig_index_map[index_hash] < count: if (name, index, value, count) in self.replacer.array_store_processed: continue + # if name not in seen: + # seen.add(name) + # width = self.width_table[name] + # if isinstance(width, MemoryType): + # for i in range(width.height): + # self.curr_block.append(ast.parse(f"{name}[{i}] = {orig_value}[{i}]").body[0]) + # else: + # self.curr_block.append(ast.parse(f"{name} = {orig_value}").body[0]) self.replacer.array_store_processed.add((name, index, value, count)) index_str = "" for i in index: index_str = f"[{astor.to_source(i).rstrip()}]" + index_str - false_var = name + index_str - if stmt.orelse: - if (name, index) not in true_counts: - true_var = name + index_str - false_var = name + f"_{value}_i{count}" - elif true_counts[name, index] == count: - true_var = name + f"_{value}_i{count}" - else: - true_var = name + f"_{value}_i{count - 1}" - false_var = name + f"_{value}_i{count}" - else: - true_var = name + f"_{value}_i{count}" - self.curr_block.statements.insert( - 0, ast.parse(f"{name}{index_str} = {true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var}").body[0]) + true_var = name + f"_si_tmp_val_{value}_i{count}" + self.curr_block.statements.append(ast.parse(f"{name}{index_str} = {true_var}").body[0]) + self.curr_block = self.gen_new_block() + for base_var, (true_var, false_var) in phi_vars.items(): + self.replacer.id_counter[base_var] += 1 + if false_var is None: + false_var = f"{base_var}_{self.replacer.id_counter[base_var]}" + self.replacer.id_counter[base_var] += 1 + self.width_table[false_var] = self.width_table[base_var] + next_var = f"{base_var}_{self.replacer.id_counter[base_var]}" + self.width_table[next_var] = self.width_table[base_var] + width = self.width_table[base_var] + if isinstance(width, MemoryType): + for i in range(width.height): + self.curr_block.append( + ast.parse(f"{next_var}[{i}] = phi({true_var}[{i}] if {astor.to_source(stmt.test).rstrip()} else {false_var}[{i}])").body[0] + ) + else: + self.curr_block.statements.append( + ast.parse(f"{next_var} = phi({true_var} if {astor.to_source(stmt.test).rstrip()} else {false_var})").body[0]) add_edge(end_then_block, self.curr_block) if stmt.orelse: @@ -555,19 +668,31 @@ def process_assign(self, stmt): else: raise NotImplementedError(stmt.value.value) array_stores_to_process = [] - for (name, index), value in self.replacer.array_stores.items(): + seen = set() + for (name, index), (value, orig_value) in self.replacer.array_stores.items(): index_hash = "_".join(ast.dump(i) for i in index) count = self.replacer.index_map[index_hash] if (name, index, value, count) in self.replacer.array_store_processed: continue + # if name not in seen: + # seen.add(name) + # width = self.width_table[name] + # if isinstance(width, MemoryType): + # for i in range(width.height): + # self.curr_block.append(ast.parse(f"{name}[{i}] = {orig_value}[{i}]").body[0]) + # else: + # self.curr_block.append(ast.parse(f"{name} = {orig_value}").body[0]) index_str = "" for i in index: index_str = f"[{astor.to_source(i).rstrip()}]" + index_str self.replacer.array_store_processed.add((name, index, value, count)) - true_var = name + f"_{value}_i{count}" + true_var = name + f"_si_tmp_val_{value}_i{count}" array_stores_to_process.append( ast.parse(f"{name}{index_str} = {true_var}").body[0]) - self.add_new_yield(stmt, output_map, array_stores_to_process) + self.add_new_yield(stmt, output_map) + for var in self.replacer.id_counter: + self.replacer.id_counter[var] += 1 + self.width_table[f"{var}_{self.replacer.id_counter[var]}"] = self.width_table[var] else: self.replacer.visit(stmt) self.curr_block.add(stmt) @@ -699,6 +824,39 @@ def get_basic_blocks_followed_by_branches(self): isinstance(block.outgoing_edge[0], Branch) return filter(is_basicblock_followed_by_branch, self.blocks) + + def collect_loads_before_yield(self, block, seen): + if isinstance(block, Yield): + return set() + if block in seen: + return set() + seen.add(block) + loads = set() + if isinstance(block, BasicBlock): + for stmt in block.statements: + loads |= collect_loads(stmt) + elif isinstance(block, Branch): + loads |= collect_loads(block.cond) + for edge, _ in block.outgoing_edges: + loads |= self.collect_loads_before_yield(edge, seen) + return loads + + + def collect_stores_after_yield(self, block, seen): + if isinstance(block, (Yield, HeadBlock)): + return set() + if block in seen: + return set() + seen.add(block) + stores = set() + if isinstance(block, BasicBlock): + for stmt in block.statements: + stores |= collect_stores(stmt) + for edge, _ in block.incoming_edges: + stores |= self.collect_stores_after_yield(edge, seen) + return stores + + def render_fsm(states): from graphviz import Digraph dot = Digraph(name="top") diff --git a/silica/liveness.py b/silica/cfg/liveness.py similarity index 50% rename from silica/liveness.py rename to silica/cfg/liveness.py index 614b41e..a443933 100644 --- a/silica/liveness.py +++ b/silica/cfg/liveness.py @@ -1,5 +1,6 @@ import ast -import silica.cfg.types as cfg_types +from .types import Yield, HeadBlock, Branch +import silica.ast_utils as ast_utils class Analyzer(ast.NodeVisitor): def __init__(self): @@ -8,6 +9,9 @@ def __init__(self): # self.return_values = set() def visit_Call(self, node): + # if ast_utils.is_name(node.func) and node.func.id == "phi": + # # skip + # return # Ignore function calls for arg in node.args: self.visit(arg) @@ -23,6 +27,8 @@ def visit_Name(self, node): self.gen.add(node.id) else: self.kill.add(node.id) + if node.id in self.gen: + self.gen.remove(node.id) # def visit_Return(self, node): # if isinstance(node, ast.Tuple): @@ -34,31 +40,45 @@ def visit_Name(self, node): def analyze(node): analyzer = Analyzer() - if isinstance(node, cfg_types.Yield): + if isinstance(node, Yield): analyzer.visit(node.value) analyzer.gen = set(value for value in node.output_map.values()) - if not node.terminal: # We only use assigments - analyzer.gen = set() - else: - analyzer.kill = set() - elif isinstance(node, cfg_types.HeadBlock): + # if not node.terminal: # We only use assigments + # analyzer.gen = set() + # else: + # analyzer.kill = set() + elif isinstance(node, HeadBlock): for statement in node.initial_statements: analyzer.visit(statement) if hasattr(node, "initial_yield"): analyzer.visit(node.initial_yield.value) - elif isinstance(node, cfg_types.Branch): + elif isinstance(node, Branch): analyzer.visit(node.cond) else: for statement in node.statements: analyzer.visit(statement) return analyzer.gen, analyzer.kill +def do_analysis(block, seen=set()): + if block in seen: + return block.live_ins + seen.add(block) + block.gen, block.kill = analyze(block) + if not isinstance(block, Yield): + if isinstance(block, Branch): + true_live_ins = do_analysis(block.true_edge, seen) + false_live_ins = do_analysis(block.false_edge, seen) + block.live_outs = true_live_ins | false_live_ins + else: + block.live_outs = do_analysis(block.outgoing_edge[0], seen) + still_live = block.live_outs - block.kill + block.live_ins = block.gen | still_live + return block.live_ins + def liveness_analysis(cfg): - for path in cfg.paths_between_yields: - path[-1].live_outs = set() - for node in reversed(path): - node.gen, node.kill = analyze(node) - if node.next is not None: - node.live_outs = node.next.live_ins - still_live = node.live_outs - node.kill - node.live_ins = node.gen | still_live + for block in cfg.blocks: + if isinstance(block, (HeadBlock, Yield)): + block.live_outs = do_analysis(block.outgoing_edge[0]) + block.gen, block.kill = analyze(block) + still_live = block.live_outs - block.kill + block.live_ins = still_live diff --git a/silica/cfg/ssa.py b/silica/cfg/ssa.py index aef0173..c28e34e 100644 --- a/silica/cfg/ssa.py +++ b/silica/cfg/ssa.py @@ -3,7 +3,8 @@ class SSAReplacer(ast.NodeTransformer): - def __init__(self): + def __init__(self, width_table): + self.width_table = width_table self.id_counter = {} self.phi_vars = {} self.load_store_offset = {} @@ -43,17 +44,22 @@ def visit_Assign(self, node): node.targets[0].slice = self.visit(node.targets[0].slice) store_offset = self.load_store_offset.get(node.targets[0].value.id, 0) name = self.get_name(node.targets[0]) - name += f"_{self.id_counter[name]}" + prev_name = name + f"_{self.id_counter[name] - store_offset}" + name += f"_{self.id_counter[name] + store_offset}" index = self.get_index(node.targets[0]) - if name not in self.array_stores: - self.array_stores[name, index] = 0 + if (name, index) not in self.array_stores: + self.array_stores[name, index] = (0, prev_name) + num = 0 else: - self.array_stores[name, index] += 1 + val = self.array_stores[name, index] + num = val[0] + 1 + self.array_stores[name, index] = (num, val[1]) index_hash = "_".join(ast.dump(i) for i in index) if index_hash not in self.index_map: self.index_map[index_hash] = len(self.index_map) - node.targets[0].value.id += f"_{self.id_counter[node.targets[0].value.id]}_{self.array_stores[name, index]}_i{self.index_map[index_hash]}" + node.targets[0].value.id = f"{name}_si_tmp_val_{num}_i{self.index_map[index_hash]}" node.targets[0] = node.targets[0].value + node.targets[0].ctx = ast.Store() # self.increment_id(name) # if name not in self.seen: # self.increment_id(name) @@ -80,4 +86,5 @@ def visit_Name(self, node): # self.increment_id(node.id) # self.seen.add(node.id) store_offset = self.load_store_offset.get(node.id, 0) - return ast.Name(f"{node.id}_{self.id_counter[node.id] + store_offset}", ast.Store) + self.width_table[f"{node.id}_{self.id_counter[node.id]}"] = self.width_table[node.id] + return ast.Name(f"{node.id}_{self.id_counter[node.id] + store_offset}", ast.Store()) diff --git a/silica/cfg/types.py b/silica/cfg/types.py index 29721cb..6daebbd 100644 --- a/silica/cfg/types.py +++ b/silica/cfg/types.py @@ -28,6 +28,7 @@ class HeadBlock(Block): def __init__(self): super().__init__() self.initial_statements = [] + self.loads = {} def add(self, stmt): self.initial_statements.append(stmt) @@ -44,6 +45,9 @@ def __init__(self): def add(self, stmt): self.statements.append(stmt) + def append(self, stmt): + self.statements.append(stmt) + def __iter__(self): return iter(self.statements) @@ -69,6 +73,8 @@ def __init__(self, value, output_map={}, array_stores_to_process=[]): self.value = value self.output_map = output_map self.array_stores_to_process = array_stores_to_process + self.loads = {} + self.stores = {} @property def is_initial_yield(self): diff --git a/silica/compile.py b/silica/compile.py index 8887aa8..362a30e 100644 --- a/silica/compile.py +++ b/silica/compile.py @@ -11,7 +11,6 @@ from silica.cfg import ControlFlowGraph, BasicBlock, HeadBlock from silica.cfg.control_flow_graph import render_paths_between_yields, build_state_info, render_fsm, get_constant import silica.ast_utils as ast_utils -from silica.liveness import liveness_analysis from silica.transformations import specialize_constants, replace_symbols, \ constant_fold, desugar_for_loops, specialize_evals, inline_yield_from_functions from silica.visitors import collect_names @@ -95,12 +94,8 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog', type_table = {} TypeChecker(width_table, type_table).check(tree) # DesugarArrays().run(tree) - cfg = ControlFlowGraph(tree) - for var in cfg.replacer.id_counter: - width = width_table[var] - for i in range(cfg.replacer.id_counter[var] + 1): - width_table[f"{var}_{i}"] = width - liveness_analysis(cfg) + cfg = ControlFlowGraph(tree, width_table, func_locals) + # cfg.render() # render_paths_between_yields(cfg.paths) if output == 'magma': @@ -108,10 +103,12 @@ def compile(coroutine, file_name=None, mux_strategy="one-hot", output='verilog', return compile_magma(coroutine, file_name, mux_strategy, output) registers = set() + registers |= cfg.registers outputs = tuple() for path in cfg.paths: - registers |= path[0].live_ins # Union + registers |= set(path[0].loads.values()) # Union outputs += (collect_names(path[-1].value, ctx=ast.Load), ) + assert all(outputs[1] == output for output in outputs[1:]), "Yield statements must all have the same outputs except for the first" outputs = outputs[1] states = cfg.states @@ -166,9 +163,12 @@ def get_len(t): width = width_table[var] for i in range(cfg.replacer.id_counter[var] + 1): if f"{var}_{i}" not in registers: - ctx.declare_wire(f"{var}_{i}", width) + if isinstance(width, MemoryType): + ctx.declare_wire(f"{var}_{i}", width.width, width.height) + else: + ctx.declare_wire(f"{var}_{i}", width) - for (name, index), value in cfg.replacer.array_stores.items(): + for (name, index), (value, orig_value) in cfg.replacer.array_stores.items(): width = width_table[name] if isinstance(width, MemoryType): width = width.width @@ -178,8 +178,10 @@ def get_len(t): index_hash = "_".join(ast.dump(i) for i in index) count = cfg.replacer.index_map[index_hash] for i in range(count + 1): - var = name + f"_{value}_i{count}" - width_table[var] = width + var = name + f"_si_tmp_val_{value}_i{i}" + if var not in width_table: + width_table[var] = width + ctx.declare_wire(var, width) # declare regs for register in registers: @@ -192,12 +194,14 @@ def get_len(t): init_body = [ctx.assign(ctx.get_by_name(key), value) for key,value in initial_values.items() if value is not None] if cfg.curr_yield_id > 1: - ctx.declare_reg("yield_state", (cfg.curr_yield_id - 1).bit_length()) + yield_state_width = (cfg.curr_yield_id - 1).bit_length() + ctx.declare_reg("yield_state", yield_state_width) + ctx.declare_wire(f"yield_state_next", yield_state_width) init_body.append(ctx.assign(ctx.get_by_name("yield_state"), 0)) - if initial_basic_block: - for statement in states[0].statements: - verilog.process_statement(statement) + # if initial_basic_block: + # for statement in states[0].statements: + # verilog.process_statement(statement) # init_body.append(ctx.translate(statement)) # TODO: redefinition bug? # temp_var_promoter.visit(statement) @@ -207,9 +211,9 @@ def get_len(t): waddrs = {} wdatas = {} wens = {} - if initial_basic_block: - states = states[1:] - verilog.compile_states(ctx, states, cfg.curr_yield_id == 1, width_table, strategy) + # if initial_basic_block: + # states = states[1:] + verilog.compile_states(ctx, states, cfg.curr_yield_id == 1, width_table, registers, strategy) # cfg.render() with open(file_name, "w") as f: diff --git a/silica/verilog.py b/silica/verilog.py index 34515d8..3207155 100644 --- a/silica/verilog.py +++ b/silica/verilog.py @@ -5,6 +5,9 @@ from functools import * from .ast_utils import * import magma +from .visitors.collect_stores import collect_stores +from .cfg.types import HeadBlock +from .memory import MemoryType class Context: def __init__(self, name): @@ -31,7 +34,7 @@ def declare_reg(self, name, width=None, height=None): self.module.Reg(name, width, height) def assign(self, lhs, rhs): - return vg.Subst(lhs, rhs, self.is_reg(lhs)) + return vg.Subst(lhs, rhs, not self.is_reg(lhs)) def initial(self, body=[]): return self.module.Initial(body) @@ -70,6 +73,8 @@ def translate(self, stmt): return vg.And elif is_bit_xor(stmt): return vg.Xor + elif is_bit_or(stmt): + return vg.Or elif is_compare(stmt): assert(len(stmt.ops) == len(stmt.comparators) == 1) return self.translate(stmt.ops[0])( @@ -88,6 +93,8 @@ def translate(self, stmt): return vg.Unot elif is_lt(stmt): return vg.LessThan + elif is_gt(stmt): + return vg.GreaterThan elif is_name(stmt): return self.get_by_name(stmt.id) elif is_name_constant(stmt): @@ -131,14 +138,15 @@ def visit_Subscript(self, node): node.slice.upper = ast.Num(0) return node -class RemoveMagmaFuncs(ast.NodeTransformer): +class RemoveFuncs(ast.NodeTransformer): def visit_Call(self, node): - if isinstance(node.func, ast.Name) and node.func.id in ['bits', 'uint', 'bit']: - return node.args[0] + if isinstance(node.func, ast.Name) and node.func.id in ['bits', 'uint', + 'bit', 'phi']: + return self.visit(node.args[0]) return node def process_statement(stmt): - RemoveMagmaFuncs().visit(stmt) + RemoveFuncs().visit(stmt) SwapSlices().visit(stmt) return stmt @@ -172,13 +180,16 @@ def visit_Subscript(self, node): return node -def compile_statements(ctx, seq, states, one_state, width_table, statements): +def compile_statements(ctx, seq, comb_body, states, one_state, width_table, + registers, statements): module = ctx.module # temp_var_promoter = TempVarPromoter(width_table) for statement in statements: conds = [] yields = set() contained = [state for state in states if statement in state.statements] + stores = collect_stores(statement) + has_reg = any(x in registers for x in stores) if contained != states: for state in states: if statement in state.statements: @@ -193,29 +204,91 @@ def compile_statements(ctx, seq, states, one_state, width_table, statements): # conds.append(" & ".join(these_conds)) if not one_state: conds = [module.get_vars()["yield_state"] == yield_id for yield_id in yields] - process_statement(statement) - if conds: - cond = reduce(vg.Lor, conds) - seq.If(cond)( - vg.Subst(ctx.translate(statement.targets[0]), ctx.translate(statement.value), 1) - ) + statement = process_statement(statement) + if has_reg: + stmt = vg.Subst(ctx.translate(statement.targets[0]), ctx.translate(statement.value), 1) + + if conds: + cond = reduce(vg.Lor, conds) + seq.If(cond)( + stmt + ) + else: + seq(stmt) else: - seq( + comb_body.append( vg.Subst(ctx.translate(statement.targets[0]), ctx.translate(statement.value), 1) ) else: process_statement(statement) - seq( - vg.Subst(ctx.translate(statement.targets[0]), ctx.translate(statement.value), 1) - ) + try: + stmt = vg.Subst(ctx.translate(statement.targets[0]), ctx.translate(statement.value), 1) + except Exception as e: + ctx.module.Always(vg.SensitiveAll())(comb_body) + print(ctx.module.to_verilog()) + raise e + if has_reg: + seq(stmt) + else: + comb_body.append(stmt) -def compile_states(ctx, states, one_state, width_table, strategy="by_statement"): +def compile_states(ctx, states, one_state, width_table, registers, + strategy="by_statement"): module = ctx.module seq = vg.TmpSeq(module, module.get_ports()["CLK"]) comb_body = [] if strategy == "by_statement": + seen = set() + for i, state in enumerate(states): + if isinstance(state.path[0], HeadBlock): + continue + if state.path[0] not in seen: + for target, value in state.path[0].loads.items(): + if (target, value) not in seen: + seen.add((target, value)) + width = width_table[value] + if isinstance(width, MemoryType): + for i in range(width.height): + comb_body.append( + ctx.assign(vg.Pointer(ctx.get_by_name(target), i), + vg.Pointer(ctx.get_by_name(value), i))) + comb_body[-1].blk = 1 + else: + comb_body.append(ctx.assign(ctx.get_by_name(target), ctx.get_by_name(value))) + comb_body[-1].blk = 1 + for value, target in state.path[0].stores.items(): + if (target, value) not in seen: + seen.add((target, value)) + width = width_table[value] + if target in registers: + if not one_state: + cond = ctx.get_by_name('yield_state_next') == state.start_yield_id + if isinstance(width, MemoryType): + if_body = [] + for i in range(width.height): + if_body.append(ctx.assign(vg.Pointer(ctx.get_by_name(target), i), + vg.Pointer(ctx.get_by_name(value), i))) + seq.If(cond)(if_body) + else: + seq.If(cond)(ctx.assign(ctx.get_by_name(target), ctx.get_by_name(value))) + else: + if isinstance(width, MemoryType): + for i in range(width.height): + seq(ctx.assign(vg.Pointer(ctx.get_by_name(target), i), + vg.Pointer(ctx.get_by_name(value), i))) + else: + seq(ctx.assign(ctx.get_by_name(target), ctx.get_by_name(value))) + else: + if isinstance(width, MemoryType): + for i in range(width.height): + comb_body.append(ctx.assign( + vg.Pointer(ctx.get_by_name(target), i), + vg.Pointer(ctx.get_by_name(value), i)) + ) + else: + comb_body.append(ctx.assign(ctx.get_by_name(target), ctx.get_by_name(value))) statements = [] for state in states: index = len(statements) @@ -224,8 +297,9 @@ def compile_states(ctx, states, one_state, width_table, strategy="by_statement") index = statements.index(statement) else: statements.insert(index, statement) - compile_statements(ctx, seq, states, one_state, width_table, statements) + compile_statements(ctx, seq, comb_body, states, one_state, width_table, registers, statements) if not one_state: + next_state = None for i, state in enumerate(states): conds = [] if state.conds: @@ -240,18 +314,27 @@ def compile_states(ctx, states, one_state, width_table, strategy="by_statement") output_stmts = [] for output, var in state.path[-1].output_map.items(): output_stmts.append(ctx.assign(ctx.get_by_name(output), ctx.get_by_name(var))) + output_stmts[-1].blk = 1 stmts = [] - stmts.append(ctx.assign(ctx.get_by_name('yield_state'), state.end_yield_id)) + next_yield = ctx.assign(ctx.get_by_name('yield_state_next'), state.end_yield_id) + if next_state is None: + next_state = vg.If(cond)(next_yield) + else: + next_state.Elif(cond)(next_yield) for stmt in state.path[-1].array_stores_to_process: stmts.append(ctx.translate(process_statement(stmt))) if_stmt(stmts) - comb_cond = reduce(vg.Land, conds, ctx.get_by_name('yield_state') == state.end_yield_id) + comb_cond = reduce( + vg.Land, + conds + [ctx.get_by_name('yield_state_next') == state.end_yield_id], + ctx.get_by_name('yield_state') == state.start_yield_id) if i == 0: comb_body.append(vg.If(comb_cond)(output_stmts)) else: comb_body[-1] = comb_body[-1].Elif(comb_cond)(output_stmts) + comb_body.insert(0, next_state) else: for output, var in states[0].path[-1].output_map.items(): @@ -263,3 +346,7 @@ def compile_states(ctx, states, one_state, width_table, strategy="by_statement") ctx.module.Always(vg.SensitiveAll())( comb_body ) + if not one_state: + seq( + ctx.assign(ctx.get_by_name('yield_state'), ctx.get_by_name('yield_state_next')) + ) diff --git a/silica/visitors/__init__.py b/silica/visitors/__init__.py index 6af182a..c4c5d2b 100644 --- a/silica/visitors/__init__.py +++ b/silica/visitors/__init__.py @@ -1,2 +1,3 @@ from .collect_names import collect_names from .collect_stores import collect_stores +from .collect_loads import collect_loads diff --git a/silica/visitors/collect_loads.py b/silica/visitors/collect_loads.py new file mode 100644 index 0000000..3156ab0 --- /dev/null +++ b/silica/visitors/collect_loads.py @@ -0,0 +1,20 @@ +import ast + +class LoadCollector(ast.NodeVisitor): + def __init__(self): + super().__init__() + self.names = set() + + def visit_Assign(self, node): + # skip stores + self.visit(node.value) + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load): + self.names.add(node.id) + + +def collect_loads(tree): + visitor = LoadCollector() + visitor.visit(tree) + return visitor.names diff --git a/tests/test_counter.py b/tests/test_counter.py index 276c5a9..cfba6ed 100644 --- a/tests/test_counter.py +++ b/tests/test_counter.py @@ -8,10 +8,11 @@ @silica.coroutine def SilicaCounter(width, init=0, incr=1): - O = bits(init, width) + count = bits(init, width) while True: + O = count + count = count + bits(incr, width) yield O - O = O + bits(incr, width) def test_counter(): N = 3 diff --git a/tests/test_detect.py b/tests/test_detect.py index 3e1087e..1a2ab90 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -9,15 +9,15 @@ @silica.coroutine(inputs={"I" : Bit}) def SIDetect111(): cnt = uint(0, 2) - O = bit(0) + I = yield while True: - I = yield O if (I): if (cnt<3): cnt = cnt+1 else: cnt = 0 O = (cnt == 3) + I = yield O @silica.coroutine def inputs_generator(inputs): @@ -26,9 +26,9 @@ def inputs_generator(inputs): I = i yield I -inputs = list(map(bool, [1,1,0,1,1,1,0,1,0,1,1,1,1,1,1])) def test_detect111(): detect = SIDetect111() + inputs = list(map(bool, [1,1,0,1,1,1,0,1,0,1,1,1,1,1,1])) outputs = list(map(bool, [0,0,0,0,0,1,0,0,0,0,0,1,1,1,1])) si_detect = silica.compile(detect, file_name="tests/build/si_detect.v") # si_detect = m.DefineFromVerilogFile("tests/build/si_detect.v", @@ -36,9 +36,10 @@ def test_detect111(): tester = fault.Tester(si_detect, si_detect.CLK) for i, o in zip(inputs, outputs): tester.poke(si_detect.I, i) - tester.step(2) + tester.eval() tester.expect(si_detect.O, o) assert o == detect.send(i) + tester.step(2) tester.compile_and_run(target="verilator", directory="tests/build", flags=['-Wno-fatal'], include_verilog_libraries=['../cells_sim.v']) diff --git a/tests/test_fifo.py b/tests/test_fifo.py index ab36ec9..a523bf8 100644 --- a/tests/test_fifo.py +++ b/tests/test_fifo.py @@ -12,19 +12,17 @@ def SilicaFifo(): buffer = memory(4, 4) raddr = uint(0, 3) waddr = uint(0, 3) - rdata = buffer[raddr] - full = bit(0) - empty = bit(1) + wdata, wen, ren = yield while True: - wdata, wen, ren = yield rdata, empty, full + full = (waddr[:2] == raddr[:2]) & (waddr[2] != raddr[2]) + empty = waddr == raddr + rdata = buffer[raddr[:2]] if wen & ~full: buffer[waddr[:2]] = wdata waddr = waddr + uint(1, 3) if ren & ~empty: raddr = raddr + uint(1, 3) - rdata = buffer[raddr[:2]] - full = (waddr[:2] == raddr[:2]) & (waddr[2] != raddr[2]) - empty = waddr == raddr + wdata, wen, ren = yield rdata, empty, full # buffer = memory(4, 4) # raddr = uint(0, 3) @@ -58,20 +56,20 @@ def SilicaFifo(): # {'wdata': 13, 'wen': 0, 'ren': 1, 'rdata': 10, 'full': False, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 1, 'waddr': 0}, # {'wdata': 14, 'wen': 1, 'ren': 1, 'rdata': 11, 'full': False, 'empty': False, 'buffer': [14, 10, 11, 12], 'raddr': 2, 'waddr': 1}, {'wdata': 1, 'wen': 0, 'ren': 1, 'rdata': 0, 'full': False, 'empty': True, 'buffer': [0, 0, 0, 0], 'raddr': 0, 'waddr': 0}, - {'wdata': 2, 'wen': 1, 'ren': 0, 'rdata': 2, 'full': False, 'empty': False, 'buffer': [2, 0, 0, 0], 'raddr': 0, 'waddr': 1}, - {'wdata': 3, 'wen': 1, 'ren': 1, 'rdata': 3, 'full': False, 'empty': False, 'buffer': [2, 3, 0, 0], 'raddr': 1, 'waddr': 2}, + {'wdata': 2, 'wen': 1, 'ren': 0, 'rdata': 0, 'full': False, 'empty': True, 'buffer': [2, 0, 0, 0], 'raddr': 0, 'waddr': 1}, + {'wdata': 3, 'wen': 1, 'ren': 1, 'rdata': 2, 'full': False, 'empty': False, 'buffer': [2, 3, 0, 0], 'raddr': 1, 'waddr': 2}, {'wdata': 4, 'wen': 1, 'ren': 0, 'rdata': 3, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 0], 'raddr': 1, 'waddr': 3}, - {'wdata': 5, 'wen': 0, 'ren': 1, 'rdata': 4, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 0], 'raddr': 2, 'waddr': 3}, - {'wdata': 6, 'wen': 0, 'ren': 1, 'rdata': 0, 'full': False, 'empty': True, 'buffer': [2, 3, 4, 0], 'raddr': 3, 'waddr': 3}, - {'wdata': 7, 'wen': 1, 'ren': 0, 'rdata': 7, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 7], 'raddr': 3, 'waddr': 4}, - {'wdata': 8, 'wen': 0, 'ren': 1, 'rdata': 2, 'full': False, 'empty': True, 'buffer': [2, 3, 4, 7], 'raddr': 4, 'waddr': 4}, - {'wdata': 9, 'wen': 1, 'ren': 1, 'rdata': 9, 'full': False, 'empty': False, 'buffer': [9, 3, 4, 7], 'raddr': 4, 'waddr': 5}, + {'wdata': 5, 'wen': 0, 'ren': 1, 'rdata': 3, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 0], 'raddr': 2, 'waddr': 3}, + {'wdata': 6, 'wen': 0, 'ren': 1, 'rdata': 4, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 0], 'raddr': 3, 'waddr': 3}, + {'wdata': 7, 'wen': 1, 'ren': 0, 'rdata': 0, 'full': False, 'empty': True, 'buffer': [2, 3, 4, 7], 'raddr': 3, 'waddr': 4}, + {'wdata': 8, 'wen': 0, 'ren': 1, 'rdata': 7, 'full': False, 'empty': False, 'buffer': [2, 3, 4, 7], 'raddr': 4, 'waddr': 4}, + {'wdata': 9, 'wen': 1, 'ren': 1, 'rdata': 2, 'full': False, 'empty': True, 'buffer': [9, 3, 4, 7], 'raddr': 4, 'waddr': 5}, {'wdata': 10, 'wen': 1, 'ren': 0, 'rdata': 9, 'full': False, 'empty': False, 'buffer': [9, 10, 4, 7], 'raddr': 4, 'waddr': 6}, {'wdata': 11, 'wen': 1, 'ren': 0, 'rdata': 9, 'full': False, 'empty': False, 'buffer': [9, 10, 11, 7], 'raddr': 4, 'waddr': 7}, - {'wdata': 12, 'wen': 1, 'ren': 0, 'rdata': 9, 'full': True, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 4, 'waddr': 0}, + {'wdata': 12, 'wen': 1, 'ren': 0, 'rdata': 9, 'full': False, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 4, 'waddr': 0}, {'wdata': 13, 'wen': 1, 'ren': 0, 'rdata': 9, 'full': True, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 4, 'waddr': 0}, - {'wdata': 13, 'wen': 0, 'ren': 1, 'rdata': 10, 'full': False, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 5, 'waddr': 0}, - {'wdata': 14, 'wen': 1, 'ren': 1, 'rdata': 11, 'full': False, 'empty': False, 'buffer': [14, 10, 11, 12], 'raddr': 6, 'waddr': 1}, + {'wdata': 13, 'wen': 0, 'ren': 1, 'rdata': 9, 'full': True, 'empty': False, 'buffer': [9, 10, 11, 12], 'raddr': 5, 'waddr': 0}, + {'wdata': 14, 'wen': 1, 'ren': 1, 'rdata': 10, 'full': False, 'empty': False, 'buffer': [14, 10, 11, 12], 'raddr': 6, 'waddr': 1}, ] @si.coroutine @@ -100,11 +98,12 @@ def test_fifo(): tester.poke(si_fifo.interface.ports[input_], trace[input_]) tester.print(si_fifo.interface.ports[input_]) fifo.send(args) - tester.step(2) + tester.step(1) for output in outputs: assert getattr(fifo, output) == trace[output], (i, output, getattr(fifo, output), trace[output]) tester.expect(si_fifo.interface.ports[output], trace[output]) tester.print(si_fifo.interface.ports[output]) + tester.step(1) for state in states: assert getattr(fifo, state) == trace[state], (i, state, getattr(fifo, state), trace[state]) diff --git a/tests/test_lbmem.py b/tests/test_lbmem.py index 8f2e46d..7b2df78 100644 --- a/tests/test_lbmem.py +++ b/tests/test_lbmem.py @@ -63,12 +63,32 @@ def DrainingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen): @si.coroutine(inputs={"wdata": si.Bits(8), "wen": si.Bit}) def SILbMem(depth=64, lbmem_width=8): lbmem = memory(depth, lbmem_width) - raddr = uint(0, eval(math.ceil(math.log2(depth)))) waddr = uint(0, eval(math.ceil(math.log2(depth)))) + count = uint(0, 3) wdata, wen = yield while True: - waddr = yield from FillingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen) - waddr, raddr = yield from DrainingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen) + while (count < uint(7, 3)) | ~wen: + rdata = lbmem[waddr - uint(count, 6)] + valid = bit(0) + if wen: + lbmem[waddr] = wdata + count = count + 1 + waddr = waddr + 1 + wdata, wen = yield rdata, valid + lbmem[waddr] = wdata + while count > 0: + valid = bit(1) + rdata = lbmem[waddr - uint(count, 6)] + if ~wen: + count = count - 1 + if wen: + lbmem[waddr] = wdata + waddr = waddr + 1 + wdata, wen = yield rdata, valid + + # while True: + # waddr = yield from FillingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen) + # waddr, raddr = yield from DrainingState(lbmem_width, depth, lbmem, raddr, waddr, wdata, wen) @si.coroutine def inputs_generator(inputs): @@ -92,17 +112,19 @@ def test_lbmem(): tester.poke(si_lbmem.wdata, i) tester.poke(si_lbmem.wen, 1) lbmem.send((i, BitVector(1))) - tester.step(2) + tester.step(1) assert lbmem.valid == (i == 7) # tester.print(si_lbmem.valid) tester.expect(si_lbmem.valid, i == 7), "Valid only on last write" assert lbmem.rdata == 0, "should be 0, even on first read" tester.expect(si_lbmem.rdata, 0) + tester.step(1) for i in range(0, 8): tester.poke(si_lbmem.wdata, 0) tester.poke(si_lbmem.wen, 0) lbmem.send((0, BitVector(0))) - tester.step(2) + # print(lbmem.lbmem) + tester.step(1) # i + 1 because 0 was read on the final write, 7 when i == 7 because # there's nothing left to drain rdata = i + 1 if i < 7 else 0 @@ -112,7 +134,7 @@ def test_lbmem(): tester.expect(si_lbmem.rdata, rdata) # tester.print(si_lbmem.valid) tester.expect(si_lbmem.valid, valid) - tester.eval() + tester.step(1) # print(f"inputs={inputs}, rdata={lbmem.rdata}, valid={lbmem.valid}") # wen = inputs[1] diff --git a/tests/test_piso.py b/tests/test_piso.py index 0467057..4795956 100644 --- a/tests/test_piso.py +++ b/tests/test_piso.py @@ -1,3 +1,4 @@ +from magma import * import silica import magma as m m.set_mantle_target("ice40") @@ -15,9 +16,8 @@ def DefinePISO(n): def SIPISO(): values = bits(0, n) # O = values[-1] - O = values[n-1] + PI, SI, LOAD = yield while True: - PI, SI, LOAD = yield O if LOAD: values = PI else: @@ -26,6 +26,7 @@ def SIPISO(): values[i] = values[i - 1] values[0] = SI O = values[n-1] + PI, SI, LOAD = yield O SIPISO._name = f"SIPISO{n}" return SIPISO @@ -46,46 +47,76 @@ def inputs_generator(message): def test_PISO(): piso = DefinePISO(10)() - piso_magma = silica.compile(piso, "tests/build/piso_si.v") - tester = fault.Tester(piso_magma, piso_magma.CLK) + si_piso = silica.compile(piso, "tests/build/si_piso.v") + # si_piso = m.DefineFromVerilogFile("tests/build/si_piso.v", + # type_map={"CLK": m.In(m.Clock)})[0] + tester = fault.Tester(si_piso, si_piso.CLK) message = [0xDE, 0xAD, 0xBE, 0xEF] inputs = inputs_generator(message) for i in range(len(message)): - expected_outputs = inputs.PI[:] + actual_outputs = [] + expected_outputs = [False] + inputs.PI[:] expected_state = inputs.PI[:] piso.send((inputs.PI, inputs.SI, inputs.LOAD)) + actual_outputs.insert(0, piso.O) # print(f"PI={inputs.PI}, SI={inputs.SI}, LOAD={inputs.LOAD}, O={piso.O}, values={piso.values}") - tester.poke(piso_magma.PI , BitVector(inputs.PI)) - tester.poke(piso_magma.SI , BitVector(inputs.SI)) - tester.poke(piso_magma.LOAD, BitVector(inputs.LOAD)) + tester.poke(si_piso.PI , BitVector(inputs.PI)) + tester.poke(si_piso.SI , BitVector(inputs.SI)) + tester.poke(si_piso.LOAD, BitVector(inputs.LOAD)) next(inputs) tester.step(2) - actual_outputs = [] - for i in range(10): - expected_state = [inputs.SI] + expected_state[:-1] - tester.poke(piso_magma.PI , BitVector(inputs.PI)) - tester.poke(piso_magma.SI , BitVector(inputs.SI)) - tester.poke(piso_magma.LOAD, BitVector(inputs.LOAD)) - actual_outputs.insert(0, piso.O) + tester.expect(si_piso.O, piso.O) + for j in range(10): + tester.poke(si_piso.PI , BitVector(inputs.PI)) + tester.poke(si_piso.SI , BitVector(inputs.SI)) + tester.poke(si_piso.LOAD, BitVector(inputs.LOAD)) + assert piso.values == expected_state, (i, j) piso.send((inputs.PI, inputs.SI, inputs.LOAD)) - tester.step(2) - tester.expect(piso_magma.O, piso.O) - assert piso.values == expected_state + actual_outputs.insert(0, piso.O) + expected_state = [inputs.SI] + expected_state[:-1] + tester.step(1) + tester.expect(si_piso.O, piso.O) + tester.step(1) next(inputs) assert actual_outputs == expected_outputs tester.compile_and_run(target="verilator", directory="tests/build", flags=['-Wno-fatal']) # check that they are the same - mantle_PISO = mantle.DefinePISO(10) - mantle_tester = tester.retarget(mantle_PISO, clock=mantle_PISO.CLK) - m.compile("tests/build/mantle_piso", mantle_PISO) + + # mantle_PISO = mantle.DefinePISO(10) + from mantle import Mux + from mantle.common.register import _RegisterName, Register, FFs + T = Bits(10) + n = 10 + init = 0 + has_ce = False + has_reset = False + class _PISO(Circuit): + name = _RegisterName('PISO', n, init, has_ce, has_reset) + IO = ['SI', In(Bit), 'PI', In(T), 'LOAD', In(Bit), + 'O', Out(Bit)] + ClockInterface(has_ce,has_reset) + @classmethod + def definition(piso): + def mux2(y): + return curry(Mux(2), prefix='I') + + mux = braid(col(mux2, n), forkargs=['S']) + reg = Register(n - 1, init, has_ce=has_ce, has_reset=has_reset) + + si = concat(array(piso.SI),reg.O[0:n-1]) + mux(si, piso.PI, piso.LOAD) + reg(mux.O[:-1]) + wire(mux.O[-1], piso.O) + wireclock(piso, reg) + mantle_tester = tester.retarget(_PISO, clock=_PISO.CLK) + m.compile("tests/build/mantle_piso", _PISO) mantle_tester.compile_and_run(target="verilator", directory="tests/build", include_verilog_libraries=['../cells_sim.v']) if __name__ == '__main__': print("===== BEGIN : SILICA RESULTS =====") - evaluate_circuit("piso_si", "SIPISO10") + evaluate_circuit("si_piso", "SIPISO10") print("===== END : SILICA RESULTS =====") print("===== BEGIN : MANTLE RESULTS =====") evaluate_circuit("mantle_piso", "PISO10") diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 4e20034..66881ca 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -14,9 +14,8 @@ def Serializer4(): data1 = bits(0, 16) data2 = bits(0, 16) # I0, I1, I2, I3 = yield - O = bits(0, 16) + I0, I1, I2, I3 = yield while True: - I0, I1, I2, I3 = yield O O = I0 # data = I[1:] data0 = I1 @@ -28,6 +27,7 @@ def Serializer4(): O = data1 I0, I1, I2, I3 = yield O O = data2 + I0, I1, I2, I3 = yield O # for i in range(3): # O = data[i] # I0, I1, I2, I3 = yield O @@ -54,18 +54,20 @@ def test_ser4(): for I in inputs: for j in range(len(I)): tester.poke(getattr(serializer_si, f"I{j}"), I[j]) - tester.step(2) + tester.step(1) ser.send(I) assert ser.O == I[0] tester.print(serializer_si.O) tester.expect(serializer_si.O, I[0]) + tester.step(1) for i in range(3): for j in range(len(I)): - tester.poke(getattr(serializer_si, f"I{j}"), I[j]) - tester.step(2) + tester.poke(getattr(serializer_si, f"I{j}"), 0) + tester.step(1) ser.send([0,0,0,0]) assert ser.O == I[i + 1] tester.expect(serializer_si.O, I[i + 1]) + tester.step(1) tester.compile_and_run(target="verilator", directory="tests/build", flags=['-Wno-fatal']) diff --git a/tests/test_tff.py b/tests/test_tff.py index 34ac3f4..5d5938b 100644 --- a/tests/test_tff.py +++ b/tests/test_tff.py @@ -8,10 +8,12 @@ @silica.coroutine(inputs={"I": silica.Bit}) def TFF(init=0): - O = bit(init) + state = bit(init) + I = yield while True: + O = state + state = I ^ state I = yield O - O = I ^ O def test_TFF(): @@ -22,8 +24,8 @@ def test_TFF(): tester = fault.Tester(si_tff, si_tff.CLK) # Should toggle for i in range(5): - assert tff.O == i % 2 tff.send(True) + assert tff.O == i % 2 tester.poke(si_tff.I, 1) tester.expect(si_tff.O, i % 2) tester.step(2) diff --git a/tests/test_uart.py b/tests/test_uart.py index 4d7b154..f23ffa5 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -24,7 +24,6 @@ @silica.coroutine(inputs={"data": silica.Bits(8), "valid": silica.Bit}) def uart_transmitter(): - message = bits(0, 8) data, valid = yield while True: if valid: @@ -54,29 +53,32 @@ def uart_transmitter(): def test_UART(): uart = uart_transmitter() - si_uart = silica.compile(uart, "tests/build/uart.v") + si_uart = silica.compile(uart, "tests/build/si_uart.v") + # si_uart = m.DefineFromVerilogFile("tests/build/si_uart.v", + # type_map={"CLK": m.In(m.Clock)})[0] tester = fault.Tester(si_uart, si_uart.CLK) tester.step(2) for message in [0xDE, 0xAD]: + tester.step(1) tester.expect(si_uart.ready, 1) tester.poke(si_uart.data, message) tester.poke(si_uart.valid, 1) tester.step(2) + tester.expect(si_uart.tx, 0) tester.poke(si_uart.data, 0xFF) tester.poke(si_uart.valid, 0) tester.expect(si_uart.ready, 0) + tester.step(1) # start bit - tester.expect(si_uart.tx, 0) for i in range(8): - tester.step(2) + tester.print(si_uart.CLK) tester.expect(si_uart.tx, (message >> (7-i)) & 1) - tester.step(2) + tester.step(2) # end bit tester.expect(si_uart.tx, 1) tester.step(2) tester.expect(si_uart.ready, 1) - tester.eval() tester.compile_and_run(target="verilator", directory="tests/build", flags=['-Wno-fatal']) @@ -88,7 +90,7 @@ def test_UART(): flags=['-Wno-fatal']) if __name__ == '__main__': print("===== BEGIN : SILICA RESULTS =====") - evaluate_circuit("uart", "uart_transmitter") + evaluate_circuit("si_uart", "uart_transmitter") print("===== END : SILICA RESULTS =====") import shutil shutil.copy('verilog/uart.v', 'tests/build/verilog_uart.v') diff --git a/verilog/detect111.v b/verilog/detect111.v index 06d16d5..cc1d51f 100644 --- a/verilog/detect111.v +++ b/verilog/detect111.v @@ -5,15 +5,15 @@ module detect111( ); reg [1:0] cnt = 0; - /* wire [1:0] cnt_next; */ + wire [1:0] cnt_next; always @(posedge CLK) begin cnt <= I ? (cnt==3 ? cnt : cnt+1) : 0; end // mealey version - /* assign cnt_next = I ? (cnt==3 ? cnt : cnt+1) : 0; */ - // assign O = (cnt_next==3); - assign O = (cnt==3); + assign cnt_next = I ? (cnt==3 ? cnt : cnt+1) : 0; + assign O = (cnt_next==3); + /* assign O = (cnt==3); */ endmodule diff --git a/verilog/lbmem.v b/verilog/lbmem.v index 8e2bbd9..99c4583 100644 --- a/verilog/lbmem.v +++ b/verilog/lbmem.v @@ -20,7 +20,7 @@ module lbmem( cnt <= cnt + {3'h0,wen}; end else begin - cnt <= wen ? cnt : cnt-1; + cnt <= wen ? cnt : cnt-1; end end @@ -32,11 +32,11 @@ module lbmem( state <= (cnt!=1 | wen); end end - assign valid = state; + assign valid = (state & (cnt!=1 | wen)) | (cnt == 7 & wen); reg [5:0] waddr = 0; wire [5:0] raddr; - assign raddr = waddr - {2'h0,cnt}; + assign raddr = waddr - {2'h0, wen ? cnt : cnt-1}; always @(posedge CLK) begin if (wen) begin diff --git a/verilog/serializer.v b/verilog/serializer.v index 05bd834..5b53729 100644 --- a/verilog/serializer.v +++ b/verilog/serializer.v @@ -7,7 +7,7 @@ module serializer( input [15:0] I1, input [15:0] I2, input [15:0] I3, - output reg [15:0] O + output [15:0] O ); reg [15:0] s1; reg [15:0] s2; @@ -16,21 +16,27 @@ module serializer( reg [1:0] cnt = 0; always @(posedge CLK) begin - cnt <= cnt + 1; + cnt <= (cnt==3) ? 0 : (cnt + 1); end always @(posedge CLK) begin + if (cnt==2'h0) begin + // s1 <= I[1]; + // s2 <= I[2]; + // s3 <= I[3]; + s1 <= I1; + s2 <= I2; + s3 <= I3; + end + end + + always @(*) begin case(cnt) // 2'h0 : O = I[0][15:0]; - 2'h0 : begin - s1 <= I1; - s2 <= I2; - s3 <= I3; - O <= I0; - end - 2'h1 : O <= s1; - 2'h2 : O <= s2; - 2'h3 : O <= s3; + 2'h0 : O = I0; + 2'h1 : O = s1; + 2'h2 : O = s2; + 2'h3 : O = s3; endcase end diff --git a/verilog/uart.v b/verilog/uart.v index ea68251..d1d7788 100644 --- a/verilog/uart.v +++ b/verilog/uart.v @@ -27,6 +27,9 @@ module uart_tx(input CLK, input [7:0] data, input valid, output tx, output ready state <= 0; end endcase + $display("valid=%d", valid); + $display("tx=%d", tx); + $display("send_cnt=%d", send_cnt); end - assign ready = state == 0; + assign ready = state == 0 & ~valid; endmodule