From 0753b0c79054aa848e8b9eabe486b9d221da2894 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 8 Sep 2017 16:52:18 -0700 Subject: [PATCH] Scope the scalar cache in the context. PiperOrigin-RevId: 168065417 --- tensorflow/python/eager/context.py | 5 +++++ tensorflow/python/framework/constant_op.py | 12 +++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 185cd9a7165ecf..8496a02947f2c2 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -53,6 +53,7 @@ def __init__(self): self.mode = _default_mode self.scope_name = "" self.recording_summaries = False + self.scalar_cache = {} # TODO(agarwal): rename to EagerContext / EagerRuntime ? @@ -157,6 +158,10 @@ def in_eager_mode(self): """Returns True if current thread is in EAGER mode.""" return self._eager_context.mode == EAGER_MODE + def scalar_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.scalar_cache + @property def scope_name(self): """Returns scope name for the current thread.""" diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 819730a51b901b..a859645950db4b 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -74,10 +74,6 @@ def _eager_fill(dims, value): return result -# Rely on the GIL for thread-safety. -_scalar_cache = {} - - def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), ops.EagerTensor): @@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None): # Use a scalar cache. This will put each scalar of each type only once on # each device. Scalars don't use much device memory but copying scalars can # trigger memcpys which are slow. - device = context.context().device_name + ctx = context.context() + device = ctx.device_name cache_key = device, t, dtype, type(t) - tensor = _scalar_cache.get(cache_key, None) + scalar_cache = ctx.scalar_cache() + tensor = scalar_cache.get(cache_key, None) if tensor is not None: return tensor value = ops.EagerTensor(t, dtype=dtype) - _scalar_cache[cache_key] = value + scalar_cache[cache_key] = value return value return ops.EagerTensor(t, dtype=dtype)