Skip to content

Commit

Permalink
Fix memory leak in python caused by @tf_should_use.
Browse files Browse the repository at this point in the history
The issue is that python's GC has trouble collecting objects with __del__ methods.

The solution is two pronged:
* Keep track of usage state outside of the class, via a dict mapping
  id(object) => state
* Remove __del__ (this was the source: python's GC couldn't collect wrapped
  objects), and instead use weakref.finalize to emit warnings just as the object
  is being garbage collected.
* Added tests for garbage collection [they were failing before i fixed the issue]

PiperOrigin-RevId: 158042388
  • Loading branch information
ebrevdo authored and Amit Patankar committed Jun 6, 2017
1 parent a823c8d commit 51c4049
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 39 deletions.
93 changes: 65 additions & 28 deletions tensorflow/python/util/tf_should_use.py
Expand Up @@ -17,14 +17,52 @@
from __future__ import division
from __future__ import print_function

import collections
import functools
import itertools
import traceback
import types

import six # pylint: disable=unused-import

from backports import weakref # pylint: disable=g-bad-import-order

from tensorflow.python.platform import tf_logging
from tensorflow.python.util import tf_decorator


class _RefInfoField(
collections.namedtuple(
'_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
pass


# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
# higher values in Python 3.4+. We don't expect to ever count higher than this.
# https://mail.python.org/pipermail/python-list/2005-April/342279.html
_REF_ITER = itertools.count()

# Dictionary mapping id(obj) => _RefInfoField.
_REF_INFO = {}


def _deleted(obj_id, fatal_error):
obj = _REF_INFO[obj_id]
del _REF_INFO[obj_id]
if not obj.object_used:
if fatal_error:
logger = tf_logging.fatal
else:
logger = tf_logging.error
logger(
'==================================\n'
'Object was never used (type %s):\n%s\nIf you want to mark it as '
'used call its "mark_used()" method.\nIt was originally created '
'here:\n%s\n'
'==================================' %
(obj.type_, obj.repr_, obj.creation_stack))


def _add_should_use_warning(x, fatal_error=False):
"""Wraps object x so that if it is never used, a warning is logged.
Expand All @@ -39,14 +77,14 @@ def _add_should_use_warning(x, fatal_error=False):
"""
if x is None: # special corner case where x is None
return x
has_been_used = getattr(x, '_tf_object_has_been_used', None)
if has_been_used is not None:
x._tf_object_has_been_used = has_been_used # pylint: disable=protected-access
if hasattr(x, '_tf_ref_id'): # this is already a TFShouldUseWarningWrapper
return x

def override_method(method):
def fn(self, *args, **kwargs):
self._tf_object_has_been_used = True # pylint: disable=protected-access
# pylint: disable=protected-access
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
return method(self, *args, **kwargs)
return fn

Expand All @@ -55,38 +93,36 @@ class TFShouldUseWarningWrapper(type(x)):

def __init__(self, true_self):
self.__dict__ = true_self.__dict__
stack = [x.strip() for x in traceback.format_stack()]
stack = [s.strip() for s in traceback.format_stack()]
# Remove top three stack entries from adding the wrapper
self._tf_object_creation_stack = '\n'.join(stack[:-3])
self._tf_object_has_been_used = False
self.creation_stack = '\n'.join(stack[:-3])
self._tf_ref_id = next(_REF_ITER)
_REF_INFO[self._tf_ref_id] = _RefInfoField(
type_=type(x),
repr_=repr(x),
creation_stack=stack,
object_used=False)

# Create a finalizer for self, which will be called when self is
# garbage collected. Can't add self as the args because the
# loop will break garbage collection. We keep track of
# ourselves via python ids.
weakref.finalize(self, _deleted, self._tf_ref_id, fatal_error)

# Not sure why this pylint warning is being used; this is not an
# old class form.
# pylint: disable=super-on-old-class
def __getattribute__(self, name):
if name != '_tf_object_has_been_used':
self._tf_object_has_been_used = True
if name == '_tf_ref_id':
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
if self._tf_ref_id in _REF_INFO:
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
return super(TFShouldUseWarningWrapper, self).__getattribute__(name)

def __del__(self):
if not self._tf_object_has_been_used:
if fatal_error:
logger = tf_logging.fatal
else:
logger = tf_logging.error
logger(
'==================================\n'
'Object was never used (type %s):\n%s\nIf you want to mark it as '
'used call its "mark_used()" method.\nIt was originally created '
'here:\n%s\n'
'==================================' %
(type(x), x, self._tf_object_creation_stack))

if hasattr(super(TFShouldUseWarningWrapper, self), '__del__'):
return super(TFShouldUseWarningWrapper, self).__del__()

def mark_used(self, *args, **kwargs):
self._tf_object_has_been_used = True
_REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
object_used=True)
if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
# pylint: enable=super-on-old-class
Expand All @@ -102,7 +138,8 @@ def mark_used(self, *args, **kwargs):

wrapped = TFShouldUseWarningWrapper(x)
wrapped.__doc__ = x.__doc__ # functools.wraps fails on some objects.
wrapped._tf_object_has_been_used = False # pylint: disable=protected-access
ref_id = wrapped._tf_ref_id # pylint: disable=protected-access
_REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
return wrapped


Expand Down
33 changes: 22 additions & 11 deletions tensorflow/python/util/tf_should_use_test.py
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

import contextlib
import gc
import sys

from tensorflow.python.framework import constant_op
Expand All @@ -45,52 +46,60 @@ def capture_errors(*args, **unused_kwargs):
class TfShouldUseTest(test.TestCase):

def testAddShouldUseWarningWhenNotUsed(self):
c = constant_op.constant(0, name='blah')
c = constant_op.constant(0, name='blah0')
captured = []
with reroute_error(captured):
def in_this_function():
h = tf_should_use._add_should_use_warning(c)
del h
in_this_function()
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah0:0', '\n'.join(captured))
self.assertIn('in_this_function', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)

def _testAddShouldUseWarningWhenUsed(self, fn):
c = constant_op.constant(0, name='blah')
def _testAddShouldUseWarningWhenUsed(self, fn, name):
c = constant_op.constant(0, name=name)
captured = []
with reroute_error(captured):
h = tf_should_use._add_should_use_warning(c)
fn(h)
del h
self.assertNotIn('Object was never used', '\n'.join(captured))
self.assertNotIn('blah:0', '\n'.join(captured))
self.assertNotIn('%s:0' % name, '\n'.join(captured))

def testAddShouldUseWarningWhenUsedWithAdd(self):
def add(h):
_ = h + 1
self._testAddShouldUseWarningWhenUsed(add)
self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
gc.collect()
self.assertFalse(gc.garbage)

def testAddShouldUseWarningWhenUsedWithGetName(self):
def get_name(h):
_ = h.name
self._testAddShouldUseWarningWhenUsed(get_name)
self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
gc.collect()
self.assertFalse(gc.garbage)

def testShouldUseResult(self):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah')
return constant_op.constant(value, name='blah2')
captured = []
with reroute_error(captured):
return_const(0.0)
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah2:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)

