Skip to content

Commit

Permalink
Add duplicates param to Pattern.src_idx().
Browse files Browse the repository at this point in the history
  • Loading branch information
lebedov committed Mar 20, 2016
1 parent d359d6d commit 0a20d5c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
21 changes: 16 additions & 5 deletions neurokernel/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,12 +1602,12 @@ def __getitem__(self, key):
sel_1 = self.sel.expand(key[1])
selector = [f+t for f, t in itertools.product(sel_0, sel_1)]
if len(key) > 2:
return self.sel.select(self.data[list(key[2:])], selector=selector)
return self.sel.select(self.data[list(key[2:])], selector=selector)
else:
return self.sel.select(self.data, selector=selector)

def src_idx(self, src_int, dest_int,
src_type=None, dest_type=None, dest_ports=None):
src_type=None, dest_type=None, dest_ports=None, duplicates=False):
"""
Retrieve source ports connected to the specified destination ports.
Expand All @@ -1632,6 +1632,8 @@ def src_idx(self, src_int, dest_int,
Path-like selector corresponding to ports in destination
interface. If not specified, all ports in the destination
interface are considered.
duplicates : bool
If True, include duplicate ports in output.
Returns
-------
Expand Down Expand Up @@ -1679,9 +1681,13 @@ def src_idx(self, src_int, dest_int,
f = lambda x: x[self.from_slice][0] in from_idx and x[self.to_slice][0] in to_idx
idx = self.data.select(f).index

# Remove duplicate tuples from output without perturbing the order
# of the remaining tuples:
return OrderedDict.fromkeys([x[self.from_slice] for x in idx]).keys()
if not duplicates:

# Remove duplicate tuples from output without perturbing the order
# of the remaining tuples:
return OrderedDict.fromkeys([x[self.from_slice] for x in idx]).keys()
else:
return [x[self.from_slice] for x in idx]

def dest_idx(self, src_int, dest_int,
src_type=None, dest_type=None, src_ports=None):
Expand Down Expand Up @@ -1714,6 +1720,11 @@ def dest_idx(self, src_int, dest_int,
-------
idx : list of tuple
Destination ports connected to the specified source ports.
Notes
-----
No `duplicates` parameter is provided because fan-in from multiple
source ports to a single destination port is not permitted.
"""

assert src_int != dest_int
Expand Down
13 changes: 13 additions & 0 deletions tests/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,19 @@ def test_src_idx_dest_type(self):
[('aaa',)])
self.assertItemsEqual(q.src_idx(0, 1, dest_type='gpot'), [])

def test_src_idx_duplicates(self):
p = Pattern('/[aaa,bbb][0:3]', '/[xxx,yyy][0:3]')
p['/aaa[0]', '/yyy[0]'] = 1
p['/aaa[0]', '/yyy[1]'] = 1
p['/aaa[0]', '/yyy[2]'] = 1
p['/xxx[0]', '/bbb[0]'] = 1
p['/xxx[0]', '/bbb[1]'] = 1
p['/xxx[1]', '/bbb[2]'] = 1
self.assertItemsEqual(p.src_idx(0, 1, duplicates=True),
[('aaa', 0), ('aaa', 0), ('aaa', 0)])
self.assertItemsEqual(p.src_idx(1, 0, duplicates=True),
[('xxx', 0), ('xxx', 0), ('xxx', 1)])

def test_dest_idx(self):
p = Pattern('/[aaa,bbb][0:3]', '/[xxx,yyy][0:3]')
p['/aaa[0]', '/yyy[0]'] = 1
Expand Down

0 comments on commit 0a20d5c

Please sign in to comment.