Skip to content

Commit

Permalink
Improve ndpointer to allow shape and flags checking as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
teoliphant committed Aug 13, 2006
1 parent eee00f8 commit 3fa71a7
Showing 1 changed file with 55 additions and 8 deletions.
63 changes: 55 additions & 8 deletions numpy/lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys, os
import inspect
import types
from numpy.core.numerictypes import obj2sctype
from numpy.core.multiarray import dtype
from numpy.core.numerictypes import obj2sctype, integer
from numpy.core.multiarray import dtype, _flagdict
from numpy.core import product, ndarray

__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
Expand Down Expand Up @@ -76,30 +76,77 @@ def ctypes_load_library(libname, loader_path):
libpath = os.path.join(libdir, libname)
return ctypes.cdll[libpath]

def _num_fromflags(flaglist):
num = 0
for val in flaglist:
num += _flagdict[val]
return num

def _flags_fromnum(num):
res = []
for key, value in _flagdict.items():
if (num & value):
res.append(key)
return res

class _ndptr(object):
def from_param(cls, obj):
if not isinstance(obj, ndarray):
raise TypeError("argument must be an ndarray")
if obj.dtype != cls._dtype_:
raise TypeError("array must have data type", cls._dtype_)
if cls._ndim_ and obj.ndim != cls._ndim_:
raise TypeError("array must have %d dimension(s)" % cls._ndim_)
if cls._shape_ and obj.shape != cls._shape_:
raise TypeError("array must have shape ", cls._shape_)
if cls._flags_ and ((obj.flags.num & cls._flags_) != cls._flags_):
raise TypeError("array must have flags ",
_flags_fromnum(cls._flags_))
return obj.ctypes
from_param = classmethod(from_param)

# Factory for a type-checking object with from_param defined
# Factory for an array-checking object with from_param defined
_pointer_type_cache = {}
def ndpointer(datatype):
def ndpointer(datatype, ndim=None, shape=None, flags=None):
datatype = dtype(datatype)
num = None
if flags is not None:
if isinstance(flags, str):
flags = flags.split(',')
elif isinstance(flags, (int, integer)):
num = flags
flags = _flags_fromnum(flags)
if num is None:
flags = [x.strip().upper() for x in flags]
num = _num_fromflags(flags)
try:
return _pointer_type_cache[datatype]
return _pointer_type_cache[(datatype, ndim, shape, num)]
except KeyError:
pass
pass
if datatype.names:
name = str(id(datatype))
else:
name = datatype.str
if ndim is not None:
name += "_%dd" % ndim
if shape is not None:
try:
strshape = [str(x) for x in shape]
except TypeError:
strshape = [str(shape)]
shape = (shape,)
shape = tuple(shape)
name += "_"+"x".join(strshape)
if flags is not None:
name += "_"+"_".join(flags)
else:
flags = []
klass = type("ndpointer_%s"%name, (_ndptr,),
{"_dtype_": datatype})
_pointer_type_cache[datatype] = klass
{"_dtype_": datatype,
"_shape_" : shape,
"_ndim_" : ndim,
"_flags_" : num})
_pointer_type_cache[datatype] = klass
return klass


Expand Down

0 comments on commit 3fa71a7

Please sign in to comment.