## basics

- https://arxiv.org/pdf/1604.06174.pdf
- https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
    - **computation graph examples**
    - 节点在必要时才加载进来
    - At the peak, the algorithm stores all activations,
        - forward
        - $O(n)$ memory requirement for network of depth $n$.
    - recomputing them later.
        - More generally, this “memory-poor” strategy needs $O(1)$ memory but requires $O(n^2)$ computation steps.
    - https://yaroslavvb.medium.com/backprop-and-systolic-arrays-24e925d2050
- https://huggingface.co/docs/transformers/v4.18.0/en/performance

- 依然是显存占用优化算法
    - 当然是 memory usage 与 computation time 之间的 tradeoff ；
- In deep neural networks, backpropagation requires storing **intermediate activations** for computing gradients during the backward pass. 
    - 但是当层数变多时，存储所有的中间层的激活值（intermediate activations）非常地占用显存；
- gradient checkpointing 选择性地重新计算（recompute）一部分的 intermediate activations 在反向传播过程中来缓解显存的压力；
    - Instead of storing all activations (**during the forward pass**), only a **subset** of them, typically those necessary for computing gradients, are cached. 
    - The remaining intermediate activations are recomputed on-the-fly **during the backward pass**. By recomputing rather than storing all intermediate activations, memory usage is reduced at the cost of increased computation time.

## Trainer Arguments

In [10]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

### dataset

In [2]:
import numpy as np
from datasets import Dataset


seq_len, dataset_size = 512, 512
dummy_data = {
    "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
    "labels": np.random.randint(0, 2, (dataset_size)),
}
ds = Dataset.from_dict(dummy_data)
ds.set_format("pt")

In [3]:
print(dummy_data['input_ids'].shape, dummy_data['labels'].shape)

(512, 512) (512,)


In [4]:
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [5]:
print_gpu_utilization()

GPU memory occupied: 355 MB.


### model

In [6]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('bert-large-uncased').to('cuda')
print_gpu_utilization()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPU memory occupied: 2511 MB.


### training without checkpint

In [7]:
from transformers import TrainingArguments, Trainer, logging

default_args = {
    "output_dir": "tmp",
    "evaluation_strategy": "steps",
    "num_train_epochs": 1,
    "log_level": "error",
    "report_to": "none",
}

In [8]:
# logging.set_verbosity_error()
# 跑 training with checkpint 时，需要把这段代码注释掉
training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
print_summary(result)

Step,Training Loss,Validation Loss


Time: 23.78
Samples/second: 21.53
GPU memory occupied: 12917 MB.


### training with checkpoint

In [8]:
training_args = TrainingArguments(
    per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, **default_args
)
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
print_summary(result)

Step,Training Loss,Validation Loss


Time: 38.96
Samples/second: 13.14
GPU memory occupied: 9305 MB.


In [9]:
# steps(total batches)
512/((1*1) * 4)

128.0

## pytorch

In [1]:
from torch.utils.checkpoint import checkpoint_sequential

## huggingface

- checkpoint default module

```
class BertPreTrainedModel(PreTrainedModel):
    def _set_gradient_checkpointing(self, module, value=False):
        # 只对 BertEncoder 进行 checkpoint
        if isinstance(module, BertEncoder):
            module.gradient_checkpointing = value
```