Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] auto-cast optimizers to distributed version #5746

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
Expand Down Expand Up @@ -1179,6 +1179,10 @@ def configure(
# TODO: Support Galore + ZeRO
zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config)

# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False
Expand Down
6 changes: 5 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer

Expand Down Expand Up @@ -437,6 +437,10 @@ def configure(
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()

# Replace with the distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False
Expand Down
21 changes: 21 additions & 0 deletions colossalai/nn/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from galore_torch import GaLoreAdafactor, GaLoreAdamW

from colossalai.logging import get_dist_logger

from .came import CAME
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
Expand Down Expand Up @@ -34,3 +36,22 @@
"Adafactor",
"DistributedAdaFactor",
]

optim2DistOptim = {
GaLoreAdamW8bit: DistGaloreAwamW,
Lamb: DistributedLamb,
CAME: DistributedCAME,
Adafactor: DistributedAdaFactor,
}
_logger = get_dist_logger()


def cast_to_distributed(optim):
if optim.__class__ in optim2DistOptim:
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])

if isinstance(optim, GaLoreAdamW8bit):
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args)
return optim2DistOptim[optim.__class__](optim.param_groups)

return optim
3 changes: 0 additions & 3 deletions colossalai/nn/optimizer/distributed_came.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def __init__(
betas=(0.9, 0.999, 0.9999),
weight_decay=0.0,
):
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])

defaults = dict(
lr=lr,
eps=eps,
Expand Down
15 changes: 9 additions & 6 deletions colossalai/nn/optimizer/distributed_galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
"""

def __init__(
self,
params,
lr=1e-3,
lr=1e-2,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
Expand All @@ -57,6 +58,7 @@ def __init__(
percentile_clipping=100,
block_wise=True,
is_paged=False,
args=None,
):
super().__init__(
"adam",
Expand All @@ -65,13 +67,14 @@ def __init__(
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
optim_bits=nbits,
args=args,
min_8bit_size=min_8bit_size,
percentile_clipping=percentile_clipping,
block_wise=block_wise,
is_paged=is_paged,
)

self.tp_size = 1
self.dp_size = 1
self.is_dist = {}
Expand Down
12 changes: 7 additions & 5 deletions colossalai/nn/optimizer/galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class GaLoreAdamW8bit(Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
Example:

"""
Expand All @@ -200,6 +201,7 @@ def __init__(
percentile_clipping=100,
block_wise=True,
is_paged=False,
args=None,
):
super().__init__(
"adam",
Expand All @@ -208,11 +210,11 @@ def __init__(
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
optim_bits=nbits,
args=args,
min_8bit_size=min_8bit_size,
percentile_clipping=percentile_clipping,
block_wise=block_wise,
is_paged=is_paged,
)

Expand Down
5 changes: 3 additions & 2 deletions docs/source/en/features/distributed_optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)

## Introduction
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins.
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change.

## Optimizers
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.

Expand All @@ -21,7 +22,7 @@ Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}

## Hands-On Practice
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
### step 1. Import libraries

```python
Expand Down
8 changes: 3 additions & 5 deletions docs/source/zh-Hans/features/distributed_optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)

## 介绍
除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。
除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过plugin与Tensor Parallel、DDP和ZeRO无缝集成。
## 优化器
Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现

Expand All @@ -21,7 +21,7 @@ Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}

## 使用
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
现在我们展示如何使用分布式 Adafactor booster API 结合 Tensor Parallel ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。
### step 1. 导包

```python
Expand All @@ -34,15 +34,13 @@ import torch
```

### step 2. 初始化分布式
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
我们需要先初始化分布式环境. 为了展示, 我们使用 `colossal run --nproc_per_node 4`. 更多初始化方式请参考 [Launch Colossal-AI](../basics/launch_colossalai.md)

```python
colossalai.launch_from_torch()
```

### step 3. 初始化模型和优化器
Build our model. We created an MLP using two Linear Layer.

```python
configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizer/test_dist_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
sharded_optimizer,
criterion,
booster,
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, Adafactor)

org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizer/test_dist_came.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
sharded_optimizer,
criterion,
booster,
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)

org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
Expand Down
6 changes: 3 additions & 3 deletions tests/test_optimizer/test_dist_galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def check_dist_galore(rank, world_size, port):
global coordinator
coordinator = DistCoordinator()

run_dist_galore_basic()
coordinator.print_on_master("Basic backward tests passed")
# run_dist_galore_basic()
# coordinator.print_on_master("Basic backward tests passed")

coordinator.print_on_master("Skipping forward-backward tests due to SVD instability")
# run_dist_galore_fwd_bwd()
Expand All @@ -319,7 +319,7 @@ def check_dist_galore(rank, world_size, port):
)
for config in test_config:
try:
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW)
run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=GaLoreAdamW8bit)
except Exception as e:
print(e)
dist.barrier()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optimizer/test_dist_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def check_dist_lamb(rank, world_size, port):
run_dist_lamb_fwd_bwd()
coordinator.print_on_master("Forward-backward tests passed")

run_bert_test(optim_class=Lamb, sharded_optim_class=DistributedLamb)
run_bert_test(optim_class=Lamb, sharded_optim_class=Lamb)
print(f"rank {rank} tests passed :)")


Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.nn.optimizer import GaLoreAdamW8bit
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand Down Expand Up @@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin(
if use_lazy_init:
ctx.materialize(org_model)
org_model = org_model.cuda()
if sharded_optim_class == DistGaloreAwamW:
if optim_class == GaLoreAdamW8bit:
# Disable clipping and block-wise quantization
org_optimizer = optim_class(
get_galore_param_groups(org_model, weight_decay=0, rank=4),
Expand Down
Loading