Skip to content

mathemakitten/gradient-checkpointing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

Gradient checkpointing

Gradient checkpointing for graph mode execution in Tensorflow 2

This is a standalone version extracted from the original implementation in tf-slim.

If using eager execution, use tf.recompute_grad.

For more information on recomputing gradients between graph nodes during backpropagation, see the original gradient checkpointing repository.


Tested with tf-nightly==2.2.0.dev20200303 in graph mode on TPU.

Example usage for a model built with a Keras layer call method:

def call(self, x, past):
    @gradient_checkpointing.recompute_grad
    def inner(x):
        # ops go here
        return y
    return inner(x)

Note: Gradient checkpointing can significantly slow down training.

About

Gradient checkpointing for graph mode in Tensorflow 2

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages