## basics

- https://arxiv.org/pdf/1604.06174.pdf
- https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
- https://huggingface.co/docs/transformers/v4.18.0/en/performance

- 依然是显存占用优化算法
    - 当然是 memory usage 与 computation time 之间的 tradeoff 在反向传播过程中；
- In deep neural networks, backpropagation requires storing **intermediate activations** for computing gradients during the backward pass. 
    - 但是当层数变多时，存储所有的中间层的激活值（intermediate activations）非常地占用显存；
- gradient checkpointing 选择性地重新计算（recompute）一部分的 intermediate activations 在反向传播过程中来缓解显存的压力；
    - Instead of storing all activations (**during the forward pass**), only a **subset** of them, typically those necessary for computing gradients, are cached. 
    - The remaining intermediate activations are recomputed on-the-fly **during the backward pass**. By recomputing rather than storing all intermediate activations, memory usage is reduced at the cost of increased computation time.

## pytorch

In [1]:
from torch.utils.checkpoint import checkpoint_sequential

## huggingface

- checkpoint default module

```
class BertPreTrainedModel(PreTrainedModel):
    def _set_gradient_checkpointing(self, module, value=False):
        # 只对 BertEncoder 进行 checkpoint
        if isinstance(module, BertEncoder):
            module.gradient_checkpointing = value
```