Skip to content

Commit

Permalink
Decref memoryview slice class attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
markflorisson committed Apr 10, 2012
1 parent 703dd88 commit d96dfdb
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 41 deletions.
18 changes: 16 additions & 2 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,21 @@ def release_temp_result(self, code):
# ---------------- Code Generation -----------------

def make_owned_reference(self, code):
# If result is a pyobject, make sure we own
# a reference to it.
"""
If result is a pyobject, make sure we own a reference to it.
If the result is in a temp, it is already a new reference.
"""
if self.type.is_pyobject and not self.result_in_temp():
code.put_incref(self.result(), self.ctype())

def make_owned_memoryviewslice(self, code):
"""
Make sure we own the reference to this memoryview slice.
"""
if not self.result_in_temp():
code.put_incref_memoryviewslice(self.result(),
have_gil=self.in_nogil_context)

def generate_evaluation_code(self, code):
code.mark_pos(self.pos)

Expand Down Expand Up @@ -8809,6 +8819,10 @@ def annotate(self, code):
code.annotate((file, line, col-1), AnnotationItem(style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type)))

class CoerceToMemViewSliceNode(CoercionNode):
"""
Coerce an object to a memoryview slice. This holds a new reference in
a managed temp.
"""

def __init__(self, arg, dst_type, env):
assert dst_type.is_memoryviewslice
Expand Down
14 changes: 3 additions & 11 deletions Cython/Compiler/MemoryView.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
"We can avoid decreffing the lhs if we know it is the first assignment"
assert rhs.type.is_memoryviewslice

pretty_rhs = isinstance(rhs, NameNode) or rhs.result_in_temp()
pretty_rhs = rhs.result_in_temp() or rhs.is_simple()
if pretty_rhs:
rhstmp = rhs.result()
else:
Expand All @@ -106,19 +106,11 @@ def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code
if not first_assignment:
code.put_xdecref_memoryviewslice(lhs_cname, have_gil=have_gil)

if rhs.is_name:
code.put_incref_memoryviewslice(rhs_cname, have_gil=have_gil)
if not rhs.result_in_temp():
rhs.make_owned_memoryviewslice(code)

code.putln("%s = %s;" % (lhs_cname, rhs_cname))

#code.putln("%s.memview = %s.memview;" % (lhs_cname, rhs_cname))
#code.putln("%s.data = %s.data;" % (lhs_cname, rhs_cname))
#for i in range(memviewslicetype.ndim):
# tup = (lhs_cname, i, rhs_cname, i)
# code.putln("%s.shape[%d] = %s.shape[%d];" % tup)
# code.putln("%s.strides[%d] = %s.strides[%d];" % tup)
# code.putln("%s.suboffsets[%d] = %s.suboffsets[%d];" % tup)

def get_buf_flags(specs):
is_c_contig, is_f_contig = is_cf_contig(specs)

Expand Down
55 changes: 28 additions & 27 deletions Cython/Compiler/ModuleNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,18 +988,10 @@ def generate_new_function(self, scope, code, cclass_entry):
type = scope.parent_type
base_type = type.base_type

py_attrs = []
memviewslice_attrs = []
py_buffers = []
for entry in scope.var_entries:
if entry.type.is_pyobject:
py_attrs.append(entry)
elif entry.type.is_memoryviewslice:
memviewslice_attrs.append(entry)
elif entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
have_entries, (py_attrs, py_buffers, memoryview_slices) = \
scope.get_refcounted_entries(include_weakref=True)

