Skip to content

Commit

Permalink
[doc] polish shardformer doc (#4779)
Browse files Browse the repository at this point in the history
* fix example format in docstring

* polish shardformer doc
  • Loading branch information
Fridge003 committed Sep 26, 2023
1 parent 26cd6d8 commit a2db755
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 73 deletions.
21 changes: 11 additions & 10 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,17 @@ class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import GeminiPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
model, train_dataset, optimizer, criterion = ...
plugin = GeminiPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
chunk_config_dict (dict, optional): chunk configuration dictionary.
Expand Down
17 changes: 9 additions & 8 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase):
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
model, train_dataset, optimizer, criterion = ...
plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
```
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
Expand Down
21 changes: 11 additions & 10 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import LowLevelZeroPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
model, train_dataset, optimizer, criterion = ...
plugin = LowLevelZeroPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
strage (int, optional): ZeRO stage. Defaults to 1.
Expand Down
21 changes: 11 additions & 10 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase):
"""
Plugin for PyTorch DDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchDDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
model, train_dataset, optimizer, criterion = ...
plugin = TorchDDPPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
Expand Down
21 changes: 11 additions & 10 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchFSDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchFSDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchFSDPPlugin
model, train_dataset, optimizer, criterion = ...
plugin = TorchFSDPPlugin()
train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
Expand Down
45 changes: 25 additions & 20 deletions colossalai/cluster/dist_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ class in the whole program.
- master: the process with rank 0
- node master: the process with local rank 0 on the current node
Example:
>>> from colossalai.cluster.dist_coordinator import DistCoordinator
>>> coordinator = DistCoordinator()
>>>
>>> if coordinator.is_master():
>>> do_something()
>>>
>>> coordinator.print_on_master('hello world')
```python
from colossalai.cluster.dist_coordinator import DistCoordinator
coordinator = DistCoordinator()
if coordinator.is_master():
do_something()
coordinator.print_on_master('hello world')
```
Attributes:
rank (int): the rank of the current process
Expand Down Expand Up @@ -131,11 +133,13 @@ def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup
other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption.
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>> with dist_coordinator.priority_execution():
>>> dataset = CIFAR10(root='./data', download=True)
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
with dist_coordinator.priority_execution():
dataset = CIFAR10(root='./data', download=True)
```
Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
Expand Down Expand Up @@ -174,13 +178,14 @@ def on_master_only(self, process_group: ProcessGroup = None):
"""
A function wrapper that only executes the wrapped function on the master process (rank 0).
Example:
>>> from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator()
>>>
>>> @dist_coordinator.on_master_only()
>>> def print_on_master(msg):
>>> print(msg)
```python
from colossalai.cluster import DistCoordinator
dist_coordinator = DistCoordinator()
@dist_coordinator.on_master_only()
def print_on_master(msg):
print(msg)
```
"""
is_master = self.is_master(process_group)

Expand Down
74 changes: 71 additions & 3 deletions docs/source/en/features/shardformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,83 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.

More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute
```bash
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
```
Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`.

Let's delve into the code of `finetune.py`:

In the `main` function, the plugin is created through the following codes:
```python
...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)
```
Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md).

If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way).
However, if pipeline parallel is enabled, there are several usages different from other normal cases:

1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:
```python
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
```

2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline:
```python
train_dataloader_iter = iter(train_dataloader)
```

[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
```python
outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
)
```
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md).


#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)

You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.

[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
is an example on how to trigger `Shardformer` through calling Shardformer APIs.
is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes:
```python
...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")

# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)

# Then create ShardFormer object with created config
shard_former = ShardFormer(shard_config=shard_config)

# Finally shard the model using ShardFormer.optimize method
model, _ = shard_former.optimize(model)
...
```

### Precautions

Expand All @@ -241,6 +307,8 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.

## How Shardformer Works

### Main Idea

Generally, Shardformer works through the following four kinds of *replacements*:

1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.
Expand Down
73 changes: 71 additions & 2 deletions docs/source/zh-Hans/features/shardformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,83 @@ Shardformer的配置由类`ShardConfig`的参数控制:

通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster``execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能(例如混合精度训练或Zero)相结合的能力。

更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
移动到示例的根目录下,执行命令:
```bash
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
```
你便可以微调一个被`Shardformer`封装过的Bert模型,而封装的操作是由`HybridParallelPlugin`完成的。

接下来一起深入挖掘一下`finetune.py`里的代码:

`main`函数中,混合并行的插件通过以下的代码创建
```python
...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)
```
在这里你可以通过设置不同的`tp_size`, `pp_size``zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。

当流水并行不被启用的时候,训练的流程和其他的插件是一样的 (先用Booster封装模型和优化器,再用正常的方式做前向和后向传递)。然而,当流水线并行被启用的时候,有几处不同于寻常情况的用法:

1. 在进行前向和后向之前,criterion函数(loss函数)需要被处理以满足流水线并行的传参要求:
```python
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
```

2.`train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类:
```python
train_dataloader_iter = iter(train_dataloader)
```

3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
```python
outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
)
```
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。
更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)

#### 2. 通过Shardformer API启动Shardformer (不推荐)

您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法,因为流水线并行在没有`Booster`的情况下无法正常运行。

[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
是一个通过调用Shardformer的API启动`Shardformer`的示例。

在示例代码的`train`函数中,模型被以下的几行代码进行封装:
```python
...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")

# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)

# Then create ShardFormer object with created config
shard_former = ShardFormer(shard_config=shard_config)

# Finally shard the model using ShardFormer.optimize method
model, _ = shard_former.optimize(model)
...
```

### 注意事项

Expand All @@ -234,6 +301,8 @@ Shardformer的配置由类`ShardConfig`的参数控制:

## Shardformer的工作原理

### 设计思想

通常来说,Shardformer通过以下四种“替换”进行工作:

1. 用我们设计的分布式模块替换原始的PyTorch模块(例如`nn.Linear``nn.Embedding`)。
Expand Down

0 comments on commit a2db755

Please sign in to comment.