Permalink
Browse files

Select specialization at runtime & support broadcasting

  • Loading branch information...
1 parent f4496bc commit b8ab0f3a2f17db4aab5850d8f03ad2e4c374e834 @markflorisson committed Jun 7, 2012
Showing with 348 additions and 97 deletions.
  1. +4 −0 Cython/Compiler/CythonScope.py
  2. +3 −1 Cython/Compiler/ExprNodes.py
  3. +168 −81 Cython/Compiler/Vector.py
  4. +42 −8 Cython/Utility/MemoryView.pyx
  5. +131 −7 tests/run/elementwise.pyx
@@ -127,6 +127,10 @@ def create_cython_scope(context):
# it across different contexts)
return CythonScope(context)
+def get_cython_scope(env):
+ "Return the CythonScope given an env in some context"
+ return env.global_scope().context.cython_scope
+
# Load test utilities for the cython scope
def load_testscope_utility(cy_util_name, **kwargs):
@@ -7713,7 +7713,9 @@ def infer_type(self, env):
return self.get_cython_array_type(env)
def get_cython_array_type(self, env):
- return env.global_scope().context.cython_scope.viewscope.lookup("array").type
+ from Cython.Compiler import CythonScope
+ cython_scope = CythonScope.get_cython_scope(env)
+ return cython_scope.viewscope.lookup("array").type
def generate_result_code(self, code):
import Buffer
View
@@ -1,6 +1,8 @@
+import copy
+
from Cython.Compiler import (ExprNodes, Nodes, PyrexTypes, Visitor,
- Code, Naming)
-from Cython.Compiler.Errors import error
+ Code, Naming, MemoryView, Errors)
+from Cython.Compiler.Errors import error, CompileError
from Cython.minivect import miniast
from Cython.minivect import minitypes
@@ -38,7 +40,7 @@ def map_type(self, type, wrap=False):
class CythonSpecializerMixin(object):
def visit_NodeWrapper(self, node):
- for op in node.cython_ops:
+ for op in node.operands:
op.variable = self.visit(op.variable)
return node
@@ -70,6 +72,9 @@ def visit_FunctionNode(self, node):
'__Pyx_RefNannySetupContext("%s", 1);' % node.mangled_name)
def visit_NodeWrapper(self, node):
+ for operand in node.operands:
+ operand.codegen = self
+
node = node.opaque_node
code = create_hybrid_code(self, self.code)
@@ -105,25 +110,10 @@ def visit_NodeWrapper(self, node):
class Context(miniast.CContext):
- #codegen_cls = CCodeGen
+ codegen_cls = CCodeGen
cleanup_codegen_cls = CCodeGenCleanup
specializer_mixin_cls = CythonSpecializerMixin
- def __init__(self, astbuilder=None, typemapper=None):
- super(Context, self).__init__(astbuilder, typemapper)
- # [OperandNode]
- self.cython_operand_nodes = []
-
- def codegen_cls(self, _, codewriter):
- """
- Monkeypatch all OperandNodes to have a codegen attribute, so they
- can generate code for the miniast they wrap.
- """
- codegen = CCodeGen(self, codewriter)
- for node in self.cython_operand_nodes:
- node.codegen = codegen
- return codegen
-
def getchildren(self, node):
return node.child_attrs
@@ -215,16 +205,85 @@ def result(self):
return ""
def generate_result_code(self, code):
- function = self.function
-
- if self.all_contig:
- specializer = specializers.ContigSpecializer
- else:
- specializer = specializers.StridedSpecializer
+ specializer_transforms = [
+ specializers.ContigSpecializer,
+ specializers.StridedSpecializer,
+ ]
self.context.original_cython_code = code
- codes = self.context.run(function, [specializer])
- (specialized_function, codewriter, (proto, impl)), = codes
+ codes = self.context.run(self.function, specializer_transforms)
+
+ self.temps = []
+
+ code.begin_block()
+ # Initialize a dest_shape and broadcasting variable
+ ndim = self.function.ndim
+ ones = ",".join("1" for i in range(ndim))
+ shape_temp = "__pyx_shape_temp"
+ broadcast_temp = "__pyx_broadcasting"
+ code.putln("Py_ssize_t %s[%d] = { %s };" % (shape_temp, ndim, ones))
+ code.putln("int %s = 0;" % broadcast_temp)
+
+ # broadcast all array operands
+ for idx, operand in enumerate(self.operands):
+ if operand.type.is_memoryviewslice:
+ self.broadcast(code, shape_temp, broadcast_temp, operand, ndim)
+
+ if_guard = "if"
+ for result in codes:
+ specializer = iter(result).next()
+ condition = self.condition(specializer, broadcast_temp)
+ if condition:
+ code.putln("%s (%s) {" % (if_guard, condition))
+ if_guard = " elif"
+ else:
+ code.putln(" else {")
+
+ self.put_specialized_call(code, shape_temp, broadcast_temp, *result)
+ code.put("}")
+
+ code.putln("")
+ code.end_block()
+
+ for temp in self.temps:
+ code.funcstate.release_temp(temp)
+
+ def broadcast(self, code, shape_temp, broadcast_temp, operand, ndim):
+ strides_temp = code.funcstate.allocate_temp(PyrexTypes.CArrayType(
+ PyrexTypes.c_py_ssize_t_type, operand.type.ndim), False)
+ for i in range(operand.type.ndim):
+ code.putln("%s[%d] = 0;" % (strides_temp, i))
+
+ self.temps.append(strides_temp)
+ # cdef void broadcast(Py_ssize_t *dst_shape, Py_ssize_t *dst_strides,
+ # int max_ndim, int ndim,
+ # Py_ssize_t *input_shape, Py_ssize_t *input_strides,
+ # bint *p_broadcast) nogil:
+ code.putln("__pyx_memoryview_broadcast(&%s[0], &%s[0], "
+ "%d, %d, "
+ "&%s.shape[0], &%s.strides[0],"
+ "&%s);" %
+ (shape_temp, strides_temp,
+ ndim, operand.type.ndim,
+ operand.result(), operand.result(),
+ broadcast_temp))
+
+ def condition(self, specializer, broadcast_temp):
+ if specializer.is_contig_specializer:
+ if not self.all_contig:
+ # todo: implement a memoryview flag to quickly check whether
+ # it is contig for each operand
+ return "0"
+ return "!%s" % broadcast_temp
+
+ def put_specialized_call(self, code, shape_temp, broadcast_temp,
+ specializer, specialized_function,
+ codewriter, result_code):
+ proto, impl = result_code
+
+ function = self.function
+ ndim = function.ndim
+
utility = Code.UtilityCode(proto=proto, impl=impl)
code.globalstate.use_utility_code(utility)
@@ -235,26 +294,17 @@ def generate_result_code(self, code):
print marker, 'impl', marker
print impl
- array_type = PyrexTypes.c_array_type(PyrexTypes.c_py_ssize_t_type,
- function.ndim)
- shape = code.funcstate.allocate_temp(array_type, manage_ref=False)
- for i in range(function.ndim):
- code.putln("%s[%d] = 0;" % (shape, i))
-
- args = ["&%s[0]" % shape, "&%s" % Naming.filename_cname,
+ # all function call arguments
+ args = ["&%s[0]" % shape_temp, "&%s" % Naming.filename_cname,
"&%s" % Naming.lineno_cname, "NULL"]
- for operand in self.operands:
+
+ # broadcast all array operands
+ for idx, operand in enumerate(self.operands):
result = operand.result()
if operand.type.is_memoryviewslice:
- for i in range(function.ndim):
- code.putln("if (%s.shape[%d] > %s[%d]) {" % (result, i, shape, i))
- code.putln( "%s[%d] = %s.shape[%d];" % (shape, i, result, i))
- code.putln("}")
-
- tp = operand.type.dtype.declaration_code("")
- args.append('(%s *) %s.data' % (tp, result))
- #args.append('&%s.shape[0]' % result)
- if not self.all_contig:
+ dtype_pointer_decl = operand.type.dtype.declaration_code("")
+ args.append('(%s *) %s.data' % (dtype_pointer_decl, result))
+ if not specializer.is_contig_specializer:
args.append("&%s.strides[0]" % result)
else:
args.append(result)
@@ -264,9 +314,13 @@ def generate_result_code(self, code):
code.funcstate.use_label(lbl)
code.putln("if (unlikely(%s < 0)) { goto %s; }" % (call, lbl))
- code.funcstate.release_temp(shape)
def need_wrapper_node(type):
+ """
+ Return whether a Cython node that needs to be mapped to a miniast Node,
+ should be mapped or wrapped (i.e., should minivect or Cython generate
+ the code to evaluate the expression?).
+ """
while True:
if type.is_ptr:
type = type.base_type
@@ -278,7 +332,36 @@ def need_wrapper_node(type):
type = type.resolve()
return type.is_pyobject or type.is_complex
+def get_dtype(type):
+ if type.is_memoryviewslice:
+ return type.dtype
+ return type
+
+class CythonASTInMiniastTransform(Visitor.VisitorTransform):
+
+ def __init__(self, env):
+ super(CythonASTInMiniastTransform, self).__init__()
+ self.env = env
+ self.operands = []
+
+ def visit_BinopNode(self, node):
+ dtype = get_dtype(node.type)
+ node = type(node)(node.pos, type=dtype, operator=node.operator,
+ operand1=self.visit(node.operand1),
+ operand2=self.visit(node.operand2))
+ node.analyse_types(self.env)
+ return node
+
+ def visit_ExprNode(self, node):
+ node = OperandNode(node.pos, type=get_dtype(node.type), node=node)
+ self.operands.append(node)
+ return node
+
class ElementalMapper(specializers.ASTMapper):
+ """
+ When some elementwise expression is found in the Cython AST, convert that
+ tree to a minivect AST.
+ """
wrapping = 0
@@ -289,9 +372,6 @@ def __init__(self, context, env):
self.operands = []
# miniast function arguments to the function
self.funcargs = []
- # All OperandNodes founds in the Cython AST held by a
- # miniast.NodeWrapper
- self.cython_ops = []
self.error = False
def map_type(self, node, **kwds):
@@ -302,17 +382,13 @@ def map_type(self, node, **kwds):
"operation: %s" % (node.type,))
raise
- def get_dtype(self, type):
- if type.is_memoryviewslice:
- return type.dtype
- return type
-
def register_operand(self, node):
"""
Register a non-elemental subexpression, and pass it in to the function
we are generating as an argument.
"""
assert not node.is_elemental
+
b = self.astbuilder
node = node.coerce_to_simple(self.env)
@@ -326,42 +402,37 @@ def register_operand(self, node):
funcarg = b.funcarg(b.variable(minitype, varname))
self.funcargs.append(funcarg)
-
- if self.wrapping:
- # we are inside a Cython AST, return something compatible
- result = OperandNode(node.pos, type=self.get_dtype(node.type),
- variable=funcarg.variable)
- self.context.cython_operand_nodes.append(result)
- self.cython_ops.append(result)
- return result
- else:
- # we are in a miniast
- return funcarg.variable
+ return funcarg.variable
def register_wrapper_node(self, node):
- if not node.is_elemental:
- return self.register_operand(node)
+ """
+ Create a miniast.NodeWrapper for functionality that Cython provides,
+ but that we want to use inside miniast expressions.
+ """
+ assert node.is_elemental
- self.wrapping += 1
- self.visitchildren(node)
- self.wrapping -= 1
+ transform = CythonASTInMiniastTransform(self.env)
+ try:
- if self.wrapping == 0:
- dtype = node.type
- if dtype.is_memoryviewslice:
- dtype = dtype.dtype
+ node = transform.visit(node)
+ except CompileError, e:
+ error(e.position, e.message_only)
+ return None
- node = type(node)(node.pos, type=dtype, operator=node.operator,
- operand1=node.operand1, operand2=node.operand2)
- node.analyse_types(self.env)
+ for operand in transform.operands:
+ operand.variable = self.register_operand(operand.node)
- result = self.astbuilder.wrap(node, cython_ops=self.cython_ops)
- self.cython_ops = []
- return result
- else:
- return node
+ def specialize_node(nodewrapper, memo):
+ return copy.deepcopy(node, memo)
+
+ return self.astbuilder.wrap(node, specialize_node,
+ operands=transform.operands)
def visit_ExprNode(self, node):
+ """
+ Some expression which cannot be converted to a miniast, but is passed
+ in as an argument to the function generated from the miniast.
+ """
return self.register_operand(node)
def visit_SingleAssignmentNode(self, node):
@@ -378,6 +449,12 @@ def visit_BinopNode(self, node):
return self.astbuilder.binop(minitype, node.operator, op1, op2)
class ElementWiseOperationsTransform(Visitor.EnvTransform):
+ """
+ Find elementwise expressions and run ElementalMapper to turn it into
+ a minivect AST. Our Cython tree ends here in an ElementalNode, which
+ responsibility is to call the function generated by minivect (as well
+ as to perform broadcasting and selection of the right specialization).
+ """
in_elemental = 0
@@ -393,6 +470,8 @@ def visit_elemental(self, node):
self.in_elemental -= 1
if not self.in_elemental:
+ load_utilities(self.current_env())
+
b = self.minicontext.astbuilder
b.pos = node.pos
astmapper = ElementalMapper(self.minicontext, self.current_env())
@@ -441,3 +520,11 @@ def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
return node
+
+def load_utilities(env):
+ from Cython.Compiler import CythonScope
+ cython_scope = CythonScope.get_cython_scope(env)
+ broadcast_utility.declare_in_scope(cython_scope.viewscope,
+ cython_scope=cython_scope, used=True)
+
+broadcast_utility = MemoryView.load_memview_cy_utility("Broadcasting")
Oops, something went wrong.

0 comments on commit b8ab0f3

Please sign in to comment.