Permalink
Browse files

Support (partial) elementwise function calls

  • Loading branch information...
markflorisson committed Jul 16, 2012
1 parent dd455de commit b42b157ce4b6c0c2dbf6cf5042a22d34d60b649e
Showing with 153 additions and 52 deletions.
  1. +97 −49 Cython/Compiler/ExprNodes.py
  2. +24 −2 Cython/Compiler/Vector.py
  3. +1 −1 Cython/minivect
  4. +31 −0 tests/array_expressions/elementwise.pyx
@@ -4011,7 +4011,7 @@ def analyse_types(self, env):
def function_type(self):
# Return the type of the function being called, coercing a function
# pointer to a function if necessary. If the function has fused
- # arguments, return the specific type.
+ # arguments, return the specialized type.
func_type = self.function.type
if func_type.is_ptr:
@@ -4026,61 +4026,47 @@ def is_simple(self):
# sequence for a function call or comparing values.
return False
- def analyse_c_function_call(self, env):
- if self.function.type is error_type:
- self.type = error_type
- return
-
- if self.function.type.is_cpp_class:
- overloaded_entry = self.function.type.scope.lookup("operator()")
- if overloaded_entry is None:
- self.type = PyrexTypes.error_type
- self.result_code = "<error>"
- return
- elif hasattr(self.function, 'entry'):
- overloaded_entry = self.function.entry
- elif self.function.is_index and self.function.is_fused_index:
- overloaded_entry = self.function.type.entry
+ def _analyse_overloaded_entry(self, env, overloaded_entry):
+ if self.function.type.is_fused:
+ functypes = self.function.type.get_all_specialized_function_types()
+ alternatives = [f.entry for f in functypes]
else:
- overloaded_entry = None
+ alternatives = overloaded_entry.all_alternatives()
- if overloaded_entry:
- if self.function.type.is_fused:
- functypes = self.function.type.get_all_specialized_function_types()
- alternatives = [f.entry for f in functypes]
- else:
- alternatives = overloaded_entry.all_alternatives()
+ entry = PyrexTypes.best_match(self.args, alternatives, self.pos, env)
- entry = PyrexTypes.best_match(self.args, alternatives, self.pos, env)
+ if not entry:
+ self.type = PyrexTypes.error_type
+ self.result_code = "<error>"
+ return None
- if not entry:
- self.type = PyrexTypes.error_type
- self.result_code = "<error>"
- return
+ entry.used = True
+ self.function.entry = entry
+ self.function.type = entry.type
+ func_type = self.function_type()
+ return func_type
- entry.used = True
- self.function.entry = entry
- self.function.type = entry.type
- func_type = self.function_type()
- else:
- func_type = self.function_type()
- if not func_type.is_cfunction:
- error(self.pos, "Calling non-function type '%s'" % func_type)
- self.type = PyrexTypes.error_type
- self.result_code = "<error>"
- return
- # Check no. of args
- max_nargs = len(func_type.args)
- expected_nargs = max_nargs - func_type.optional_arg_count
- actual_nargs = len(self.args)
- if func_type.optional_arg_count and expected_nargs != actual_nargs:
- self.has_optional_args = 1
- self.is_temp = 1
- # Coerce arguments
+ def _coerce_arguments(self, actual_nargs, env, func_type, max_nargs):
+ "Coerce arguments for a native function call"
some_args_in_temps = False
+ self.elemental_args = []
for i in xrange(min(max_nargs, actual_nargs)):
formal_type = func_type.args[i].type
- arg = self.args[i].coerce_to(formal_type, env)
+ arg = self.args[i]
+ if (arg.type.is_memoryviewslice and not
+ formal_type.is_memoryviewslice and
+ formal_type.assignable_from(arg.type.dtype)):
+ # elemental function call, this node will be replaced later
+ self.is_elemental = True
+ self.elemental_args.append(arg)
+ if not isinstance(self.function, NameNode):
+ error(self.function.pos,
+ "Function must be a C function name")
+ self.type = error_type
+ return
+ else:
+ arg = arg.coerce_to(formal_type, env)
+
if arg.is_temp:
if i > 0:
# first argument in temp doesn't impact subsequent arguments
@@ -4101,6 +4087,7 @@ def analyse_c_function_call(self, env):
some_args_in_temps = True
arg = arg.coerce_to_temp(env)
self.args[i] = arg
+
# handle additional varargs parameters
for i in xrange(max_nargs, actual_nargs):
arg = self.args[i]
@@ -4113,6 +4100,7 @@ def analyse_c_function_call(self, env):
self.args[i] = arg = arg.coerce_to(arg_ctype, env)
if arg.is_temp and i > 0:
some_args_in_temps = True
+
if some_args_in_temps:
# if some args are temps and others are not, they may get
# constructed in the wrong order (temps first) => make
@@ -4138,15 +4126,75 @@ def analyse_c_function_call(self, env):
#self.args[i] = arg.coerce_to_temp(env)
# instead: issue a warning
if i > 0 or i == 1 and self.self is not None: # skip first arg
- warning(arg.pos, "Argument evaluation order in C function call is undefined and may not be as expected", 0)
+ warning(arg.pos,
+ "Argument evaluation order in C function call "
+ "is undefined and may not be as expected", 0)
break
+ def _analyse_elemental_call(self, func_type):
+ if self.has_optional_args:
+ error(self.pos, "Default argument calls not supported for "
+ "elemental function calls")
+ self.type = error_type
+ else:
+ # Fake a memoryview return type. This type won't be used directly,
+ # but rather the types of the elemental arguments will be used
+ max_ndim = max(arg.type.ndim for arg in self.elemental_args)
+ dtype = self.elemental_args[0].type.dtype
+ return_type = PyrexTypes.MemoryViewSliceType(
+ dtype, [('direct', 'strided')] * max_ndim)
+ # func_type.return_type = return_type
+ self.type = return_type
+
+ def analyse_c_function_call(self, env):
+ if self.function.type is error_type:
+ self.type = error_type
+ return
+
+ if self.function.type.is_cpp_class:
+ overloaded_entry = self.function.type.scope.lookup("operator()")
+ if overloaded_entry is None:
+ self.type = PyrexTypes.error_type
+ self.result_code = "<error>"
+ return
+ elif hasattr(self.function, 'entry'):
+ overloaded_entry = self.function.entry
+ elif self.function.is_index and self.function.is_fused_index:
+ overloaded_entry = self.function.type.entry
+ else:
+ overloaded_entry = None
+
+ if overloaded_entry:
+ func_type = self._analyse_overloaded_entry(env, overloaded_entry)
+ if func_type is None:
+ return
+ else:
+ func_type = self.function_type()
+ if not func_type.is_cfunction:
+ error(self.pos, "Calling non-function type '%s'" % func_type)
+ self.type = PyrexTypes.error_type
+ self.result_code = "<error>"
+ return
+
+ # Check no. of args
+ max_nargs = len(func_type.args)
+ expected_nargs = max_nargs - func_type.optional_arg_count
+ actual_nargs = len(self.args)
+ if func_type.optional_arg_count and expected_nargs != actual_nargs:
+ self.has_optional_args = 1
+ self.is_temp = 1
+
+ self._coerce_arguments(actual_nargs, env, func_type, max_nargs)
+
# Calc result type and code fragment
if isinstance(self.function, NewExprNode):
self.type = PyrexTypes.CPtrType(self.function.class_type)
else:
self.type = func_type.return_type
+ if self.is_elemental:
+ self._analyse_elemental_call(func_type)
+
if self.function.is_name or self.function.is_attribute:
if self.function.entry and self.function.entry.utility_code:
self.is_temp = 1 # currently doesn't work for self.calculate_result_code()
View
@@ -667,9 +667,9 @@ def generate_result_code(self, code):
if_clause = "if"
if_clause = self._put_contig_specialization(code, if_clause,
contig, mixed_contig)
- if_clause = self._put_tiled_specialization(code, if_clause,
- mixed_contig)
if self.target.type.ndim > 1:
+ if_clause = self._put_tiled_specialization(code, if_clause,
+ mixed_contig)
if_clause = self._put_inner_contig_specializations(code, if_clause,
mixed_contig)
self._put_strided_specializations(code, if_clause, mixed_contig)
@@ -1239,6 +1239,22 @@ def visit_SingleAssignmentNode(self, node):
return self.astbuilder.assign(self.visit(lhs),
self.visit(node.rhs))
+ @elemental_dispatcher
+ def visit_SimpleCallNode(self, node, minitype):
+ b = self.astbuilder
+ miniargs = []
+ elemental_args = set(node.elemental_args)
+ for arg in node.args:
+ if arg in elemental_args:
+ # elementwise argument
+ miniargs.append(self.visit(arg))
+ else:
+ # normal function argument, create partial function
+ miniargs.append(self.register_operand(arg))
+
+ minifunc = b.funcname(minitype, node.function.entry.cname)
+ return b.funccall(minifunc, miniargs)
+
@elemental_dispatcher
def visit_UnopNode(self, node, minitype):
return self.astbuilder.unop(minitype, node.operator,
@@ -1273,6 +1289,8 @@ def visit_elemental(self, node, lhs=None, acquire_slice=None):
self.in_elemental -= 1
if not self.in_elemental:
+ # Convert the Cython AST to a minivect AST and generate code
+ # to select the right specialization
load_utilities(self.current_env())
b = self.minicontext.astbuilder
@@ -1337,8 +1355,12 @@ def visit_SingleAssignmentNode(self, node):
def visit_ExprNode(self, node):
if node.is_elemental:
if self.in_elemental:
+ # We are already in an elemental expression, just recursve
return self.visit_elemental(node)
+ # We are an outer expression which is not a direct assignment
+ # We have to create a new array to store the result of the
+ # expression, and convert to a minivect AST
env = self.current_env()
tmp_lhs, elemental_node = self._create_new_array_node(node, env)
result = ElementalNodeWrapper(node.pos, slice_result=tmp_lhs,
@@ -2,6 +2,8 @@
# tag: openmp
# mode: run
+from libc.math cimport sin, cos
+
include "utils.pxi"
@testcase
@@ -311,3 +313,32 @@ def test_operator_xor(bt[:, :, :] a, bt[:, :, :] b, result):
>>> test_operator_xor(i_a, i_b.copy(order='F'), r)
"""
equal(a ^ b, result)
+
+@testcase
+def test_function_calls(double[:, :, :] a, double[:, :, :] b):
+ """
+ >>> operands = d_a, d_b
+ >>> numpy_result = np.sin(d_a) + np.cos(d_b)
+ >>> our_result = test_function_calls(d_a, d_b)
+ >>> np.allclose(numpy_result, our_result)
+ True
+ """
+ return sin(a) + cos(b)
+
+cdef int getarg():
+ print "getarg called"
+ return 10
+
+cdef double elementwise_func(double a, int arg):
+ return a * 10
+
+@testcase
+def test_partial_function_call(double[:, :, :] a):
+ """
+ >>> numpy_result = d_a * 10 + 1
+ >>> our_result = test_partial_function_call(d_a)
+ getarg called
+ >>> np.allclose(numpy_result, our_result)
+ True
+ """
+ return elementwise_func(a, getarg()) + 1

0 comments on commit b42b157

Please sign in to comment.