Skip to content

Commit

Permalink
Fix fan-in check, require in/out ports in BaseModule/Module construct…
Browse files Browse the repository at this point in the history
…ors, update LPU.py and demos in base.py/core.py accordingly.
  • Loading branch information
lebedov committed Feb 11, 2015
1 parent 54575d5 commit 5b605c2
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 70 deletions.
7 changes: 2 additions & 5 deletions neurokernel/LPU/LPU.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,11 @@ def __init__(self, dt, n_dict, s_dict, input_file=None, output_file=None,
np.double)
data_spike = np.zeros(self.num_public_spike + len(in_ports_ids_spk),
np.bool)
super(LPU, self).__init__(sel, sel_gpot, sel_spk, data_gpot, data_spike,
super(LPU, self).__init__(sel, sel_in, sel_out,
sel_gpot, sel_spk, data_gpot, data_spike,
columns, port_data, port_ctrl, port_time,
self.LPU_id, device, debug, time_sync)

self.interface[sel_in_gpot, 'io', 'type'] = ['in', 'gpot']
self.interface[sel_out_gpot, 'io', 'type'] = ['out', 'gpot']
self.interface[sel_in_spk, 'io', 'type'] = ['in', 'spike']
self.interface[sel_out_spk, 'io', 'type'] = ['out', 'spike']
self.sel_in_gpot_ids = self.pm['gpot'].ports_to_inds(self.sel_in_gpot)
self.sel_out_gpot_ids = self.pm['gpot'].ports_to_inds(self.sel_out_gpot)
self.sel_in_spk_ids = self.pm['spike'].ports_to_inds(self.sel_in_spk)
Expand Down
82 changes: 43 additions & 39 deletions neurokernel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ class BaseModule(ControlledProcess):
Parameters
----------
selector : str, unicode, or sequence
sel : str, unicode, or sequence
Path-like selector describing the module's interface of
exposed ports.
sel_in : str, unicode, or sequence
Selector describing all input ports in the module's interface.
sel_out : str, unicode, or sequence
Selector describing all input ports in the module's interface.
data : numpy.ndarray
Data array to associate with ports. Array length must equal the number
of ports in a module's interface.
Expand Down Expand Up @@ -125,12 +129,18 @@ def max_steps(self, value):
(self._max_steps, value))
self._max_steps = value

