Skip to content

Commit

Permalink
Updating suggested usage
Browse files Browse the repository at this point in the history
Existing documentated way of using the package with TF is not relevant anymore. This commit updates the documentation and allows to replace gradients in a more robust way
  • Loading branch information
yselivonchyk committed Jul 18, 2019
1 parent 039cc31 commit de64dac
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions memory_saving_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
# save original gradients since tf.gradient could be monkey-patched to point
# to our version
from tensorflow.python.ops import gradients as tf_gradients_lib
tf_gradient_function = tf_gradients_lib.gradients

# ISSUE: https://github.com/cybertronai/gradient-checkpointing/issues/38
def tf_gradients(ys, *args, **kwargs):
"""Decorate tf.gradients calls with explicit device placement to avoid memory
leaks when splitting model across multiple GPUs"""
source = ys[0] if isinstance(ys, (list, tuple)) else ys
device = source.op.node_def.device if isinstance(source, tf.Tensor) else None
print("SETTING DEVICE:", device)
with tf.device(device):
return tf_gradients_lib.gradients(ys, *args, **kwargs)
return tf_gradient_function(ys, *args, **kwargs)


MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing
Expand Down

0 comments on commit de64dac

Please sign in to comment.