Skip to content

Commit

Permalink
Refactor memoryview dtype validation
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Jul 23, 2012
1 parent 1df5077 commit 98d0d60
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 41 deletions.
3 changes: 1 addition & 2 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7554,8 +7554,6 @@ def analyse_types(self, env):
array_dtype = self.base_type_node.base_type_node.analyse(env)
axes = self.base_type_node.axes

MemoryView.validate_memslice_dtype(self.pos, array_dtype)

self.type = error_type
self.shapes = []
ndim = len(axes)
Expand Down Expand Up @@ -7635,6 +7633,7 @@ def analyse_types(self, env):
axes[-1] = ('direct', 'contig')

self.coercion_type = PyrexTypes.MemoryViewSliceType(array_dtype, axes)
self.coercion_type.validate_memslice_dtype(self.pos)
self.type = self.get_cython_array_type(env)
MemoryView.use_cython_array_utility_code(env)
env.use_utility_code(MemoryView.typeinfo_to_format_code)
Expand Down
2 changes: 1 addition & 1 deletion Cython/Compiler/FusedNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _specialize_function_args(self, args, fused_to_specific):
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
arg.type.validate_memslice_dtype(arg.pos)

def create_new_local_scope(self, node, env, f2s):
"""
Expand Down
40 changes: 3 additions & 37 deletions Cython/Compiler/MemoryView.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,35 +170,6 @@ def src_conforms_to_dst(src, dst, broadcast=False):

return True

def valid_memslice_dtype(dtype, i=0):
"""
Return whether type dtype can be used as the base type of a
memoryview slice.
We support structs, numeric types and objects
"""
if dtype.is_complex and dtype.real_type.is_int:
return False

if dtype.is_struct and dtype.kind == 'struct':
for member in dtype.scope.var_entries:
if not valid_memslice_dtype(member.type):
return False

return True

return (
dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
(dtype.is_array and i < 8 and
valid_memslice_dtype(dtype.base_type, i + 1)) or
dtype.is_numeric or
dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
)

def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice: %s" % dtype)
Expand Down Expand Up @@ -436,18 +407,13 @@ def get_is_contig_utility(c_contig, ndim):
def copy_src_to_dst_cname():
return "__pyx_memoryview_copy_contents"

def verify_direct_dimensions(node):
for access, packing in node.type.axes:
if access != 'direct':
error(self.pos, "All dimensions must be direct")

def copy_broadcast_memview_src_to_dst(src, dst, code):
"""
Copy the contents of slice src to slice dst. Does not support indirect
slices.
"""
verify_direct_dimensions(src)
verify_direct_dimensions(dst)
src.type.assert_direct_dims(src.pos)
dst.type.assert_direct_dims(dst.pos)

code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (copy_src_to_dst_cname(),
Expand All @@ -471,7 +437,7 @@ def assign_scalar(dst, scalar, code):
Assign a scalar to a slice. dst must be a temp, scalar will be assigned
to a correct type and not just something assignable.
"""
verify_direct_dimensions(dst)
dst.type.assert_direct_dims(dst.pos)
dtype = dst.type.dtype
type_decl = dtype.declaration_code("")
slice_decl = dst.type.declaration_code("")
Expand Down
2 changes: 1 addition & 1 deletion Cython/Compiler/Nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,8 @@ def analyse(self, env, could_be_name = False):
if not MemoryView.validate_axes(self.pos, axes_specs):
self.type = error_type
else:
MemoryView.validate_memslice_dtype(self.pos, base_type)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
self.type.validate_memslice_dtype(self.pos)
self.use_memview_utilities(env)

return self.type
Expand Down
39 changes: 39 additions & 0 deletions Cython/Compiler/PyrexTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,45 @@ def declare_attribute(self, attribute, env, pos):

return True

def valid_dtype(self, dtype, i=0):
"""
Return whether type dtype can be used as the base type of a
memoryview slice.
We support structs, numeric types and objects
"""
if dtype.is_complex and dtype.real_type.is_int:
return False

if dtype.is_struct and dtype.kind == 'struct':
for member in dtype.scope.var_entries:
if not self.valid_dtype(member.type):
return False

return True

return (
dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
(dtype.is_array and i < 8 and
self.valid_dtype(dtype.base_type, i + 1)) or
dtype.is_numeric or
dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and self.valid_dtype(dtype.typedef_base_type))
)

def validate_memslice_dtype(self, pos):
if not self.valid_dtype(self.dtype):
error(pos, "Invalid base type for memoryview slice: %s" % self.dtype)

def assert_direct_dims(self, pos):
for access, packing in self.axes:
if access != 'direct':
error(pos, "All dimensions must be direct")
break

def specialization_suffix(self):
return "%s_%s" % (self.axes_to_name(), self.dtype_name)

Expand Down

0 comments on commit 98d0d60

Please sign in to comment.