Skip to content

Commit

Permalink
Fix memory leak when using tf.layers
Browse files Browse the repository at this point in the history
Uses `weakref` so that PER_GRAPH_LAYER_NAME_UIDS doesn't prevent Graphs from being garbage collected.

Fixes tensorflow#11273
  • Loading branch information
drasmuss committed Jul 4, 2017
1 parent fc87d91 commit ad20e17
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 4 additions & 0 deletions tensorflow/contrib/keras/python/keras/backend.py
Expand Up @@ -21,6 +21,7 @@
from __future__ import division
from __future__ import print_function

import collections
import json
import os

Expand Down Expand Up @@ -263,6 +264,9 @@ def get_uid(prefix=''):
```
"""
graph = ops.get_default_graph()
if graph not in tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS:
tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(
int)
layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
layer_name_uids[prefix] += 1
return layer_name_uids[prefix]
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/python/layers/base.py
Expand Up @@ -27,6 +27,7 @@
import copy
import functools
import re
import weakref

from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
Expand Down Expand Up @@ -671,8 +672,7 @@ def _object_list_uid(object_list):
# A global dictionary mapping graph objects to an index of counters used
# for various layer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict(
lambda: collections.defaultdict(int))
PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()


def _unique_layer_name(name):
Expand All @@ -694,6 +694,8 @@ def _unique_layer_name(name):
```
"""
graph = ops.get_default_graph()
if graph not in PER_GRAPH_LAYER_NAME_UIDS:
PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(int)
layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
layer_name_uids[name] += 1
return name + '_' + str(layer_name_uids[name])

0 comments on commit ad20e17

Please sign in to comment.