Skip to content
Browse files

Support scalar arguments

  • Loading branch information...
1 parent 46fe298 commit a02e733397a3a8fdcd4426bd4d07430ddd863967 @markflorisson committed
Showing with 95 additions and 12 deletions.
  1. +1 −1 Cython/Compiler/ExprNodes.py
  2. +32 −10 Cython/Compiler/Vector.py
  3. +1 −1 Cython/minivect
  4. +61 −0 tests/array_expressions/elementwise.pyx
View
2 Cython/Compiler/ExprNodes.py
@@ -8080,7 +8080,7 @@ def analyse_memoryview_operation(self, env):
dtype1 = type1.dtype
if type2.is_memoryviewslice:
- type1.assert_direct_dims(self.pos)
+ type2.assert_direct_dims(self.pos)
ndim2 = type2.ndim
dtype2 = type2.dtype
View
42 Cython/Compiler/Vector.py
@@ -410,6 +410,7 @@ class SpecializationCaller(ExprNodes.ExprNode):
context: Context attribute
operands: all participating array views
+ scalar_operands: non-array operands
function: miniast function wrapping the array expression
During code generation:
@@ -621,6 +622,7 @@ def put_specialized_call(self, code, specializer, specialized_function,
else:
args.append(result)
+ args.extend(scalar_arg.result() for scalar_arg in self.scalar_operands)
call = "%s(%s)" % (specialized_function.mangled_name, ", ".join(args))
if self.may_error:
@@ -652,9 +654,9 @@ class ElementalNode(Nodes.StatNode):
may_error: indicates whether the expression may raise a sudden error
"""
- child_attrs = ['operands', 'temp_nodes', 'lhs', 'check_overlap', 'rhs',
- 'final_assignment_node', 'broadcast', 'final_broadcast',
- 'temp_dst']
+ child_attrs = ['operands', 'scalar_operands', 'temp_nodes', 'lhs',
+ 'check_overlap', 'rhs', 'final_assignment_node',
+ 'broadcast', 'final_broadcast', 'temp_dst']
check_overlap = None
may_error = None
@@ -672,7 +674,9 @@ def analyse_expressions(self, env):
self.rhs = SpecializationCaller(
self.operands[0].pos, context=self.minicontext,
- dst=self.lhs, operands=self.operands, function=self.rhs_function,
+ dst=self.lhs, operands=self.operands,
+ scalar_operands=self.scalar_operands,
+ function=self.rhs_function,
may_error=self.may_error)
self.rhs.analyse_types(env)
@@ -716,7 +720,11 @@ def final_assignment(self):
rhs_var = b.variable(typemapper.map_type(rhs_type, wrap=True), 'rhs')
if self.lhs.type.dtype.is_pyobject:
- body = b.assign(b.decref(lhs_var), b.incref(rhs_var))
+ rhs_tmp = b.temp(rhs_var.type.dtype)
+ body = b.stats(b.assign(rhs_tmp, rhs_var),
+ b.decref(lhs_var),
+ b.incref(rhs_tmp),
+ b.assign(lhs_var, rhs_tmp))
else:
body = b.assign(lhs_var, rhs_var)
@@ -724,6 +732,7 @@ def final_assignment(self):
func = b.function('final_assignment%d', body, args)
return SpecializationCaller(self.pos, context=self.minicontext,
operands=[self.rhs], function=func,
+ scalar_operands=[],
dst=self.lhs, may_error=False,
target=self.lhs.wrap_in_clone_node())
@@ -788,6 +797,9 @@ def generate_execution_code(self, code):
for op in self.operands:
op.generate_evaluation_code(code)
+ for scalar_op in self.scalar_operands:
+ scalar_op.generate_evaluation_code(code)
+
code.putln("/* Check overlapping memory */")
self.check_overlap.generate_evaluation_code(code)
@@ -908,6 +920,8 @@ def __init__(self, context, env, max_ndim):
self.env = env
# operands to the function in callee space
self.operands = []
+ # scalar operands to the function in callee space
+ self.scalar_operands = []
# miniast function arguments to the function
self.funcargs = []
self.error = False
@@ -931,12 +945,18 @@ def register_operand(self, node):
b = self.astbuilder
- node = node.coerce_to_temp(self.env)
- varname = '__pyx_op%d' % len(self.operands)
- self.operands.append(node)
-
minitype = self.map_type(node, wrap=True)
if node.type.is_memoryviewslice:
+ node = node.coerce_to_temp(self.env)
+ self.operands.append(node)
+ elif node.is_literal:
+ return b.constant(node.value, type=minitype)
+ else:
+ node = node.coerce_to_simple(self.env)
+ self.scalar_operands.append(node)
+
+ varname = '__pyx_op%d' % (len(self.operands) + len(self.scalar_operands))
+ if node.type.is_memoryviewslice:
funcarg = b.array_funcarg(b.variable(minitype, varname))
funcarg.type.ndim = min(funcarg.type.ndim, self.max_ndim)
else:
@@ -1040,7 +1060,9 @@ def visit_elemental(self, node, lhs=None):
posinfo=posinfo)
astmapper.operands.remove(lhs)
- node = ElementalNode(node.pos, operands=astmapper.operands,
+ node = ElementalNode(node.pos,
+ operands=astmapper.operands,
+ scalar_operands=astmapper.scalar_operands,
rhs_function=function,
minicontext=self.minicontext,
lhs=lhs)
2 Cython/minivect
@@ -1 +1 @@
-Subproject commit e30c76cbe16566f40204ba2864ad55a7fa312afb
+Subproject commit 599552222c9a6fac6d3331c2c100d59b4e67efc0
View
61 tests/array_expressions/elementwise.pyx
@@ -112,3 +112,64 @@ def test_overlapping_memory(fused_dtype_t[:] m1, fused_dtype_t[:, :] m2):
m2[...] = m2[::-1, :] + m1
m1[1:] = m1[:-1]
+@testcase
+def test_constant_scalar_complex_arguments(double complex[:] m):
+ """
+ >>> test_constant_scalar_complex_arguments(np.arange(10, dtype=np.complex128))
+ array([ 5.+4.j, 7.+4.j, 9.+4.j, 11.+4.j, 13.+4.j, 15.+4.j,
+ 17.+4.j, 19.+4.j, 21.+4.j, 23.+4.j])
+ """
+ m[:] = m + m + (5 + 4j)
+ return np.asarray(m)
+
+@testcase
+def test_constant_scalar_double_arguments(double[:] m):
+ """
+ >>> test_constant_scalar_double_arguments(np.arange(10, dtype=np.double))
+ array([ 5., 7., 9., 11., 13., 15., 17., 19., 21., 23.])
+ """
+ m[:] = m + m + 5.0
+ return np.asarray(m)
+
+@testcase
+def test_constant_external_arguments(np.uint64_t[:] m):
+ """
+ >>> test_constant_external_arguments(np.arange(10, dtype=np.uint64))
+ array([ 5, 7, 9, 11, 13, 15, 17, 19, 21, 23], dtype=uint64)
+ """
+ m[:] = m + m + 5
+ return np.asarray(m)
+
+@testcase
+def test_constant_object_arguments(object[:] m):
+ """
+ >>> test_constant_object_arguments(np.arange(10, dtype=np.object))
+ array([5, 7, 9, 11, 13, 15, 17, 19, 21, 23], dtype=object)
+ """
+ m[:] = m + m + 5
+ return np.asarray(m)
+
+cdef int func1():
+ print "func1"
+ return 4
+
+cdef int func2():
+ print "func2"
+ return 3
+
+cdef int func3():
+ print "func3"
+ return 2
+
+@testcase
+def test_evaluate_operands_once(int[:] m):
+ """
+ >>> test_evaluate_operands_once(np.arange(10, dtype='i'))
+ func1
+ func2
+ func3
+ array([ 5, 7, 9, 11, 13, 15, 17, 19, 21, 23], dtype=int32)
+ """
+ m[:] = m + func1() + m + func2()
+ m[:] = -func3() + m
+ return np.asarray(m)

0 comments on commit a02e733

Please sign in to comment.
Something went wrong with that request. Please try again.