Skip to content

Commit

Permalink
Fix #127
Browse files Browse the repository at this point in the history
Lookup cellvars in a more direct way
  • Loading branch information
jandecaluwe committed Sep 27, 2015
1 parent beed676 commit 0bc41cc
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 2 deletions.
6 changes: 4 additions & 2 deletions myhdl/_extractHierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from myhdl._util import _isGenFunc, _flatten, _genfunc
from myhdl._misc import _isGenSeq
from myhdl._resolverefs import _resolveRefs
from myhdl._getcellvars import _getCellVars


_profileFunc = None
Expand Down Expand Up @@ -316,9 +317,8 @@ def extractor(self, frame, event, arg):
symdict = frame.f_globals.copy()
symdict.update(frame.f_locals)
cellvars = []
cellvars.extend(frame.f_code.co_cellvars)

#All nested functions will be in co_consts
# All nested functions will be in co_consts
if func:
local_gens = []
consts = func.__code__.co_consts
Expand All @@ -327,6 +327,8 @@ def extractor(self, frame, event, arg):
if genfunc.__code__ in consts:
local_gens.append(item)
if local_gens:
cellvarlist = _getCellVars(symdict, local_gens)
cellvars.extend(cellvarlist)
objlist = _resolveRefs(symdict, local_gens)
cellvars.extend(objlist)
#for dict in (frame.f_globals, frame.f_locals):
Expand Down
33 changes: 33 additions & 0 deletions myhdl/_getcellvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import absolute_import
import ast
import itertools
from types import FunctionType

from myhdl._util import _flatten
from myhdl._enum import EnumType
from myhdl._Signal import SignalType


class Data():
pass

def _getCellVars(symdict, arg):
gens = _flatten(arg)
data = Data()
data.symdict = symdict
v = _GetCellVars(data)
for gen in gens:
v.visit(gen.ast)
return list(data.objset)

class _GetCellVars(ast.NodeVisitor):
def __init__(self, data):
self.data = data
self.data.objset = set()

def visit_Name(self, node):

if node.id in self.data.symdict:
self.data.objset.add(node.id)

self.generic_visit(node)
148 changes: 148 additions & 0 deletions myhdl/test/bugs/test_issue_127.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
''' Bitonic sort '''

# http://www.myhdl.org/examples/bitonic/

from __future__ import absolute_import

import unittest
from random import randrange

from myhdl import Signal, intbv, \
always_comb, instance, \
delay, toVHDL, StopSimulation


ASCENDING = True
DESCENDING = False


# modules

def compare(a_1, a_2, z_1, z_2, direction):
""" Combinatorial circuit with two input and two output signals.
Sorting to 'direction'. """

@always_comb
def logic():
''' Combinatorial logic '''
if direction == (a_1 > a_2):
z_1.next = a_2
z_2.next = a_1
else:
z_1.next = a_1
z_2.next = a_2

return logic


def feedthru(in_a, out_z):
""" Equivalent of 'doing nothing'. """

@always_comb
def logic():
''' Combinatorial logic '''
out_z.next = in_a

return logic


def bitonic_merge(list_a, list_z, direction):
""" bitonicMerge:
Generates the output from the input list of signals.
Recursive. """
len_list = len(list_a)
half_len = len_list//2
width = len(list_a[0])

if len_list > 1:
tmp = [Signal(intbv(0)[width:]) for _ in range(len_list)]

comp = [compare(list_a[i], list_a[i+half_len], tmp[i], tmp[i+half_len], \
direction) for i in range(half_len)]

lo_merge = bitonic_merge( tmp[:half_len], list_z[:half_len], direction )
hi_merge = bitonic_merge( tmp[half_len:], list_z[half_len:], direction )

return comp, lo_merge, hi_merge
else:
feed = feedthru(list_a[0], list_z[0])
return feed


def bitonic_sort(list_a, list_z, direction):
""" bitonicSort:
Produces a bitonic sequence.
Recursive. """
len_list = len(list_a)
half_len = len_list//2
width = len(list_a[0])

if len_list > 1:
tmp = [Signal(intbv(0)[width:]) for _ in range(len_list)]

lo_sort = bitonic_sort( list_a[:half_len], tmp[:half_len], ASCENDING )
hi_sort = bitonic_sort( list_a[half_len:], tmp[half_len:], DESCENDING )

merge = bitonic_merge( tmp, list_z, direction )
return lo_sort, hi_sort, merge
else:
feed = feedthru(list_a[0], list_z[0])
return feed


# tests

def array8sorter(a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7,
z_0, z_1, z_2, z_3, z_4, z_5, z_6, z_7):
''' Sort Array with 8 values '''

list_a = [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7]
list_z = [z_0, z_1, z_2, z_3, z_4, z_5, z_6, z_7]

sort = bitonic_sort(list_a, list_z, ASCENDING)
return sort


class TestBitonicSort(unittest.TestCase):
''' Test class for bitonic sort '''

def test_sort(self):
""" Check the functionality of the bitonic sort """
length = 8
width = 4

def test_impl():
''' test implementation '''
inputs = [ Signal(intbv(0)[width:]) for _ in range(length) ]
outputs = [ Signal(intbv(0)[width:]) for _ in range(length) ]
z_0, z_1, z_2, z_3, z_4, z_5, z_6, z_7 = outputs
a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7 = inputs

inst = array8sorter(a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7,
z_0, z_1, z_2, z_3, z_4, z_5, z_6, z_7)

@instance
def check():
''' testbench input and validation '''
for i in range(100):
data = [randrange(2**width) for i in range(length)]
for i in range(length):
inputs[i].next = data[i]
yield delay(10)
data.sort()
self.assertEqual(data, outputs, 'wrong data')
raise StopSimulation

return inst, check



# convert

def test_issue_127():
''' Convert to VHDL '''
length = 8
width = 4
sigs = [Signal(intbv(0)[width:]) for _ in range(2*length)]
toVHDL(array8sorter, *sigs)

0 comments on commit 0bc41cc

Please sign in to comment.