Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Translate save_gpu_memory.md #803

Merged
merged 11 commits into from
Dec 12, 2022
111 changes: 110 additions & 1 deletion docs/en/examples/save_gpu_memory.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,112 @@
# Save Memory on GPU

Coming soon. Please refer to [chinese documentation](https://mmengine.readthedocs.io/zh_CN/latest/examples/save_gpu_memory.html).
Memory capacity is critical in deep learning training and inference and determines whether the model can run successfully. Common memory saving approaches include:

- Gradient Accumulation

Gradient accumulation is the mechanism that runs at a configured number of steps accumulating the gradients instead of updating parameters, after which the network parameters are updated and the gradients are cleared. With this technique of delayed parameter update, the result is similar to those scenarios using a large batch size, while the memory of activation can be saved. However, it should be noted that if the model contains a batch normalization layer, using gradient accumulation will impact performance.

- Gradient Checkpointing

Gradient checkpointing is a time-for-space method that compresses the model by reducing the number of saved activations, however, the unstored activations must be recomputed when calculating the gradient. The corresponding functionality has been implemented in the `torch.utils.checkpoint` package. The implementation can be briefly concluded as that, in the forward phase, the forward function passed to the checkpoint runs in `torch.no_grad` mode and saves only the input and the output of the forward function. Then recalculates its intermediate activations in the backward phase.

- Large Model Training Techniques

Recent research has shown that training a large model would be helpful to improve performance, but training a model at such a scale requires huge resources, and it is hard to store the entire model in the memory of a single graphics card. Therefore large model training techniques, typically such as [DeepSpeed ZeRO](https://www.deepspeed.ai/tutorials/zero/#zero-overview) and the Fully Shared Data Parallel ([FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/)) technique introduced in FairScale are introduced. These techniques allow slicing the parameters, gradients, and optimizer states among the parallel processes, while still maintaining the simplicity of the data parallelism.

MMEngine now supports gradient accumulation and large model training FSDP techniques, and the usages are described as follows.

## Gradient Accumulation

The configuration can be written in this way:

```python
optim_wrapper_cfg = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.001, momentum=0.9),
# update every four times
accumulative_counts=4)
```

The full example working with `Runner` is as follows.

```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(1, 1)

def forward(self, img, label, mode):
feat = self.linear(img)
loss1 = (feat - label).pow(2)
loss2 = (feat - label).abs()
return dict(loss1=loss1, loss2=loss2)


runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01),
accumulative_counts=4)
)
runner.train()
```

## Large Model Training

`FSDP` is officially supported from PyTorch 1.11. The config can be written in this way:

```python
# located in cfg file
model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True)
```

The full example working with `Runner` is as follows.

```python
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmengine.runner import Runner
from mmengine.model import BaseModel

train_dataset = [(torch.ones(1, 1), torch.ones(1, 1))] * 50
train_dataloader = DataLoader(train_dataset, batch_size=2)


class ToyModel(BaseModel):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(1, 1)

def forward(self, img, label, mode):
feat = self.linear(img)
loss1 = (feat - label).pow(2)
loss2 = (feat - label).abs()
return dict(loss1=loss1, loss2=loss2)


runner = Runner(
model=ToyModel(),
work_dir='tmp_dir',
train_dataloader=train_dataloader,
train_cfg=dict(by_epoch=True, max_epochs=1),
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.01)),
cfg=dict(model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True))
)
runner.train()
```

Please be noted that `FSDP` works only in distributed training environments.
2 changes: 1 addition & 1 deletion docs/zh_cn/examples/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

- 梯度检查点

梯度检查点是一种以时间换空间的方法,通过减少保存的激活值来压缩模型占用空间,但是在计算梯度时必须重新计算没有存储的激活值。在 torch.utils.checkpoint 包中已经实现了对应功能。简要实现过程是:在前向阶段传递到 checkpoint 中的 forward 函数会以 `torch.no_grad` 模式运行,并且仅仅保存输入参数和 forward 函数,在反向阶段重新计算其 forward 输出值
梯度检查点是一种以时间换空间的方法,通过减少保存的激活值来压缩模型占用空间,但是在计算梯度时必须重新计算没有存储的激活值。在 torch.utils.checkpoint 包中已经实现了对应功能。简要实现过程是:在前向阶段传递到 checkpoint 中的 forward 函数会以 `torch.no_grad` 模式运行,并且仅仅保存 forward 函数的输入和输出,然后在反向阶段重新计算中间层的激活值 (intermediate activations)

- 大模型训练技术

Expand Down