def __init__(self, selector, data, columns=['interface', 'io', 'type'],
def __init__(self, sel, sel_in, sel_out,
data, columns=['interface', 'io', 'type'],
port_data=PORT_DATA, port_ctrl=PORT_CTRL, port_time=PORT_TIME,
id=None, debug=False, time_sync=False):
self.debug = debug
self.time_sync = time_sync

# Require several necessary attribute columns:
assert 'interface' in columns
assert 'io' in columns
assert 'type' in columns

# Generate a unique ID if none is specified:
if id is None:
id = uid()
Expand All @@ -152,15 +162,23 @@ def __init__(self, selector, data, columns=['interface', 'io', 'type'],
self.net = 'none'

# Create module interface given the specified ports:
self.interface = Interface(selector, columns)
self.interface = Interface(sel, columns)

# Set the interface ID to 0; we assume that a module only has one
# interface:
self.interface[sel, 'interface'] = 0

# Set the interface ID to 0; we assume that a module only has one interface:
self.interface[selector, 'interface'] = 0
# Set port I/O status:
assert SelectorMethods.is_in(sel_in, sel)
assert SelectorMethods.is_in(sel_out, sel)
assert SelectorMethods.are_disjoint(sel_in, sel_out)
self.interface[sel_in, 'io'] = 'in'
self.interface[sel_out, 'io'] = 'out'

# Set up mapper between port identifiers and their associated data:
assert len(data) == len(self.interface)
self.data = data
self.pm = PortMapper(selector, self.data)
self.pm = PortMapper(sel, self.data)

# Patterns connecting this module instance with other modules instances.
# Keyed on the IDs of those modules:
Expand Down Expand Up @@ -259,14 +277,15 @@ def connect(self, m, pat, int_0, int_1):
assert m.interface.is_compatible(0, pat.interface, int_1, True)

# Check that no fan-in from different source modules occurs as a result
# of the new connection by getting the union of all input ports for the
# interfaces of all existing patterns connected to the current module
# and ensuring that the input ports from the new pattern don't overlap:
# of the new connection by getting the union of all connected input
# ports for the interfaces of all existing patterns connected to the
# current module and ensuring that the input ports from the new pattern
# don't overlap:
if self.patterns:
curr_in_ports = reduce(set.union,
[set(self.patterns[i].in_ports(self.pat_ints[i][0]).to_tuples()) \
for i in self.patterns.keys()])
assert not curr_in_ports.intersection(pat.in_ports(int_0).to_tuples())
[set(self.patterns[i].connected_ports(self.pat_ints[i][0]).in_ports(tuples=True)) \
for i in self.patterns.keys()])
assert not curr_in_ports.intersection(pat.connected_ports(int_0).in_ports(tuples=True))

# The pattern instances associated with the current
# module are keyed on the IDs of the modules to which they connect:
Expand Down Expand Up @@ -1237,21 +1256,6 @@ class MyModule(BaseModule):
Example of derived module class.
"""

def __init__(self, sel, sel_in, sel_out, data,
columns=['interface', 'io', 'type'],
port_data=PORT_DATA, port_ctrl=PORT_CTRL,
port_time=PORT_TIME,
id=None):
super(MyModule, self).__init__(sel, data, columns, port_data, port_ctrl,
port_time, id, True, True)

assert SelectorMethods.is_in(sel_in, sel)
assert SelectorMethods.is_in(sel_out, sel)
assert SelectorMethods.are_disjoint(sel_in, sel_out)

self.interface[sel_in, 'io', 'type'] = ['in', 'x']
self.interface[sel_out, 'io', 'type'] = ['out', 'x']

def run_step(self):
super(MyModule, self).run_step()

Expand Down Expand Up @@ -1294,29 +1298,29 @@ def run_step(self):
# Make sure that all ports in the patterns' interfaces are set so
# that they match those of the modules:
pat12 = Pattern(m1_int_sel, m2_int_sel)
pat12.interface[m1_int_sel_out] = [0, 'in', 'x']
pat12.interface[m1_int_sel_in] = [0, 'out', 'x']
pat12.interface[m2_int_sel_in] = [1, 'out', 'x']
pat12.interface[m2_int_sel_out] = [1, 'in', 'x']
pat12.interface[m1_int_sel_out] = [0, 'in']
pat12.interface[m1_int_sel_in] = [0, 'out']
pat12.interface[m2_int_sel_in] = [1, 'out']
pat12.interface[m2_int_sel_out] = [1, 'in']
pat12['/a[2]', '/b[0]'] = 1
pat12['/a[3]', '/b[1]'] = 1
pat12['/b[3]', '/a[0]'] = 1
man.connect(m1, m2, pat12, 0, 1)

pat23 = Pattern(m2_int_sel, m3_int_sel)
pat23.interface[m2_int_sel_out] = [0, 'in', 'x']
pat23.interface[m2_int_sel_in] = [0, 'out', 'x']
pat23.interface[m3_int_sel_in] = [1, 'out', 'x']
pat23.interface[m3_int_sel_out] = [1, 'in', 'x']
pat23.interface[m2_int_sel_out] = [0, 'in']
pat23.interface[m2_int_sel_in] = [0, 'out']
pat23.interface[m3_int_sel_in] = [1, 'out']
pat23.interface[m3_int_sel_out] = [1, 'in']
pat23['/b[4]', '/c[0]'] = 1
pat23['/c[2]', '/b[2]'] = 1
man.connect(m2, m3, pat23, 0, 1)

pat31 = Pattern(m3_int_sel, m1_int_sel)
pat31.interface[m3_int_sel_out] = [0, 'in', 'x']
pat31.interface[m1_int_sel_in] = [1, 'out', 'x']
pat31.interface[m3_int_sel_in] = [0, 'out', 'x']
pat31.interface[m1_int_sel_out] = [1, 'in', 'x']
pat31.interface[m3_int_sel_out] = [0, 'in']
pat31.interface[m1_int_sel_in] = [1, 'out']
pat31.interface[m3_int_sel_in] = [0, 'out']
pat31.interface[m1_int_sel_out] = [1, 'in']
pat31['/c[3]', '/a[1]'] = 1
pat31['/a[4]', '/c[1]'] = 1
man.connect(m3, m1, pat31, 0, 1)
Expand Down
50 changes: 24 additions & 26 deletions neurokernel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ class Module(BaseModule):
Parameters
----------
selector : str, unicode, or sequence
sel : str, unicode, or sequence
Path-like selector describing the module's interface of
exposed ports.
sel_in : str, unicode, or sequence
Selector describing all input ports in the module's interface.
sel_out : str, unicode, or sequence
Selector describing all input ports in the module's interface.
sel_gpot : str, unicode, or sequence
Path-like selector describing the graded potential ports in the module's
Selector describing all graded potential ports in the module's
interface.
sel_spike : str, unicode, or sequence
Path-like selector describing the spiking ports in the module's
interface.
Selector describing all spiking ports in the module's interface.
data_gpot : numpy.ndarray
Data array to associate with graded potential ports. Array length
must equal the number of graded potential ports in the module's interface.
Expand Down Expand Up @@ -76,7 +79,9 @@ class Module(BaseModule):
destination modules.
"""

def __init__(self, selector, sel_gpot, sel_spike, data_gpot, data_spike,
def __init__(self, sel, sel_in, sel_out,
sel_gpot, sel_spike,
data_gpot, data_spike,
columns=['interface', 'io', 'type'],
port_data=PORT_DATA, port_ctrl=PORT_CTRL, port_time=PORT_TIME,
id=None, device=None, debug=False, time_sync=False):
Expand All @@ -101,7 +106,6 @@ def __init__(self, selector, sel_gpot, sel_spike, data_gpot, data_spike,
# Reformat logger name:
LoggerMixin.__init__(self, 'mod %s' % self.id)


# Data port:
if port_data == port_ctrl:
raise ValueError('data and control ports must differ')
Expand All @@ -114,15 +118,20 @@ def __init__(self, selector, sel_gpot, sel_spike, data_gpot, data_spike,
self.net = 'none'

# Create module interface given the specified ports:
self.interface = Interface(selector, columns)
self.interface = Interface(sel, columns)

# Set the interface ID to 0
# we assume that a module only has one interface:
self.interface[selector, 'interface'] = 0
self.interface[sel, 'interface'] = 0

# Set port types:
assert SelectorMethods.is_in(sel_gpot, selector)
assert SelectorMethods.is_in(sel_spike, selector)
assert SelectorMethods.is_in(sel_in, sel)
assert SelectorMethods.is_in(sel_out, sel)
assert SelectorMethods.are_disjoint(sel_in, sel_out)
self.interface[sel_in, 'io'] = 'in'
self.interface[sel_out, 'io'] = 'out'
assert SelectorMethods.is_in(sel_gpot, sel)
assert SelectorMethods.is_in(sel_spike, sel)
assert SelectorMethods.are_disjoint(sel_gpot, sel_spike)
self.interface[sel_gpot, 'type'] = 'gpot'
self.interface[sel_spike, 'type'] = 'spike'
Expand Down Expand Up @@ -586,26 +595,15 @@ def __init__(self, sel,
columns=['interface', 'io', 'type'],
port_data=PORT_DATA, port_ctrl=PORT_CTRL, port_time=PORT_TIME,
id=None, device=None):
super(MyModule, self).__init__(sel, ','.join([sel_in_gpot,
sel_out_gpot]),
','.join([sel_in_spike,
sel_out_spike]),
super(MyModule, self).__init__(sel,
','.join([sel_in_gpot, sel_in_spike]),
','.join([sel_out_gpot, sel_out_spike]),
','.join([sel_in_gpot, sel_out_gpot]),
','.join([sel_in_spike, sel_out_spike]),
data_gpot, data_spike,
columns, port_data, port_ctrl, port_time,
id, None, True, True)

assert SelectorMethods.is_in(sel_in_gpot, sel)
assert SelectorMethods.is_in(sel_out_gpot, sel)
assert SelectorMethods.are_disjoint(sel_in_gpot, sel_out_gpot)
assert SelectorMethods.is_in(sel_in_spike, sel)
assert SelectorMethods.is_in(sel_out_spike, sel)
assert SelectorMethods.are_disjoint(sel_in_spike, sel_out_spike)

self.interface[sel_in_gpot, 'io', 'type'] = ['in', 'gpot']
self.interface[sel_out_gpot, 'io', 'type'] = ['out', 'gpot']
self.interface[sel_in_spike, 'io', 'type'] = ['in', 'spike']
self.interface[sel_out_spike, 'io', 'type'] = ['out', 'spike']

def run_step(self):
super(MyModule, self).run_step()

Expand Down

0 comments on commit 5b605c2

Please sign in to comment.