Permalink
Browse files

Refactor memoryview dtype validation

  • Loading branch information...
1 parent 1df5077 commit 98d0d6062c69312c55fca56db54a543246a2f3cb @markflorisson committed May 22, 2012
View
3 Cython/Compiler/ExprNodes.py
@@ -7554,8 +7554,6 @@ def analyse_types(self, env):
array_dtype = self.base_type_node.base_type_node.analyse(env)
axes = self.base_type_node.axes
- MemoryView.validate_memslice_dtype(self.pos, array_dtype)
-
self.type = error_type
self.shapes = []
ndim = len(axes)
@@ -7635,6 +7633,7 @@ def analyse_types(self, env):
axes[-1] = ('direct', 'contig')
self.coercion_type = PyrexTypes.MemoryViewSliceType(array_dtype, axes)
+ self.coercion_type.validate_memslice_dtype(self.pos)
self.type = self.get_cython_array_type(env)
MemoryView.use_cython_array_utility_code(env)
env.use_utility_code(MemoryView.typeinfo_to_format_code)
View
2 Cython/Compiler/FusedNode.py
@@ -184,7 +184,7 @@ def _specialize_function_args(self, args, fused_to_specific):
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice:
- MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
+ arg.type.validate_memslice_dtype(arg.pos)
def create_new_local_scope(self, node, env, f2s):
"""
View
40 Cython/Compiler/MemoryView.py
@@ -170,35 +170,6 @@ def src_conforms_to_dst(src, dst, broadcast=False):
return True
-def valid_memslice_dtype(dtype, i=0):
- """
- Return whether type dtype can be used as the base type of a
- memoryview slice.
-
- We support structs, numeric types and objects
- """
- if dtype.is_complex and dtype.real_type.is_int:
- return False
-
- if dtype.is_struct and dtype.kind == 'struct':
- for member in dtype.scope.var_entries:
- if not valid_memslice_dtype(member.type):
- return False
-
- return True
-
- return (
- dtype.is_error or
- # Pointers are not valid (yet)
- # (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
- (dtype.is_array and i < 8 and
- valid_memslice_dtype(dtype.base_type, i + 1)) or
- dtype.is_numeric or
- dtype.is_pyobject or
- dtype.is_fused or # accept this as it will be replaced by specializations later
- (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
- )
-
def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice: %s" % dtype)
@@ -436,18 +407,13 @@ def get_is_contig_utility(c_contig, ndim):
def copy_src_to_dst_cname():
return "__pyx_memoryview_copy_contents"
-def verify_direct_dimensions(node):
- for access, packing in node.type.axes:
- if access != 'direct':
- error(self.pos, "All dimensions must be direct")
-
def copy_broadcast_memview_src_to_dst(src, dst, code):
"""
Copy the contents of slice src to slice dst. Does not support indirect
slices.
"""
- verify_direct_dimensions(src)
- verify_direct_dimensions(dst)
+ src.type.assert_direct_dims(src.pos)
+ dst.type.assert_direct_dims(dst.pos)
code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (copy_src_to_dst_cname(),
@@ -471,7 +437,7 @@ def assign_scalar(dst, scalar, code):
Assign a scalar to a slice. dst must be a temp, scalar will be assigned
to a correct type and not just something assignable.
"""
- verify_direct_dimensions(dst)
+ dst.type.assert_direct_dims(dst.pos)
dtype = dst.type.dtype
type_decl = dtype.declaration_code("")
slice_decl = dst.type.declaration_code("")
View
2 Cython/Compiler/Nodes.py
@@ -886,8 +886,8 @@ def analyse(self, env, could_be_name = False):
if not MemoryView.validate_axes(self.pos, axes_specs):
self.type = error_type
else:
- MemoryView.validate_memslice_dtype(self.pos, base_type)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
+ self.type.validate_memslice_dtype(self.pos)
self.use_memview_utilities(env)
return self.type
View
39 Cython/Compiler/PyrexTypes.py
@@ -604,6 +604,45 @@ def declare_attribute(self, attribute, env, pos):
return True
+ def valid_dtype(self, dtype, i=0):
+ """
+ Return whether type dtype can be used as the base type of a
+ memoryview slice.
+
+ We support structs, numeric types and objects
+ """
+ if dtype.is_complex and dtype.real_type.is_int:
+ return False
+
+ if dtype.is_struct and dtype.kind == 'struct':
+ for member in dtype.scope.var_entries:
+ if not self.valid_dtype(member.type):
+ return False
+
+ return True
+
+ return (
+ dtype.is_error or
+ # Pointers are not valid (yet)
+ # (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
+ (dtype.is_array and i < 8 and
+ self.valid_dtype(dtype.base_type, i + 1)) or
+ dtype.is_numeric or
+ dtype.is_pyobject or
+ dtype.is_fused or # accept this as it will be replaced by specializations later
+ (dtype.is_typedef and self.valid_dtype(dtype.typedef_base_type))
+ )
+
+ def validate_memslice_dtype(self, pos):
+ if not self.valid_dtype(self.dtype):
+ error(pos, "Invalid base type for memoryview slice: %s" % self.dtype)
+
+ def assert_direct_dims(self, pos):
+ for access, packing in self.axes:
+ if access != 'direct':
+ error(pos, "All dimensions must be direct")
+ break
+
def specialization_suffix(self):
return "%s_%s" % (self.axes_to_name(), self.dtype_name)

0 comments on commit 98d0d60

Please sign in to comment.