def testShouldUseResultWhenNotReallyUsed(self):
@tf_should_use.should_use_result
def return_const(value):
return constant_op.constant(value, name='blah')
return constant_op.constant(value, name='blah3')
captured = []
with reroute_error(captured):
with self.test_session():
Expand All @@ -100,8 +109,10 @@ def return_const(value):
v = constant_op.constant(1.0, name='meh')
v.eval()
self.assertIn('Object was never used', '\n'.join(captured))
self.assertIn('blah:0', '\n'.join(captured))
self.assertIn('blah3:0', '\n'.join(captured))
self.assertIn('return_const', '\n'.join(captured))
gc.collect()
self.assertFalse(gc.garbage)


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/tools/ci_build/install/install_pip_packages.sh
Expand Up @@ -85,3 +85,6 @@ pip2 install mock

pip2 install portpicker
pip3 install portpicker

pip2 install backports.weakref==1.0rc1
pip3 install backports.weakref==1.0rc1
Expand Up @@ -89,3 +89,6 @@ pip3.5 install wheel==0.29.0
pip3.5 install portpicker

pip3.5 install werkzeug

pip3.5 install backports.weakref==1.0rc1

1 change: 1 addition & 0 deletions tensorflow/tools/pip_package/setup.py
Expand Up @@ -39,6 +39,7 @@
'html5lib == 0.9999999', # identical to 1.0b8
'markdown == 2.2.0',
'bleach == 1.5.0',
'backports.weakref == 1.0rc1',
]

project_name = 'tensorflow'
Expand Down

0 comments on commit 51c4049

Please sign in to comment.