Permalink
Browse files

Remove broadcasting leading dimensions from RHS

  • Loading branch information...
markflorisson committed Jun 18, 2012
1 parent e86e5aa commit 48672b0f7213dd97175df78a588741426b80707b
Showing with 150 additions and 67 deletions.
  1. +5 −0 Cython/Compiler/ExprNodes.py
  2. +10 −7 Cython/Compiler/PyrexTypes.py
  3. +99 −55 Cython/Compiler/Vector.py
  4. +1 −1 Cython/minivect
  5. +35 −4 tests/run/elementwise.pyx
@@ -3214,10 +3214,15 @@ class MemoryViewIndexNode(BufferIndexNode):
is_buffer_access = False
warned_untyped_idx = False
+ type = None
+
def analyse_types(self, env, getting=True):
# memoryviewslice indexing or slicing
import MemoryView
+ if self.type:
+ return
+
skip_child_analysis = True
indices = self.indices
@@ -525,6 +525,15 @@ def attributes_known(self):
return True
+ def c_f_contig_types(self, ndim):
+ follow_dim = [('direct', 'follow')]
+ contig_dim = [('direct', 'contig')]
+ to_axes_c = follow_dim * (ndim - 1) + contig_dim
+ to_axes_f = contig_dim + follow_dim * (ndim - 1)
+ to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
+ to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
+ return to_memview_c, to_memview_f
+
def declare_attribute(self, attribute, env, pos):
import MemoryView, Options
@@ -557,13 +566,7 @@ def declare_attribute(self, attribute, env, pos):
elif attribute in ("copy", "copy_fortran"):
ndim = len(self.axes)
- follow_dim = [('direct', 'follow')]
- contig_dim = [('direct', 'contig')]
- to_axes_c = follow_dim * (ndim - 1) + contig_dim
- to_axes_f = contig_dim + follow_dim * (ndim -1)
-
- to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
- to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
+ to_memview_c, to_memview_f = self.c_f_contig_types(ndim)
for to_memview, cython_name in [(to_memview_c, "copy"),
(to_memview_f, "copy_fortran")]:
View
@@ -12,7 +12,8 @@
from Cython.minivect import codegen
from Cython.minivect import specializers
-debug = False
+_debug = False
+_context_debug = False
class TypeMapper(minitypes.TypeMapper):
def map_type(self, type, wrap=False):
@@ -114,6 +115,8 @@ def visit_NodeWrapper(self, node):
class Context(miniast.CContext):
+ debug = _context_debug
+
codegen_cls = CCodeGen
cleanup_codegen_cls = CCodeGenCleanup
specializer_mixin_cls = CythonSpecializerMixin
@@ -230,8 +233,8 @@ class TempSliceMemory(ExprNodes.ExprNode):
subexprs = ['data']
def analyse_types(self, env):
- self.type = self.target.type
- self.dtype = self.type.dtype
+ # self.type = self.target.type
+ self.dtype = self.target.type.dtype
self.memsize = UtilNodes.ResultRefNode(
pos=self.pos, type=PyrexTypes.c_py_ssize_t_type)
self.data = MemoryAllocationNode(self.pos, dtype=self.dtype,
@@ -241,7 +244,7 @@ def analyse_types(self, env):
def generate_evaluation_code(self, code):
"set the size of memory to allocate before we evaluate subexpressions"
sizes = ["sizeof(%s)" % self.dtype.declaration_code("")]
- for i in range(self.type.ndim):
+ for i in range(self.target.type.ndim):
sizes.append("%s.shape[%d]" % (self.result(), i))
self.memsize.result_code = " * ".join(sizes)
@@ -252,10 +255,11 @@ def generate_result_code(self, code):
code.putln("%s.data = (char *) %s;" % (self.result(),
self.data.result()))
order = "__pyx_get_best_slice_order(%s, %d)" % (self.target.result(),
- self.type.ndim)
+ self.target.type.ndim)
t = (self.result(), self.result(),
self.dtype.declaration_code(""),
- self.type.ndim, order)
+ self.target.type.ndim, order)
+
code.putln("__pyx_fill_contig_strides_array("
"&%s.shape[0], &%s.strides[0], sizeof(%s), %d, %s);" % t)
@@ -314,7 +318,7 @@ def generate_result_code(self, code):
code.putln("%s = %d;" % (self.result(), broadcasting))
if self.init_shape:
- for i in range(self.dst_slice.type.ndim):
+ for i in range(self.max_ndim):
code.putln("%s.shape[%d] = 1;" % (self.dst_slice.result(), i))
for operand in self.operands:
@@ -326,11 +330,13 @@ def generate_result_code(self, code):
sig = "%s(&%s.shape[0], &%s.shape[0], &%s.strides[0], %d, %d, &%s)"
code.putln(code.error_goto_if_neg(sig % format_tuple, self.pos))
+def slice_type(type, ndim):
+ return PyrexTypes.MemoryViewSliceType(type.dtype, type.axes[-ndim:])
+
class UnbroadcastDestNode(ExprNodes.ExprNode):
subexprs = []
def analyse_types(self, env):
- self.type = PyrexTypes.MemoryViewSliceType(
- self.lhs.type.dtype, self.lhs.type.axes[-self.rhs.type.ndim:])
+ self.type = slice_type(self.lhs.type, self.rhs.type.ndim)
self.is_temp = True
def generate_result_code(self, code):
@@ -364,15 +370,64 @@ class SpecializationCaller(ExprNodes.ExprNode):
target = None
def analyse_types(self, env):
- self.all_contig = miniutils.all(
- op.type.is_contig for op in self.operands)
+ all_c_contig = miniutils.all(
+ op.type.is_c_contig for op in self.operands)
+ all_f_contig = miniutils.all(
+ op.type.is_f_contig for op in self.operands)
+ self.all_contig = all_c_contig or all_f_contig
+
+ rhs_ndim = max(op.type.ndim for op in self.operands)
if not self.target:
- self.target = TempSliceStruct(self.pos, ndim=self.function.ndim,
+ if self.dst.type.ndim >= rhs_ndim:
+ axes = self.dst.type.axes
+ elif self.all_contig:
+ types = self.dst.type.c_f_contig_types(rhs_ndim)
+ axes = types[all_f_contig].axes
+ else:
+ axes = [('direct', 'strided')] * rhs_ndim
+
+ self.target = TempSliceStruct(self.pos, ndim=rhs_ndim,
dtype=self.dst.type.dtype,
- axes=self.dst.type.axes)
+ axes=axes)
self.target.analyse_types(env)
+
self.type = self.target.type
+ def align_with_lhs(self, code):
+ """
+ Remove a broadcasting offset for the RHS and remove that offset. E.g.
+
+ a1d[:] = b2d[:] + 1
+
+ Here b2d is demoted to a 1d array, and shape[0] is asserted to be 1.
+
+ Note: we never broadcast the LHS with the RHS since we only want to
+ evaluate the RHS once, and then broadcast the result along the LHS.
+ """
+ if self.type.ndim <= self.dst.type.ndim:
+ return
+
+ code.putln("/* Align RHS with LHS */")
+ lhs_offset, rhs_offset = offsets(self.dst, self)
+ bound = self.type.ndim - rhs_offset
+
+ if bound > 1:
+ i = code.funcstate.allocate_temp(PyrexTypes.c_int_type,
+ manage_ref=False)
+ code.putln("for (%s = 0; %s < %d; %s++) {" % (i, i, bound, i))
+ else:
+ i = "0"
+
+ t = self.result(), i, self.result(), i, rhs_offset
+ code.putln( "%s.shape[%s] = %s.shape[%s + %d];" % t)
+ code.putln( "%s.strides[%s] = %s.strides[%s + %d];" % t)
+
+ if bound > 1:
+ code.putln("}")
+ code.funcstate.release_temp(i)
+
+ self.type = self.target.type = slice_type(self.dst.type, self.type.ndim)
+
def result(self):
return self.target.result()
@@ -419,7 +474,7 @@ def put_specialized_call(self, code, specializer, specialized_function,
utility = Code.UtilityCode(proto=proto, impl=impl)
code.globalstate.use_utility_code(utility)
- if debug:
+ if _debug:
marker = '-' * 20
print marker, 'proto', marker
print proto
@@ -455,6 +510,13 @@ def put_specialized_call(self, code, specializer, specialized_function,
else:
code.putln("(void) %s;" % call)
+def offsets(lhs, rhs):
+ lhs_ndim = lhs.type.ndim
+ rhs_ndim = rhs.type.ndim
+ lhs_offset = max(lhs_ndim - rhs_ndim, 0)
+ rhs_offset = max(rhs_ndim - lhs_ndim, 0)
+ return lhs_offset, rhs_offset
+
class ElementalNode(Nodes.StatNode):
"""
Evaluate the expression on the right hand side before assigning to the
@@ -486,14 +548,15 @@ def analyse_expressions(self, env):
# self.lhs is an UnbroadcastDestNode
self.lhs = self.lhs.lhs
- self.lhs.analyse_types(env)
+ #self.lhs.analyse_types(env)
self.lhs = self.lhs.coerce_to_simple(env)
self.rhs = SpecializationCaller(
self.operands[0].pos, context=self.minicontext,
dst=self.lhs, operands=self.operands, function=self.rhs_function,
may_error=self.may_error)
self.rhs.analyse_types(env)
+
self.temp_nodes.append(self.rhs.target)
for i, operand in enumerate(self.operands):
@@ -526,12 +589,12 @@ def final_assignment(self):
b = self.minicontext.astbuilder
typemapper = self.minicontext.typemapper
- lhs_offset, rhs_offset = self.offsets()
- self.rhs_type = PyrexTypes.MemoryViewSliceType(
+ lhs_offset, rhs_offset = offsets(self.lhs, self.rhs)
+ rhs_type = PyrexTypes.MemoryViewSliceType(
self.rhs.type.dtype, self.rhs.type.axes[rhs_offset:])
lhs_var = b.variable(typemapper.map_type(self.lhs.type, wrap=True), 'lhs')
- rhs_var = b.variable(typemapper.map_type(self.rhs_type, wrap=True), 'rhs')
+ rhs_var = b.variable(typemapper.map_type(rhs_type, wrap=True), 'rhs')
if self.lhs.type.dtype.is_pyobject:
body = b.assign(b.decref(lhs_var), b.incref(rhs_var))
@@ -552,33 +615,15 @@ def overlap(self):
return "1"
return "unlikely(%s)" % self.check_overlap.result()
- def offsets(self):
- lhs_ndim = self.lhs.type.ndim
- rhs_ndim = self.rhs.type.ndim
- lhs_offset = max(lhs_ndim - rhs_ndim, 0)
- rhs_offset = max(rhs_ndim - lhs_ndim, 0)
- return lhs_offset, rhs_offset
-
- def verify_final_shape(self, code):
- call = "__pyx_verify_shapes(%s, %s, %d, %d)" % (
- self.lhs.result(), self.rhs.result(),
- self.lhs.type.ndim, self.rhs.type.ndim)
- code.putln(code.error_goto_if_neg(call, self.pos))
-
def init_rhs_temp(self, code):
"""
In case of no overlapping memory, assign directly to the LHS.
"""
code.putln("%s.data = %s.data;" % (self.rhs.result(), self.lhs.result()))
-
- lhs_offset, rhs_offset = self.offsets()
- for i in range(rhs_offset):
- code.putln("%s.strides[%d] = 0;" % (self.rhs.result(), i))
-
+ lhs_offset, rhs_offset = offsets(self.lhs, self.rhs)
for i in range(self.rhs.type.ndim):
code.putln("%s.strides[%d] = %s.strides[%d];" % (
- self.rhs.result(), i + rhs_offset,
- self.lhs.result(), i + lhs_offset))
+ self.rhs.result(), i, self.lhs.result(), i + lhs_offset))
def advance_lhs_data_ptr(self, code):
"""
@@ -590,7 +635,7 @@ def advance_lhs_data_ptr(self, code):
m3[0, :] contains the data, which we need to broadcast over m3[1:, :]
"""
- lhs_offset, rhs_offset = self.offsets()
+ lhs_offset, rhs_offset = offsets(self.lhs, self.rhs)
lhs_r, rhs_r = self.lhs.result(), self.rhs.result()
def advance(i):
@@ -606,18 +651,11 @@ def advance(i):
advance(i + lhs_offset)
code.putln("}")
- def remove_rhs_offset(self, code):
- lhs_offset, rhs_offset = self.offsets()
- if rhs_offset:
- rhs_result = self.rhs.result()
- bound = self.rhs.type.ndim - rhs_offset
- i = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
- t = rhs_result, i, rhs_result, i, rhs_offset
- code.putln("for (%s = 0; %s < %d; %s++) {" % (i, i, bound, i))
- code.putln( "%s.shape[%s] = %s.shape[%s + %d];" % t)
- code.putln( "%s.strides[%s] = %s.strides[%s + %d];" % t)
- code.putln("}")
- code.funcstate.release_temp(i)
+ def verify_final_shape(self, code):
+ call = "__pyx_verify_shapes(%s, %s, %d, %d)" % (
+ self.lhs.result(), self.rhs.result(),
+ self.lhs.type.ndim, self.rhs.type.ndim)
+ code.putln(code.error_goto_if_neg(call, self.pos))
def generate_execution_code(self, code):
code.mark_pos(self.pos)
@@ -637,6 +675,8 @@ def generate_execution_code(self, code):
self.broadcast.generate_evaluation_code(code)
self.verify_final_shape(code)
+ self.rhs.align_with_lhs(code)
+
# Set rhs.data and rhs.strides
code.putln("/* Allocate scratch space if needed */")
code.putln("if (%s) {" % self.overlap())
@@ -661,8 +701,9 @@ def generate_execution_code(self, code):
code.putln("/* Final broadcasting assignment */")
if self.lhs.type.ndim == self.rhs.type.ndim:
- code.putln("if (%s) {" % self.final_broadcast.result())
- self.remove_rhs_offset(code)
+ code.putln("if (%s || %s) {" % (self.final_broadcast.result(),
+ self.overlap()))
+ # self.remove_rhs_offset(code)
self.final_assignment_node.broadcasting = self.final_broadcast.result()
self.final_assignment_node.generate_evaluation_code(code)
if self.lhs.type.ndim == self.rhs.type.ndim:
@@ -739,7 +780,7 @@ class ElementalMapper(specializers.ASTMapper):
wrapping = 0
- def __init__(self, context, env):
+ def __init__(self, context, env, max_ndim):
super(ElementalMapper, self).__init__(context)
self.env = env
# operands to the function in callee space
@@ -748,6 +789,7 @@ def __init__(self, context, env):
self.funcargs = []
self.error = False
self.may_error = False
+ self.max_ndim = max_ndim
def map_type(self, node, **kwds):
try:
@@ -773,6 +815,7 @@ def register_operand(self, node):
minitype = self.map_type(node, wrap=True)
if node.type.is_memoryviewslice:
funcarg = b.array_funcarg(b.variable(minitype, varname))
+ funcarg.type.ndim = min(funcarg.type.ndim, self.max_ndim)
else:
funcarg = b.funcarg(b.variable(minitype, varname))
@@ -851,7 +894,8 @@ def visit_elemental(self, node, lhs=None):
b = self.minicontext.astbuilder
b.pos = node.pos
- astmapper = ElementalMapper(self.minicontext, self.current_env())
+ astmapper = ElementalMapper(self.minicontext, self.current_env(),
+ max_ndim=lhs.type.ndim)
try:
body = astmapper.visit(node)
Oops, something went wrong.

0 comments on commit 48672b0

Please sign in to comment.