Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Support complex numbers in vector expressions

  • Loading branch information...
commit 1e53ae244de43d3effd140e5ac1a00b180dfb396 1 parent 31460e8
@markflorisson authored
Showing with 143 additions and 33 deletions.
  1. +129 −22 Cython/Compiler/Vector.py
  2. +1 −1  Cython/minivect
  3. +13 −10 tests/run/elementwise.pyx
View
151 Cython/Compiler/Vector.py
@@ -35,10 +35,24 @@ def map_type(self, type, wrap=False):
else:
raise minierror.UnmappableTypeError(type)
+class CythonSpecializerMixin(object):
+ def visit_NodeWrapper(self, node):
+ for op in node.cython_ops:
+ op.variable = self.visit(op.variable)
+ return node
+
class CCodeGen(codegen.CCodeGen):
def visit_NodeWrapper(self, node):
- node.opaque_node.generate_evaluation_code(self.code)
- return node.result()
+ code = CythonCCodeWriter()
+ code.enter_cfunc_scope()
+ node.opaque_node.generate_evaluation_code(code)
+
+ declaration_code = CythonCCodeWriter()
+ declaration_code.put_temp_declarations(code.funcstate)
+ self.code.declaration_point.putln(declaration_code.getvalue())
+ self.code.putln(code.getvalue())
+
+ return node.opaque_node.result()
class CCodeGenCleanup(codegen.CodeGenCleanup):
def visit_NodeWrapper(self, node):
@@ -46,8 +60,24 @@ def visit_NodeWrapper(self, node):
class Context(miniast.CContext):
- codegen_cls = CCodeGen
+ #codegen_cls = CCodeGen
cleanup_codegen_cls = CCodeGenCleanup
+ specializer_mixin_cls = CythonSpecializerMixin
+
+ def __init__(self, astbuilder=None, typemapper=None):
+ super(Context, self).__init__(astbuilder, typemapper)
+ # [OperandNode]
+ self.cython_operand_nodes = []
+
+ def codegen_cls(self, _, codewriter):
+ """
+ Monkeypatch all OperandNodes to have a codegen attribute, so they
+ can generate code for the miniast they wrap.
+ """
+ codegen = CCodeGen(self, codewriter)
+ for node in self.cython_operand_nodes:
+ node.codegen = codegen
+ return codegen
def getpos(self, node):
return node.pos
@@ -64,6 +94,41 @@ def declare_type(self, type):
def may_error(self, node):
return node.type.is_pyobject
+class CythonCCodeWriter(Code.CCodeWriter):
+ def mark_pos(self, pos):
+ pass
+
+class OperandNode(ExprNodes.ExprNode):
+ """
+ The purpose of this node is to wrap a miniast variable and dispatch
+ to the miniast code generator from within the Cython code generation
+ process.
+
+ This happens when certain operations are not supported natively in
+ elementwise expressions, such as operations on complex numbers or
+ objects. So the miniast has a NodeWrapper wrapping a Cython AST, of
+ which an OperandNode is a leaf, which has to return back again to
+ the miniast code generation process.
+
+ Summary:
+
+ miniast
+ -> cython ast
+ -> operand node
+ -> miniast
+ """
+
+ subexprs = []
+
+ def analyse_types(self, env):
+ "self.type is already set"
+
+ def generate_result_code(self, code):
+ pass
+
+ def result(self):
+ return self.codegen.visit(self.variable)
+
class ElementalNode(ExprNodes.ExprNode):
"""
Wraps a mapped AST.
@@ -84,9 +149,9 @@ def generate_result_code(self, code):
function = self.function
if self.all_contig:
- specializer = specializers.ContigSpecializer(self.context)
+ specializer = specializers.ContigSpecializer
else:
- specializer = specializers.StridedSpecializer(self.context)
+ specializer = specializers.StridedSpecializer
codes = self.context.run(function, [specializer])
(specialized_function, codewriter, (proto, impl)), = codes
@@ -127,13 +192,32 @@ def generate_result_code(self, code):
code.funcstate.release_temp(shape)
+def need_wrapper_node(type):
+ while True:
+ if type.is_ptr:
+ type = type.base_type
+ elif type.is_memoryviewslice:
+ type = type.dtype
+ else:
+ break
+
+ type = type.resolve()
+ return type.is_pyobject or type.is_complex
+
class ElementalMapper(specializers.ASTMapper):
+ wrapping = 0
+
def __init__(self, context, env):
super(ElementalMapper, self).__init__(context)
self.env = env
+ # operands to the function in callee space
self.operands = []
+ # miniast function arguments to the function
self.funcargs = []
+ # All OperandNodes founds in the Cython AST held by a
+ # miniast.NodeWrapper
+ self.cython_ops = []
self.error = False
def map_type(self, node, **kwds):
@@ -144,39 +228,62 @@ def map_type(self, node, **kwds):
"operation: %s" % (node.type,))
raise
- def need_wrapper_node(self, minitype):
- while True:
- if minitype.is_array:
- minitype = minitype.dtype
- elif minitype.is_pointer:
- minitype = minitype.base_type
- else:
- break
-
- if minitype.is_typewrapper:
- type = minitype.opaque_type
- return type.is_pyobject or type.is_complex
+ def get_dtype(self, type):
+ if type.is_memoryviewslice:
+ return type.dtype
+ return type
- return False
-
- def visit_ExprNode(self, node):
+ def register_operand(self, node):
+ """
+ Register a non-elemental subexpression, and pass it in to the function
+ we are generating as an argument.
+ """
assert not node.is_elemental
b = self.astbuilder
+
node = node.coerce_to_simple(self.env)
varname = '__pyx_op%d' % len(self.operands)
self.operands.append(node)
+
funcarg = b.funcarg(b.variable(self.map_type(node, wrap=True), varname))
self.funcargs.append(funcarg)
+ if self.wrapping:
+ result = OperandNode(node.pos, type=self.get_dtype(node.type),
+ variable=funcarg.variable)
+ self.context.cython_operand_nodes.append(result)
+ self.cython_ops.append(result)
+ return result
+
return funcarg.variable
+ def visit_ExprNode(self, node):
+ return self.register_operand(node)
+
def visit_SingleAssignmentNode(self, node):
return self.astbuilder.assign(self.visit(node.lhs.dst),
self.visit(node.rhs))
def visit_BinopNode(self, node):
minitype = self.map_type(node, wrap=True)
- if self.need_wrapper_node(minitype):
- return self.astbuilder.wrap(node)
+ if need_wrapper_node(node.type):
+ if not node.is_elemental:
+ return self.register_operand(node)
+
+ self.wrapping += 1
+ self.visitchildren(node)
+ self.wrapping -= 1
+
+ dtype = node.type
+ if dtype.is_memoryviewslice:
+ dtype = dtype.dtype
+
+ node = type(node)(node.pos, type=dtype, operator=node.operator,
+ operand1=node.operand1, operand2=node.operand2)
+ node.analyse_types(self.env)
+
+ result = self.astbuilder.wrap(node, cython_ops=self.cython_ops)
+ self.cython_ops = []
+ return result
op1 = self.visit(node.operand1)
op2 = self.visit(node.operand2)
2  Cython/minivect
@@ -1 +1 @@
-Subproject commit fe469802b23cca7e38d30978ca9f4329082d763d
+Subproject commit a9a6328cf0e8ae59bd461c5c2852e56afd284e3a
View
23 tests/run/elementwise.pyx
@@ -79,13 +79,16 @@ def test_arbitrary_dtypes(long[:] m1, long double[:] m2):
m1[:] = m1 + m1
m2[:] = m2 + m2
-# def test_tougher_arbitrary_dtypes(double complex[:] m1, object[:] m2):
-# """
-# >>> a = np.arange(10, dtype=np.complex64)
-# >>> b = np.arange(10, dtype=np.object)
-# >>> test_tougher_arbitrary_dtypes(a, b)
-# >>> a
-# >>> b
-# """
-# m1[:] = m1 + m1
-# m2[:] = m2 + m2
+def test_tougher_arbitrary_dtypes(double complex[:] m1, m2): #, object[:] m2):
+ """
+ >>> a = np.arange(10, dtype=np.complex128) + 1.2j
+ >>> b = np.arange(10, dtype=np.object)
+ >>> test_tougher_arbitrary_dtypes(a, b)
+ >>> a
+ array([ 0.+2.4j, 2.+2.4j, 4.+2.4j, 6.+2.4j, 8.+2.4j, 10.+2.4j,
+ 12.+2.4j, 14.+2.4j, 16.+2.4j, 18.+2.4j])
+ >>> b
+ array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=object)
+ """
+ m1[:] = m1 + m1
+ m2[:] = m2 + m2
Please sign in to comment.
Something went wrong with that request. Please try again.