Skip to content

Commit

Permalink
Remove workaround for broken ReplicaLocalVariable cross-replica read_…
Browse files Browse the repository at this point in the history
…value.

PiperOrigin-RevId: 253558060
Change-Id: Ia9a37ef474a7a59dc025d7b5a8fb45b78187a648
  • Loading branch information
petebu authored and sonnet-copybara committed Jun 17, 2019
1 parent 154a552 commit 1f3f060
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 34 deletions.
2 changes: 1 addition & 1 deletion requirements-tf.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tf-nightly-2.0-preview==2.0.0.dev20190614
tf-nightly-2.0-preview==2.0.0.dev20190617
tfp-nightly==0.8.0.dev20190610
35 changes: 2 additions & 33 deletions sonnet/src/replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,19 @@
# from __future__ import google_type_annotations
from __future__ import print_function

import functools

import contextlib
import tensorflow as tf


@contextlib.contextmanager
def maybe_enter_scope(strategy):
"""Enter the strategy scope if it is not already active."""
if strategy is not tf.distribute.get_strategy():
with strategy.scope():
yield
else:
yield


def replica_local_read_value(v):
"""Replaces `read_value` on `v` so that it works in cross-replica context."""
@functools.wraps(v.read_value)
def wrapper():
with maybe_enter_scope(v.distribute_strategy): # pylint: disable=not-context-manager
ctx = tf.distribute.get_replica_context()
if ctx is None:
return v._values[0].read_value() # pylint: disable=protected-access
else:
return v.get().read_value()
return wrapper


def replica_local_creator(getter, **kwargs) -> tf.Variable:
def replica_local_creator(next_creator, **kwargs) -> tf.Variable:
"""Variable creator that by default creates replica local variables."""
if kwargs["synchronization"] == tf.VariableSynchronization.AUTO:
kwargs["synchronization"] = tf.VariableSynchronization.ON_READ
if kwargs["aggregation"] == tf.VariableAggregation.NONE:
kwargs["aggregation"] = tf.VariableAggregation.ONLY_FIRST_REPLICA
if kwargs["trainable"] is None:
kwargs["trainable"] = True
v = getter(**kwargs)

# TODO(petebu): Remove when local variables support x-replica read_value.
v.read_value = replica_local_read_value(v)
else:
v = getter(**kwargs)
return v
return next_creator(**kwargs)


class Replicator(tf.distribute.MirroredStrategy):
Expand Down

0 comments on commit 1f3f060

Please sign in to comment.