## basics

- https://arxiv.org/pdf/1604.06174.pdf
- https://medium.com/tensorflow/fitting-larger-networks-into-memory-
    - **computation graph examples**
    - 节点在必要时才加载进来
    - At the peak, the algorithm stores all activations,
        - forward
        - $O(n)$ memory requirement for network of depth $n$.
    - CHECKPOINT: recomputing them later.
        - More generally, this “memory-poor” strategy needs $O(1)$ memory but requires $O(n^2)$ computation steps.
- https://huggingface.co/docs/transformers/v4.18.0/en/performance

- 依然是显存占用优化算法
    - 当然是 memory usage 与 computation time 之间的 tradeoff 在反向传播过程中；
- In deep neural networks, backpropagation requires storing **intermediate activations** for computing gradients during the backward pass. 
    - 但是当层数变多时，存储所有的中间层的激活值（intermediate activations）非常地占用显存；
- gradient checkpointing 选择性地重新计算（recompute）一部分的 intermediate activations 在反向传播过程中来缓解显存的压力；
    - Instead of storing all activations (**during the forward pass**), only a **subset** of them, typically those necessary for computing gradients, are cached. 
    - The remaining intermediate activations are recomputed on-the-fly **during the backward pass**. By recomputing rather than storing all intermediate activations, memory usage is reduced at the cost of increased computation time.

## Trainer Arguments

### dataset

In [1]:
import numpy as np
from datasets import Dataset


seq_len, dataset_size = 512, 512
dummy_data = {
    "input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
    "labels": np.random.randint(0, 2, (dataset_size)),
}
ds = Dataset.from_dict(dummy_data)
ds.set_format("pt")

In [7]:
print(dummy_data['input_ids'].shape, dummy_data['labels'].shape)

(512, 512) (512,)


In [8]:
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.")


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [10]:
print_gpu_utilization()

GPU memory occupied: 352 MB.


### model

In [2]:
from transformers import AutoModelForSequenceClassification

In [3]:
model = AutoModelForSequenceClassification.from_pretrained('bert-large-uncased').to('cuda')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
print_gpu_utilization()

GPU memory occupied: 2511 MB.


### training without checkpint

In [6]:
from transformers import TrainingArguments, Trainer, logging

default_args = {
    "output_dir": "tmp",
    "evaluation_strategy": "steps",
    "num_train_epochs": 1,
    "log_level": "error",
    "report_to": "none",
}

In [15]:

logging.set_verbosity_error()


training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
print_summary(result)



{'train_runtime': 22.528, 'train_samples_per_second': 22.727, 'train_steps_per_second': 2.841, 'train_loss': 0.7311427593231201, 'epoch': 1.0}
Time: 22.53
Samples/second: 22.73
GPU memory occupied: 12507 MB.


### training with checkpoint

In [7]:
training_args = TrainingArguments(
    per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, **default_args
)

trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()




Step,Training Loss,Validation Loss


NameError: name 'print_summary' is not defined

In [9]:
print_summary(result)

Time: 819.17
Samples/second: 0.62
GPU memory occupied: 12865 MB.


In [10]:
# steps(total batches)
512/((1*2) * 4)

64.0

In [13]:
for para in model.named_parameters():
    print(para[0], para[1].device)

bert.embeddings.word_embeddings.weight cuda:0
bert.embeddings.position_embeddings.weight cuda:0
bert.embeddings.token_type_embeddings.weight cuda:0
bert.embeddings.LayerNorm.weight cuda:0
bert.embeddings.LayerNorm.bias cuda:0
bert.encoder.layer.0.attention.self.query.weight cuda:0
bert.encoder.layer.0.attention.self.query.bias cuda:0
bert.encoder.layer.0.attention.self.key.weight cuda:0
bert.encoder.layer.0.attention.self.key.bias cuda:0
bert.encoder.layer.0.attention.self.value.weight cuda:0
bert.encoder.layer.0.attention.self.value.bias cuda:0
bert.encoder.layer.0.attention.output.dense.weight cuda:0
bert.encoder.layer.0.attention.output.dense.bias cuda:0
bert.encoder.layer.0.attention.output.LayerNorm.weight cuda:0
bert.encoder.layer.0.attention.output.LayerNorm.bias cuda:0
bert.encoder.layer.0.intermediate.dense.weight cuda:0
bert.encoder.layer.0.intermediate.dense.bias cuda:0
bert.encoder.layer.0.output.dense.weight cuda:0
bert.encoder.layer.0.output.dense.bias cuda:0
bert.encoder

## pytorch

In [1]:
from torch.utils.checkpoint import checkpoint_sequential

## huggingface

- checkpoint default module

```
class BertPreTrainedModel(PreTrainedModel):
    def _set_gradient_checkpointing(self, module, value=False):
        # 只对 BertEncoder 进行 checkpoint
        if isinstance(module, BertEncoder):
            module.gradient_checkpointing = value
```