need_self_cast = type.vtabslot_cname or py_attrs or memviewslice_attrs or py_buffers
need_self_cast = type.vtabslot_cname or have_entries
code.putln("")
code.putln(
"static PyObject *%s(PyTypeObject *t, PyObject *a, PyObject *k) {"
Expand Down Expand Up @@ -1044,7 +1036,7 @@ def generate_new_function(self, scope, code, cclass_entry):
else:
code.put_init_var_to_py_none(entry, "p->%s", nanny=False)

for entry in memviewslice_attrs:
for entry in memoryview_slices:
code.putln("p->%s.data = NULL;" % entry.cname)
code.putln("p->%s.memview = NULL;" % entry.cname)

Expand Down Expand Up @@ -1081,18 +1073,25 @@ def generate_dealloc_function(self, scope, code):
code.putln(
"static void %s(PyObject *o) {"
% scope.mangle_internal("tp_dealloc"))
py_attrs = []

weakref_slot = scope.lookup_here("__weakref__")
for entry in scope.var_entries:
if entry.type.is_pyobject and entry is not weakref_slot:
py_attrs.append(entry)
if py_attrs or weakref_slot in scope.var_entries:
_, (py_attrs, _, memoryview_slices) = scope.get_refcounted_entries()

if py_attrs or memoryview_slices or weakref_slot in scope.var_entries:
self.generate_self_cast(scope, code)

# call the user's __dealloc__
self.generate_usr_dealloc_call(scope, code)
if weakref_slot in scope.var_entries:
code.putln("if (p->__weakref__) PyObject_ClearWeakRefs(o);")

for entry in py_attrs:
code.put_xdecref("p->%s" % entry.cname, entry.type, nanny=False)

for entry in memoryview_slices:
code.put_xdecref_memoryviewslice("p->%s" % entry.cname,
have_gil=True)

if base_type:
tp_dealloc = TypeSlots.get_base_slot_function(scope, tp_slot)
if tp_dealloc is None:
Expand Down Expand Up @@ -1139,13 +1138,8 @@ def generate_traverse_function(self, scope, code, cclass_entry):
"static int %s(PyObject *o, visitproc v, void *a) {"
% slot_func)

py_attrs = []
py_buffers = []
for entry in scope.var_entries:
if entry.type.is_pyobject and entry.name != "__weakref__":
py_attrs.append(entry)
if entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
have_entries, (py_attrs, py_buffers,
memoryview_slices) = scope.get_refcounted_entries()

if base_type or py_attrs:
code.putln("int e;")
Expand Down Expand Up @@ -1178,9 +1172,16 @@ def generate_traverse_function(self, scope, code, cclass_entry):
code.putln(
"}")

for entry in py_buffers:
code.putln("if (p->%s.obj) {" % entry.cname)
code.putln( "e = (*v)(p->%s.obj, a); if (e) return e;" % entry.cname)
for entry in py_buffers + memoryview_slices:
if entry.type == PyrexTypes.c_py_buffer_type:
cname = entry.cname + ".obj"
else:
# traverse the memoryview object, which should traverse the
# object exposing the buffer
cname = entry.cname + ".memview"

code.putln("if (p->%s) {" % cname)
code.putln( "e = (*v)(p->%s, a); if (e) return e;" % cname)
code.putln("}")

if cclass_entry.cname == '__pyx_memoryviewslice':
Expand Down
2 changes: 1 addition & 1 deletion Cython/Compiler/Nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5365,7 +5365,7 @@ def generate_execution_code(self, code):
"%s = %s;" % (
Naming.retval_cname,
self.value.result_as(self.return_type)))
self.value.generate_post_assignment_code(code)
self.value.generate_post_assignment_code(code)
self.value.free_temps(code)
else:
if self.return_type.is_pyobject:
Expand Down
17 changes: 17 additions & 0 deletions Cython/Compiler/Symtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,23 @@ def is_cpp(self):
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)

def get_refcounted_entries(self, include_weakref=False):
py_attrs = []
py_buffers = []
memoryview_slices = []

for entry in self.var_entries:
if entry.type.is_pyobject:
if include_weakref or entry.name != "weakref":
py_attrs.append(entry)
elif entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
elif entry.type.is_memoryviewslice:
memoryview_slices.append(entry)

have_entries = py_attrs or py_buffers or memoryview_slices
return have_entries, (py_attrs, py_buffers, memoryview_slices)


class PreImportScope(Scope):

Expand Down
10 changes: 10 additions & 0 deletions tests/run/memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def test_coerce_to_temp():
print
print _coerce_to_temp()[4][4]

def test_extclass_attribute_dealloc():
"""
>>> test_extclass_attribute_dealloc()
acquired self.arr
2
released self.arr
"""
cdef ExtClassMockedAttr obj = ExtClassMockedAttr()
print obj.arr[4, 4]

cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f')
global_mv = array((10,10), itemsize=sizeof(float), format='f')
cdef object global_obj
Expand Down

0 comments on commit d96dfdb

Please sign in to comment.