Skip to content

hiworldwzj/keras_recompute

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 

Repository files navigation

recompute

通过重计算来节省显存,参考论文《Training Deep Nets with Sublinear Memory Cost》

本程序已经内置在bert4keras

使用方法

首先,确保环境变量加上RECOMPUTE=1

然后,在自定义层的时候,用recompute_grad装饰call函数即可:

from recompute import recompute_grad

class MyLayer(Layer):
    @recompute_grad
    def call(self, inputs):
        return inputs * 2

如果是现成的层,可以通过继承的方式来装饰:

from recompute import recompute_grad

class MyDense(Dense):
    @recompute_grad
    def call(self, inputs):
        return super(MyDense, self).call(inputs)

环境依赖

在下面的环境下测试通过:

tensorflow 1.14 + keras 2.3.1
tensorflow 1.15 + keras 2.3.1
tensorflow 2.0 + keras 2.3.1
tensorflow 2.1 + keras 2.3.1
tensorflow 2.0 + 自带tf.keras
tensorflow 2.1 + 自带tf.keras

确认不支持的环境:

tensorflow 1.x + 自带tf.keras

欢迎报告更多的测试结果。

强烈建议用keras 2.3.1配合tensorflow来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑

使用效果

  • 在BERT Base版本下,batch_size可以增大为原来的3倍左右;
  • 在BERT Large版本下,batch_size可以增大为原来的4倍左右;
  • 平均每个样本的训练时间大约增加25%;
  • 理论上,层数越多,batch_size可以增大的倍数越大。

参考内容

About

saving memory by recomputing for keras

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%