Skip to content

Commit

Permalink
Add option to PortMapper classes to control copying of data array pas…
Browse files Browse the repository at this point in the history
…sed to constructors.
  • Loading branch information
lebedov committed Feb 10, 2015
1 parent 890a15b commit 1692413
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
10 changes: 8 additions & 2 deletions neurokernel/plsel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,9 @@ class PortMapper(BasePortMapper):
Integer indices to map to port identifiers. If no map is specified,
it is assumed to be an array of consecutive integers from 0
through one less than the number of ports.
make_copy : bool
If True, map a copy of the specified data array to the specified
port identifiers.
Attributes
----------
Expand All @@ -2075,7 +2078,7 @@ class PortMapper(BasePortMapper):
The selectors may not contain any '*' or '[:]' characters.
"""

def __init__(self, selector, data=None, portmap=None):
def __init__(self, selector, data=None, portmap=None, make_copy=True):
super(PortMapper, self).__init__(selector, portmap)
N = len(self)

Expand All @@ -2092,7 +2095,10 @@ def __init__(self, selector, data=None, portmap=None):

# The port mapper may map identifiers to some portion of the data array:
assert N <= len(data)
self.data = data.copy()
if make_copy:
self.data = data.copy()
else:
self.data = data

def copy(self):
"""
Expand Down
7 changes: 5 additions & 2 deletions neurokernel/pm_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from plsel import BasePortMapper

class GPUPortMapper(PortMapper):
def __init__(self, selector, data=None, portmap=None):
def __init__(self, selector, data=None, portmap=None, make_copy=True):
super(PortMapper, self).__init__(selector, portmap)
N = len(self)

Expand All @@ -26,7 +26,10 @@ def __init__(self, selector, data=None, portmap=None):

# The port mapper may map identifiers to some portion of the data array:
assert N <= len(data)
self.data = data.copy()
if make_copy:
self.data = data.copy()
else:
self.data = data

def get_inds_nonzero(self):
raise NotImplementedError
Expand Down
6 changes: 6 additions & 0 deletions tests/test_plsel.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,12 @@ def test_copy(self):
assert_array_equal(pm2.data, pm1.data)
assert_series_equal(pm2.portmap, pm1.portmap)

data = np.random.rand(5)
pm0 = PortMapper('/foo[0:5]', data, portmap, False)
pm1 = pm0.copy()
data[0] = 1.0
assert pm0.data[0] == 1.0

def test_dtype(self):
pm = PortMapper('/foo/bar[0:10],/foo/baz[0:10]', self.data)
assert pm.dtype == np.float64
Expand Down

0 comments on commit 1692413

Please sign in to comment.