Skip to content

Commit

Permalink
Support objects in vector expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Jul 23, 2012
1 parent 1e53ae2 commit 451c802
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cython/Compiler/Code.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,7 @@ def put_setup_refcount_context(self, name, acquire_gil=False):
if acquire_gil:
self.globalstate.use_utility_code(
UtilityCode.load_cached("ForceInitThreads", "ModuleSetupCode.c"))
self.putln('__Pyx_RefNannySetupContext("%s", %d);' % (name, acquire_gil and 1 or 0))
self.putln('__Pyx_RefNannySetupContext("%s", %d);' % (name, int(acquire_gil)))

def put_finish_refcount_context(self):
self.putln("__Pyx_RefNannyFinishContext();")
Expand Down
148 changes: 119 additions & 29 deletions Cython/Compiler/Vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from Cython.Compiler import ExprNodes, Nodes, PyrexTypes, Visitor, Code
from Cython.Compiler import (ExprNodes, Nodes, PyrexTypes, Visitor,
Code, Naming)
from Cython.Compiler.Errors import error

from Cython.minivect import miniast
Expand Down Expand Up @@ -41,22 +42,66 @@ def visit_NodeWrapper(self, node):
op.variable = self.visit(op.variable)
return node

def create_hybrid_code(codegen, old_minicode):
minicode = codegen.context.codewriter_cls(codegen.context)
minicode.indent = old_minicode.indent
code = CythonCCodeWriter(codegen.context, minicode)
code.level = minicode.indent
code.declaration_levels = list(old_minicode.declaration_levels)
code.codegen = codegen.clone(codegen.context, code)
return code

class CCodeGen(codegen.CCodeGen):

def __init__(self, context, codewriter):
super(CCodeGen, self).__init__(context, codewriter)
self.error_handlers = []

def visit_ErrorHandler(self, node):
self.error_handlers.append(node)
result = super(CCodeGen, self).visit_ErrorHandler(node)
self.error_handlers.pop()
return result

def visit_FunctionNode(self, node):
result = super(CCodeGen, self).visit_FunctionNode(node)
self.code.function_declarations.putln("__Pyx_RefNannyDeclarations")
self.code.before_loop.putln(
'__Pyx_RefNannySetupContext("%s", 1);' % node.mangled_name)

def visit_NodeWrapper(self, node):
code = CythonCCodeWriter()
node = node.opaque_node
code = create_hybrid_code(self, self.code)

# create funcstate and evaluate the expression
code.enter_cfunc_scope()
node.opaque_node.generate_evaluation_code(code)
node.generate_evaluation_code(code)
if node.type.is_pyobject:
code.put_incref(node.result(), node.type, nanny=True)
code.put_giveref(node.result())

declaration_code = CythonCCodeWriter()
# generate declarations for any temporaries
declaration_code = CythonCCodeWriter(self.context, code.minicode)
declaration_code.put_temp_declarations(code.funcstate)
self.code.declaration_point.putln(declaration_code.getvalue())
self.code.declaration_levels[0].putln(declaration_code.getvalue())
self.code.putln(code.getvalue())

return node.opaque_node.result()
return node.result()

class CCodeGenCleanup(codegen.CodeGenCleanup):
error_handler_level = 0
def visit_ErrorHandler(self, node):
self.error_handler_level += 1
super(CCodeGenCleanup, self).visit_ErrorHandler(node)
self.error_handler_level -= 1
if self.error_handler_level == 0:
self.code.putln("__Pyx_RefNannyFinishContext();")
return node

def visit_NodeWrapper(self, node):
node.opaque_node.generate_disposal_code(self.code)
code = create_hybrid_code(self, self.code)
node.opaque_node.generate_disposal_code(code)
self.code.putln(code.getvalue())

class Context(miniast.CContext):

Expand All @@ -79,9 +124,6 @@ def codegen_cls(self, _, codewriter):
node.codegen = codegen
return codegen

def getpos(self, node):
return node.pos

def getchildren(self, node):
return node.child_attrs

Expand All @@ -92,12 +134,39 @@ def declare_type(self, type):
return super(Context, self).declare_type(type)

def may_error(self, node):
return node.type.is_pyobject
return (node.type.resolve().is_pyobject or
(node.type.is_memoryviewslice and node.type.dtype.is_pyobject))

class CythonCCodeWriter(Code.CCodeWriter):

def __init__(self, context, minicode):
super(CythonCCodeWriter, self).__init__()
self.minicode = minicode
self.globalstate = context.original_cython_code.globalstate

def mark_pos(self, pos):
pass

def set_error_info(self, pos):
fn_var, lineno_var, col_var = [
self.minicode.mangle(v.name)
for v in self.codegen.function.posinfo.variables]

filename_idx = self.lookup_filename(pos[0])
return '*%s = %s[%d]; *%s = %s;' % (
fn_var, Naming.filetable_cname, filename_idx,
lineno_var, pos[1])

def error_goto(self, pos):
assert self.codegen.error_handlers

label = self.codegen.error_handlers[-1].error_label
return "{%s goto %s;}" % (self.set_error_info(pos), label.mangled_name)

def mangle(self, name):
"We are simultaneously a mini-CodeWriter and a Cython-CodeWriter"
return self.minicode.mangle(name)

