Skip to content

Commit

Permalink
don't special case for py3k+numpy
Browse files Browse the repository at this point in the history
py3k+numpy non-copying recv works fine now, with released pyzmq.  There was no need to make any changes in pyzmq.

closes ipythongh-478
  • Loading branch information
minrk committed Jul 16, 2011
1 parent b8d87c2 commit 22b4e57
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
8 changes: 6 additions & 2 deletions IPython/parallel/tests/test_newserialized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
# Imports
#-------------------------------------------------------------------------------

import sys

from unittest import TestCase

from IPython.testing.decorators import parametric
from IPython.utils import newserialized as ns
from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
from IPython.parallel.tests.clienttest import skip_without

if sys.version_info[0] >= 3:
buffer = memoryview

class CanningTestCase(TestCase):
def test_canning(self):
Expand Down Expand Up @@ -88,10 +92,10 @@ def test_ndarray_serialized(self):
self.assertEquals(md['shape'], a.shape)
self.assertEquals(md['dtype'], a.dtype.str)
buff = ser1.getData()
self.assertEquals(buff, numpy.getbuffer(a))
self.assertEquals(buff, buffer(a))
s = ns.Serialized(buff, td, md)
final = ns.unserialize(s)
self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
self.assertEquals(buffer(a), buffer(final))
self.assertTrue((a==final).all())
self.assertEquals(a.dtype.str, final.dtype.str)
self.assertEquals(a.shape, final.shape)
Expand Down
14 changes: 6 additions & 8 deletions IPython/utils/newserialized.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class SerializationError(Exception):
py3k = True
else:
py3k = False
if sys.version_info[:2] <= (2,6):
memoryview = buffer

#-----------------------------------------------------------------------------
# Classes and functions
Expand Down Expand Up @@ -101,10 +103,7 @@ def __init__(self, unSerialized):
self.data = None
self.obj = unSerialized.getObject()
if numpy is not None and isinstance(self.obj, numpy.ndarray):
if py3k or len(self.obj.shape) == 0: # length 0 arrays are just pickled
# FIXME:
# also use pickle for numpy arrays on py3k, since
# pyzmq doesn't rebuild from memoryviews properly
if len(self.obj.shape) == 0: # length 0 arrays are just pickled
self.typeDescriptor = 'pickle'
self.metadata = {}
else:
Expand All @@ -125,7 +124,7 @@ def __init__(self, unSerialized):

def _generateData(self):
if self.typeDescriptor == 'ndarray':
self.data = numpy.getbuffer(self.obj)
self.data = buffer(self.obj)
elif self.typeDescriptor in ('bytes', 'buffer'):
self.data = self.obj
elif self.typeDescriptor == 'pickle':
Expand Down Expand Up @@ -158,11 +157,10 @@ def getObject(self):
typeDescriptor = self.serialized.getTypeDescriptor()
if numpy is not None and typeDescriptor == 'ndarray':
buf = self.serialized.getData()
if isinstance(buf, (bytes, buffer)):
if isinstance(buf, (bytes, buffer, memoryview)):
result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
else:
# memoryview
result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
raise TypeError("Expected bytes or buffer/memoryview, but got %r"%type(buf))
result.shape = self.serialized.metadata['shape']
elif typeDescriptor == 'pickle':
result = pickle.loads(self.serialized.getData())
Expand Down

0 comments on commit 22b4e57

Please sign in to comment.