Permalink
Browse files

Allow arbitrary Cython types compatible with C expressions as vector …

…expression dtypes
  • Loading branch information...
1 parent 3fdcc52 commit 31460e859df4a8dc17d5eb97f06516608f4e9d06 @markflorisson committed May 29, 2012
Showing with 85 additions and 24 deletions.
  1. +60 −23 Cython/Compiler/Vector.py
  2. +1 −1 Cython/minivect
  3. +24 −0 tests/run/elementwise.pyx
View
@@ -5,31 +5,18 @@
from Cython.minivect import minitypes
from Cython.minivect import miniutils
from Cython.minivect import minierror
+from Cython.minivect import codegen
from Cython.minivect import specializers
debug = False
-class Context(miniast.CContext):
- def getpos(self, node):
- return node.pos
-
- def getchildren(self, node):
- return node.child_attrs
-
- def declare_type(self, type):
- if type.is_typewrapper:
- return type.opaque_type.declaration_code("")
-
- return super(Context, self).declare_type(type)
-
-
class TypeMapper(minitypes.TypeMapper):
- def map_type(self, type):
+ def map_type(self, type, wrap=False):
if type.is_typedef:
return minitypes.TypeWrapper(type)
elif type.is_memoryviewslice:
- return minitypes.ArrayType(self.map_type(type.dtype),
- len(type.axes),
+ dtype = self.map_type(type.dtype, wrap=wrap)
+ return minitypes.ArrayType(dtype, len(type.axes),
is_c_contig=type.is_c_contig,
is_f_contig=type.is_f_contig)
elif type.is_float:
@@ -43,7 +30,39 @@ def map_type(self, type):
elif type == PyrexTypes.c_int_type:
return minitypes.IntType()
- raise minierror.UnmappableTypeError(type)
+ if wrap:
+ return minitypes.TypeWrapper(type)
+ else:
+ raise minierror.UnmappableTypeError(type)
+
+class CCodeGen(codegen.CCodeGen):
+ def visit_NodeWrapper(self, node):
+ node.opaque_node.generate_evaluation_code(self.code)
+ return node.result()
+
+class CCodeGenCleanup(codegen.CodeGenCleanup):
+ def visit_NodeWrapper(self, node):
+ node.opaque_node.generate_disposal_code(self.code)
+
+class Context(miniast.CContext):
+
+ codegen_cls = CCodeGen
+ cleanup_codegen_cls = CCodeGenCleanup
+
+ def getpos(self, node):
+ return node.pos
+
+ def getchildren(self, node):
+ return node.child_attrs
+
+ def declare_type(self, type):
+ if type.is_typewrapper:
+ return type.opaque_type.declaration_code("")
+
+ return super(Context, self).declare_type(type)
+
+ def may_error(self, node):
+ return node.type.is_pyobject
class ElementalNode(ExprNodes.ExprNode):
"""
@@ -117,21 +136,36 @@ def __init__(self, context, env):
self.funcargs = []
self.error = False
- def map_type(self, node):
+ def map_type(self, node, **kwds):
try:
- return super(ElementalMapper, self).map_type(node)
+ return super(ElementalMapper, self).map_type(node, **kwds)
except minierror.UnmappableTypeError, e:
error(node.pos, "Unsupported type in elementwise "
"operation: %s" % (node.type,))
raise
+ def need_wrapper_node(self, minitype):
+ while True:
+ if minitype.is_array:
+ minitype = minitype.dtype
+ elif minitype.is_pointer:
+ minitype = minitype.base_type
+ else:
+ break
+
+ if minitype.is_typewrapper:
+ type = minitype.opaque_type
+ return type.is_pyobject or type.is_complex
+
+ return False
+
def visit_ExprNode(self, node):
assert not node.is_elemental
+ b = self.astbuilder
node = node.coerce_to_simple(self.env)
varname = '__pyx_op%d' % len(self.operands)
self.operands.append(node)
- funcarg = self.astbuilder.funcarg(
- self.astbuilder.variable(self.map_type(node), varname))
+ funcarg = b.funcarg(b.variable(self.map_type(node, wrap=True), varname))
self.funcargs.append(funcarg)
return funcarg.variable
@@ -140,7 +174,10 @@ def visit_SingleAssignmentNode(self, node):
self.visit(node.rhs))
def visit_BinopNode(self, node):
- minitype = self.map_type(node)
+ minitype = self.map_type(node, wrap=True)
+ if self.need_wrapper_node(minitype):
+ return self.astbuilder.wrap(node)
+
op1 = self.visit(node.operand1)
op2 = self.visit(node.operand2)
return self.astbuilder.binop(minitype, node.operator, op1, op2)
View
@@ -65,3 +65,27 @@ def test_typedef(np.int32_t[:] m):
"""
m[:] = m + m + m
return m
+
+def test_arbitrary_dtypes(long[:] m1, long double[:] m2):
+ """
+ >>> a = np.arange(10, dtype='l')
+ >>> b = np.arange(10, dtype=np.longdouble)
+ >>> test_arbitrary_dtypes(a, b)
+ >>> a
+ array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18])
+ >>> b
+ array([ 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0], dtype=float128)
+ """
+ m1[:] = m1 + m1
+ m2[:] = m2 + m2
+
+# def test_tougher_arbitrary_dtypes(double complex[:] m1, object[:] m2):
+# """
+# >>> a = np.arange(10, dtype=np.complex64)
+# >>> b = np.arange(10, dtype=np.object)
+# >>> test_tougher_arbitrary_dtypes(a, b)
+# >>> a
+# >>> b
+# """
+# m1[:] = m1 + m1
+# m2[:] = m2 + m2

0 comments on commit 31460e8

Please sign in to comment.