class OperandNode(ExprNodes.ExprNode):
"""
The purpose of this node is to wrap a miniast variable and dispatch
Expand Down Expand Up @@ -153,6 +222,7 @@ def generate_result_code(self, code):
else:
specializer = specializers.StridedSpecializer

self.context.original_cython_code = code
codes = self.context.run(function, [specializer])
(specialized_function, codewriter, (proto, impl)), = codes
utility = Code.UtilityCode(proto=proto, impl=impl)
Expand All @@ -171,7 +241,8 @@ def generate_result_code(self, code):
for i in range(function.ndim):
code.putln("%s[%d] = 0;" % (shape, i))

args = ["&%s[0]" % shape]
args = ["&%s[0]" % shape, "&%s" % Naming.filename_cname,
"&%s" % Naming.lineno_cname, "NULL"]
for operand in self.operands:
result = operand.result()
if operand.type.is_memoryviewslice:
Expand All @@ -188,7 +259,9 @@ def generate_result_code(self, code):
args.append(result)

call = "%s(%s)" % (specialized_function.mangled_name, ", ".join(args))
code.putln(code.error_goto_if_neg(call, self.pos))
lbl = code.funcstate.error_label
code.funcstate.use_label(lbl)
code.putln("if (unlikely(%s < 0)) { goto %s; }" % (call, lbl))

code.funcstate.release_temp(shape)

Expand Down Expand Up @@ -256,23 +329,15 @@ def register_operand(self, node):

return funcarg.variable

def visit_ExprNode(self, node):
return self.register_operand(node)

def visit_SingleAssignmentNode(self, node):
return self.astbuilder.assign(self.visit(node.lhs.dst),
self.visit(node.rhs))

def visit_BinopNode(self, node):
minitype = self.map_type(node, wrap=True)
if need_wrapper_node(node.type):
if not node.is_elemental:
return self.register_operand(node)
def register_wrapper_node(self, node):
if not node.is_elemental:
return self.register_operand(node)

self.wrapping += 1
self.visitchildren(node)
self.wrapping -= 1
self.wrapping += 1
self.visitchildren(node)
self.wrapping -= 1

if self.wrapping == 0:
dtype = node.type
if dtype.is_memoryviewslice:
dtype = dtype.dtype
Expand All @@ -284,6 +349,20 @@ def visit_BinopNode(self, node):
result = self.astbuilder.wrap(node, cython_ops=self.cython_ops)
self.cython_ops = []
return result
else:
return node

def visit_ExprNode(self, node):
return self.register_operand(node)

def visit_SingleAssignmentNode(self, node):
return self.astbuilder.assign(self.visit(node.lhs.dst),
self.visit(node.rhs))

def visit_BinopNode(self, node):
minitype = self.map_type(node, wrap=True)
if need_wrapper_node(node.type):
return self.register_wrapper_node(node)

op1 = self.visit(node.operand1)
op2 = self.visit(node.operand2)
Expand All @@ -306,6 +385,7 @@ def visit_elemental(self, node):

if not self.in_elemental:
b = self.minicontext.astbuilder
b.pos = node.pos
astmapper = ElementalMapper(self.minicontext, self.current_env())
shapevar = b.variable(minitypes.Py_ssize_t.pointer(),
'__pyx_shape')
Expand All @@ -316,7 +396,17 @@ def visit_elemental(self, node):

name = '__pyx_array_expression%d' % self.funccount
self.funccount += 1
function = b.function(name, body, astmapper.funcargs, shapevar)

pos_args = (
b.variable(minitypes.c_string_type.pointer(), 'filename'),
b.variable(minitypes.int_type.pointer(), 'lineno'),
b.variable(minitypes.int_type.pointer(), 'column'))

position_argument = b.funcarg(b.variable(None, 'position'),
*pos_args)

function = b.function(name, body, astmapper.funcargs, shapevar,
position_argument)

all_contig = miniutils.all(op.type.is_contig
for op in astmapper.operands)
Expand Down
40 changes: 17 additions & 23 deletions tests/run/elementwise.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,23 @@ def test_typedef(np.int32_t[:] m):
m[:] = m + m + m
return m

def test_arbitrary_dtypes(long[:] m1, long double[:] m2):
"""
>>> a = np.arange(10, dtype='l')
>>> b = np.arange(10, dtype=np.longdouble)
>>> test_arbitrary_dtypes(a, b)
>>> a
array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18])
>>> b
array([ 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0], dtype=float128)
"""
m1[:] = m1 + m1
m2[:] = m2 + m2
cdef fused fused_dtype_t:
long
long double
double complex
object

def test_tougher_arbitrary_dtypes(double complex[:] m1, m2): #, object[:] m2):
def test_arbitrary_dtypes(fused_dtype_t[:] m):
"""
>>> a = np.arange(10, dtype=np.complex128) + 1.2j
>>> b = np.arange(10, dtype=np.object)
>>> test_tougher_arbitrary_dtypes(a, b)
>>> a
array([ 0.+2.4j, 2.+2.4j, 4.+2.4j, 6.+2.4j, 8.+2.4j, 10.+2.4j,
12.+2.4j, 14.+2.4j, 16.+2.4j, 18.+2.4j])
>>> b
array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=object)
>>> test_arbitrary_dtypes(np.arange(10, dtype='l'))
array([ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27])
>>> test_arbitrary_dtypes(np.arange(10, dtype=np.longdouble))
array([ 0.0, 3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0], dtype=float128)
>>> test_arbitrary_dtypes(np.arange(10, dtype=np.complex128) + 1.2j)
array([ 0.+3.6j, 3.+3.6j, 6.+3.6j, 9.+3.6j, 12.+3.6j, 15.+3.6j,
18.+3.6j, 21.+3.6j, 24.+3.6j, 27.+3.6j])
>>> test_arbitrary_dtypes(np.arange(10, dtype=np.object))
array([0, 3, 6, 9, 12, 15, 18, 21, 24, 27], dtype=object)
"""
m1[:] = m1 + m1
m2[:] = m2 + m2
m[:] = m + m + m
return np.asarray(m)

0 comments on commit 451c802

Please sign in to comment.