Permalink
Browse files

pytholite: move expression and register handling to separate modules

  • Loading branch information...
1 parent f59fd69 commit bf5ce8dc20339c93a201a38d02b9a59d58082278 @sbourdeauducq sbourdeauducq committed Nov 11, 2012
Showing with 169 additions and 152 deletions.
  1. +13 −152 migen/pytholite/compiler.py
  2. +104 −0 migen/pytholite/expr.py
  3. +52 −0 migen/pytholite/reg.py
View
@@ -1,62 +1,14 @@
import inspect
import ast
-from operator import itemgetter
from migen.fhdl.structure import *
from migen.fhdl.structure import _Slice
-from migen.fhdl import visit as fhdl
+from migen.pytholite.reg import *
+from migen.pytholite.expr import *
from migen.pytholite import transel
from migen.pytholite.io import make_io_object, gen_io
from migen.pytholite.fsm import *
-class FinalizeError(Exception):
- pass
-
-class _AbstractLoad:
- def __init__(self, target, source):
- self.target = target
- self.source = source
-
- def lower(self):
- if not self.target.finalized:
- raise FinalizeError
- return self.target.sel.eq(self.target.source_encoding[self.source])
-
-class _LowerAbstractLoad(fhdl.NodeTransformer):
- def visit_unknown(self, node):
- if isinstance(node, _AbstractLoad):
- return node.lower()
- else:
- return node
-
-class _Register:
- def __init__(self, name, nbits):
- self.name = name
- self.storage = Signal(BV(nbits), name=self.name)
- self.source_encoding = {}
- self.finalized = False
-
- def load(self, source):
- if source not in self.source_encoding:
- self.source_encoding[source] = len(self.source_encoding) + 1
- return _AbstractLoad(self, source)
-
- def finalize(self):
- if self.finalized:
- raise FinalizeError
- self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
- self.finalized = True
-
- def get_fragment(self):
- if not self.finalized:
- raise FinalizeError
- # do nothing when sel == 0
- items = sorted(self.source_encoding.items(), key=itemgetter(1))
- cases = [(Constant(v, self.sel.bv),
- self.storage.eq(k)) for k, v in items]
- sync = [Case(self.sel, *cases)]
- return Fragment(sync=sync)
-
def _is_name_used(node, name):
for n in ast.walk(node):
if isinstance(n, ast.Name) and n.id == name:
@@ -68,6 +20,7 @@ def __init__(self, ioo, symdict, registers):
self.ioo = ioo
self.symdict = symdict
self.registers = registers
+ self.ec = ExprCompiler(self.symdict)
def visit_top(self, node):
if isinstance(node, ast.Module) \
@@ -109,12 +62,15 @@ def visit_block(self, statements):
def visit_assign(self, sa, node, statements):
if isinstance(node.value, ast.Call):
+ is_special = False
try:
- value = self.visit_expr_call(node.value)
+ value = self.ec.visit_expr_call(node.value)
except NotImplementedError:
+ is_special = True
+ if is_special:
return self.visit_assign_special(sa, node, statements)
else:
- value = self.visit_expr(node.value)
+ value = self.ec.visit_expr(node.value)
if isinstance(value, Value):
r = []
for target in node.targets:
@@ -146,7 +102,7 @@ def visit_assign_special(self, sa, node, statements):
targetname = node.targets[0].id
else:
targetname = "unk"
- reg = _Register(targetname, nbits)
+ reg = ImplRegister(targetname, nbits)
self.registers.append(reg)
for target in node.targets:
if isinstance(target, ast.Name):
@@ -173,6 +129,7 @@ def visit_io_pattern(self, sa, targets, model, args, statements):
or not isinstance(ystatement.value, ast.Yield) \
or not isinstance(ystatement.value.value, ast.Name) \
or ystatement.value.value.id != modelname:
+ print(ast.dump(ystatement))
raise NotImplementedError("Unrecognized I/O pattern")
# following optional statements are assignments to registers
@@ -202,7 +159,7 @@ def visit_io_pattern(self, sa, targets, model, args, statements):
return fstatement
def visit_if(self, sa, node):
- test = self.visit_expr(node.test)
+ test = self.ec.visit_expr(node.test)
states_t, exit_states_t = self.visit_block(node.body)
states_f, exit_states_f = self.visit_block(node.orelse)
exit_states = exit_states_t + exit_states_f
@@ -218,7 +175,7 @@ def visit_if(self, sa, node):
exit_states)
def visit_while(self, sa, node):
- test = self.visit_expr(node.test)
+ test = self.ec.visit_expr(node.test)
states_b, exit_states_b = self.visit_block(node.body)
test_state = [If(test, AbstractNextState(states_b[0]))]
@@ -269,102 +226,6 @@ def visit_expr_statement(self, sa, node):
sa.assemble(states, exit_states)
else:
raise NotImplementedError
-
- # expressions
- def visit_expr(self, node):
- if isinstance(node, ast.Call):
- return self.visit_expr_call(node)
- elif isinstance(node, ast.BinOp):
- return self.visit_expr_binop(node)
- elif isinstance(node, ast.Compare):
- return self.visit_expr_compare(node)
- elif isinstance(node, ast.Name):
- return self.visit_expr_name(node)
- elif isinstance(node, ast.Num):
- return self.visit_expr_num(node)
- else:
- raise NotImplementedError
-
- def visit_expr_call(self, node):
- if isinstance(node.func, ast.Name):
- callee = self.symdict[node.func.id]
- else:
- raise NotImplementedError
- if callee == transel.bitslice:
- if len(node.args) != 2 and len(node.args) != 3:
- raise TypeError("bitslice() takes 2 or 3 arguments")
- val = self.visit_expr(node.args[0])
- low = ast.literal_eval(node.args[1])
- if len(node.args) == 3:
- up = ast.literal_eval(node.args[2])
- else:
- up = low + 1
- return _Slice(val, low, up)
- else:
- raise NotImplementedError
-
- def visit_expr_binop(self, node):
- left = self.visit_expr(node.left)
- right = self.visit_expr(node.right)
- if isinstance(node.op, ast.Add):
- return left + right
- elif isinstance(node.op, ast.Sub):
- return left - right
- elif isinstance(node.op, ast.Mult):
- return left * right
- elif isinstance(node.op, ast.LShift):
- return left << right
- elif isinstance(node.op, ast.RShift):
- return left >> right
- elif isinstance(node.op, ast.BitOr):
- return left | right
- elif isinstance(node.op, ast.BitXor):
- return left ^ right
- elif isinstance(node.op, ast.BitAnd):
- return left & right
- else:
- raise NotImplementedError
-
- def visit_expr_compare(self, node):
- test = self.visit_expr(node.left)
- r = None
- for op, rcomparator in zip(node.ops, node.comparators):
- comparator = self.visit_expr(rcomparator)
- if isinstance(op, ast.Eq):
- comparison = test == comparator
- elif isinstance(op, ast.NotEq):
- comparison = test != comparator
- elif isinstance(op, ast.Lt):
- comparison = test < comparator
- elif isinstance(op, ast.LtE):
- comparison = test <= comparator
- elif isinstance(op, ast.Gt):
- comparison = test > comparator
- elif isinstance(op, ast.GtE):
- comparison = test >= comparator
- else:
- raise NotImplementedError
- if r is None:
- r = comparison
- else:
- r = r & comparison
- test = comparator
- return r
-
- def visit_expr_name(self, node):
- if node.id == "True":
- return Constant(1)
- if node.id == "False":
- return Constant(0)
- r = self.symdict[node.id]
- if isinstance(r, _Register):
- r = r.storage
- if isinstance(r, int):
- r = Constant(r)
- return r
-
- def visit_expr_num(self, node):
- return Constant(node.n)
def make_pytholite(func, **ioresources):
ioo = make_io_object(**ioresources)
@@ -381,7 +242,7 @@ def make_pytholite(func, **ioresources):
regf += register.get_fragment()
fsm = implement_fsm(states)
- fsmf = _LowerAbstractLoad().visit(fsm.get_fragment())
+ fsmf = LowerAbstractLoad().visit(fsm.get_fragment())
ioo.fragment = regf + fsmf
return ioo
View
@@ -0,0 +1,104 @@
+import ast
+
+from migen.fhdl.structure import *
+from migen.pytholite import transel
+from migen.pytholite.reg import *
+
+class ExprCompiler:
+ def __init__(self, symdict):
+ self.symdict = symdict
+
+ def visit_expr(self, node):
+ if isinstance(node, ast.Call):
+ return self.visit_expr_call(node)
+ elif isinstance(node, ast.BinOp):
+ return self.visit_expr_binop(node)
+ elif isinstance(node, ast.Compare):
+ return self.visit_expr_compare(node)
+ elif isinstance(node, ast.Name):
+ return self.visit_expr_name(node)
+ elif isinstance(node, ast.Num):
+ return self.visit_expr_num(node)
+ else:
+ raise NotImplementedError
+
+ def visit_expr_call(self, node):
+ if isinstance(node.func, ast.Name):
+ callee = self.symdict[node.func.id]
+ else:
+ raise NotImplementedError
+ if callee == transel.bitslice:
+ if len(node.args) != 2 and len(node.args) != 3:
+ raise TypeError("bitslice() takes 2 or 3 arguments")
+ val = self.visit_expr(node.args[0])
+ low = ast.literal_eval(node.args[1])
+ if len(node.args) == 3:
+ up = ast.literal_eval(node.args[2])
+ else:
+ up = low + 1
+ return _Slice(val, low, up)
+ else:
+ raise NotImplementedError
+
+ def visit_expr_binop(self, node):
+ left = self.visit_expr(node.left)
+ right = self.visit_expr(node.right)
+ if isinstance(node.op, ast.Add):
+ return left + right
+ elif isinstance(node.op, ast.Sub):
+ return left - right
+ elif isinstance(node.op, ast.Mult):
+ return left * right
+ elif isinstance(node.op, ast.LShift):
+ return left << right
+ elif isinstance(node.op, ast.RShift):
+ return left >> right
+ elif isinstance(node.op, ast.BitOr):
+ return left | right
+ elif isinstance(node.op, ast.BitXor):
+ return left ^ right
+ elif isinstance(node.op, ast.BitAnd):
+ return left & right
+ else:
+ raise NotImplementedError
+
+ def visit_expr_compare(self, node):
+ test = self.visit_expr(node.left)
+ r = None
+ for op, rcomparator in zip(node.ops, node.comparators):
+ comparator = self.visit_expr(rcomparator)
+ if isinstance(op, ast.Eq):
+ comparison = test == comparator
+ elif isinstance(op, ast.NotEq):
+ comparison = test != comparator
+ elif isinstance(op, ast.Lt):
+ comparison = test < comparator
+ elif isinstance(op, ast.LtE):
+ comparison = test <= comparator
+ elif isinstance(op, ast.Gt):
+ comparison = test > comparator
+ elif isinstance(op, ast.GtE):
+ comparison = test >= comparator
+ else:
+ raise NotImplementedError
+ if r is None:
+ r = comparison
+ else:
+ r = r & comparison
+ test = comparator
+ return r
+
+ def visit_expr_name(self, node):
+ if node.id == "True":
+ return Constant(1)
+ if node.id == "False":
+ return Constant(0)
+ r = self.symdict[node.id]
+ if isinstance(r, ImplRegister):
+ r = r.storage
+ if isinstance(r, int):
+ r = Constant(r)
+ return r
+
+ def visit_expr_num(self, node):
+ return Constant(node.n)
View
@@ -0,0 +1,52 @@
+from operator import itemgetter
+
+from migen.fhdl.structure import *
+from migen.fhdl import visit as fhdl
+
+class FinalizeError(Exception):
+ pass
+
+class AbstractLoad:
+ def __init__(self, target, source):
+ self.target = target
+ self.source = source
+
+ def lower(self):
+ if not self.target.finalized:
+ raise FinalizeError
+ return self.target.sel.eq(self.target.source_encoding[self.source])
+
+class LowerAbstractLoad(fhdl.NodeTransformer):
+ def visit_unknown(self, node):
+ if isinstance(node, AbstractLoad):
+ return node.lower()
+ else:
+ return node
+
+class ImplRegister:
+ def __init__(self, name, nbits):
+ self.name = name
+ self.storage = Signal(BV(nbits), name=self.name)
+ self.source_encoding = {}
+ self.finalized = False
+
+ def load(self, source):
+ if source not in self.source_encoding:
+ self.source_encoding[source] = len(self.source_encoding) + 1
+ return AbstractLoad(self, source)
+
+ def finalize(self):
+ if self.finalized:
+ raise FinalizeError
+ self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
+ self.finalized = True
+
+ def get_fragment(self):
+ if not self.finalized:
+ raise FinalizeError
+ # do nothing when sel == 0
+ items = sorted(self.source_encoding.items(), key=itemgetter(1))
+ cases = [(Constant(v, self.sel.bv),
+ self.storage.eq(k)) for k, v in items]
+ sync = [Case(self.sel, *cases)]
+ return Fragment(sync=sync)

0 comments on commit bf5ce8d

Please sign in to comment.