diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 0e5fb5feb7e4..f4f9f7a5021d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -27,7 +27,7 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim -from colossalai.nn.optimizer import DistGaloreAwamW8bit +from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer @@ -1179,7 +1179,7 @@ def configure( # TODO: Support Galore + ZeRO zero_stage = self.zero_stage zero_config = deepcopy(self.zero_config) - if isinstance(optimizer, DistGaloreAwamW8bit) and zero_stage > 0 and self.dp_size > 0: + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") zero_config["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 46f3ba03dbd7..dfc743fe5f33 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -32,7 +32,7 @@ ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim -from colossalai.nn.optimizer import DistGaloreAwamW8bit +from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.zero import LowLevelZeroOptimizer @@ -437,7 +437,7 @@ def configure( zero_stage = self.stage zero_optim_kwargs = {**self.zero_optim_kwargs} dp_size = dist.get_world_size() - if isinstance(optimizer, DistGaloreAwamW8bit) and zero_stage > 0 and dp_size > 0: + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 905dc11fdb91..c7261b1bcf7c 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -4,7 +4,7 @@ from .cpu_adam import CPUAdam from .distributed_adafactor import DistributedAdaFactor from .distributed_came import DistributedCAME -from .distributed_galore import DistGaloreAwamW8bit +from .distributed_galore import DistGaloreAwamW from .distributed_lamb import DistributedLamb from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB @@ -25,7 +25,7 @@ "CPUAdam", "HybridAdam", "DistributedLamb", - "DistGaloreAwamW8bit", + "DistGaloreAwamW", "GaLoreAdamW", "GaLoreAdafactor", "GaLoreAdamW8bit", diff --git a/colossalai/nn/optimizer/distributed_galore.py b/colossalai/nn/optimizer/distributed_galore.py index 1b67c2d462f7..3f42dd5b99c0 100644 --- a/colossalai/nn/optimizer/distributed_galore.py +++ b/colossalai/nn/optimizer/distributed_galore.py @@ -18,7 +18,7 @@ # Mark sharded dimension -class DistGaloreAwamW8bit(DistributedOptim, Optimizer2State): +class DistGaloreAwamW(DistributedOptim, Optimizer2State): r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW. It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr. Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin. @@ -35,12 +35,14 @@ class DistGaloreAwamW8bit(DistributedOptim, Optimizer2State): numerical stability. (default: 1e-6) weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) nbits: Number of bits for quantization optim states. Only 32 and 8 are supported. - Example: - >>> optim = DistributedLamb(model.parameters(), lr=1e-3) - >>> proc_mesh = ProcessGroupMesh(tp_size, zero_size) - >>> tp_group = proc_mesh.get_group_along_axis(0) - >>> dp_group = proc_mesh.get_group_along_axis(1) - >>> optim.setup_distributed(tp_group, dp_group) + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not. """ def __init__( diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 28ddde6e36a7..5f7f2a4e2249 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -396,15 +396,15 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): else: if bucket_store.moe_extra_dp_pg is None: flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + received_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) if received_grad.dtype != grad_dtype: received_grad = received_grad.to(grad_dtype) grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank] LowLevelZeroOptimizer.update_partitoned_grad( - bucket_store, grad_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1 + bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1 ) else: # categorize moe and non moe param @@ -421,13 +421,13 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): flat_grads_list = list( non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size) ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + received_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg) LowLevelZeroOptimizer.update_partitoned_grad( bucket_store, grad_store, non_moe_grad_in_bucket_current_rank, - recieved_grad, + received_grad, group_id, 1, ) @@ -436,15 +436,15 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore): flat_grads_list = list( moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size) ) - recieved_grad = torch.zeros_like(flat_grads_list[0]) + received_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter( - recieved_grad, + received_grad, flat_grads_list, group=bucket_store.moe_extra_dp_pg, ) param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size - recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) - for split_recieved_grad in recieved_grad: + received_grad = list(received_grad.split(len(received_grad) // param_slice)) + for split_recieved_grad in received_grad: split_recieved_grad = _unflatten_dense_tensors( split_recieved_grad, moe_grad_in_bucket_current_rank ) diff --git a/docs/source/en/features/distributed_adafactor.md b/docs/source/en/features/distributed_adafactor.md deleted file mode 100644 index 8d3691177ad6..000000000000 --- a/docs/source/en/features/distributed_adafactor.md +++ /dev/null @@ -1,156 +0,0 @@ -# Distributed Adafactor - -Author: - -**Related Paper** -- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) - -## Introduction - -Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details. - -## API Reference - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} - -## Hands-On Practice -We now demonstrate how to start Distributed Adafactor with booster API. -### step 1. Import libraries - -```python -import torch -from torch import nn -import torch.distributed as dist -from transformers import LlamaModel, LlamaConfig - -from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row -from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor -from colossal_llama2.dataset.loader import load_tokenized_dataset -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -``` - -### step 2. Initialize Distributed Environment and Parallism Group -We then need to initialize distributed environment. For demo purpose, we uses `colossalai.launch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) -for other initialization methods. We use `ProcessGroupMesh` to create tensor parallelism group and data parallelism group. - -```python -# Distributed Enviroment -config = {} -colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") -``` - -### step 3. Initialize Module and Optimizer -Build our model. We created an MLP using two Linear Layer. - -```python -# Init Llama from huggingface -configuration = LlamaConfig() -model = LlamaModel(configuration) -dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") -dataloader = plugin.prepare_dataloader(dataset, batch_size=8) -criterion = lambda x: x.mean() -dist_optim = DistributedAdaFactor(model.parameters()) - -``` - -### step 4.Init Booster - -```python -plugin = LowLevelZeroPlugin() -booster = Booster(plugin=plugin) -model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) -``` -### step 5.Train Your Model -```python -for epoch in range(max_epochs): - for input_ids, attention_mask in dataloader: - outputs = model(input_ids.cuda(), attention_mask.cuda()) - loss = criterion(outputs.logits, input_ids) - booster.backward(loss, optimizer) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() -``` - -## Supporting Information -Model/Feature Compatibility Matrix: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Hybrid Parallel
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
- - diff --git a/docs/source/en/features/distributed_optimizers.md b/docs/source/en/features/distributed_optimizers.md new file mode 100644 index 000000000000..7590669dfd2d --- /dev/null +++ b/docs/source/en/features/distributed_optimizers.md @@ -0,0 +1,141 @@ +# Distributed Optimizers + +Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github.com/duanjunwen), [Renjie Mao](https://github.com/chongqichuizi875) + +**Related Paper** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) +- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047) +- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507) +- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) + +## Introduction +Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins. +## Optimizers +Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant. + +## API Reference + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} +{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} +{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} +{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} + +## Hands-On Practice +We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. +### step 1. Import libraries + +```python +from transformers import LlamaModel, LlamaConfig +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +import colossalai +import torch +``` + +### step 2. Initialize Distributed Environment and Parallism Group +We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) + +```python +colossalai.launch_from_torch() +``` + +### step 3. Initialize Module and Optimizer +Build our model. We created an MLP using two Linear Layer. + +```python +# Init Llama from huggingface +configuration = LlamaConfig() +model = LlamaModel(configuration).cuda() +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### step 4.Init Booster + +```python +plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True) +booster = Booster(plugin=plugin) +# You should also pass in your own dataset. +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion) +``` +### step 5.Train Your Model +```python +steps = 10 +for step in range(steps): + input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int) + attention_mask = input_ids.clone() + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.last_hidden_state) + booster.backward(loss, dist_optim) + dist_optim.step() + dist_optim.zero_grad() +``` +### GaLore special handling +For GaLore, we need to specify projection rank for each parameter group and quantization & paged optimizer params. Please refer to bitandbytes for quantization details. Support for ZeRO is underway. +```python +from colossalai.nn.optimizer.galore import get_galore_param_groups +from colossalai.nn.optimizer import DistGaloreAwamW +optim = DistGaloreAwamW( + get_galore_param_groups(model, decay=1e-2, rank=8), + lr=lr, + betas=(beta1, beta2), + eps=eps, + nbits=8, + percentile_clipping=100, + block_wise=True, + min_8bit_size=4096, +) +``` + +## Plugin compatibility + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureLambGaLoreAdafactorCAME
Hybrid Parallel
Plugin
✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
+ + diff --git a/docs/source/zh-Hans/features/distributed_adafactor.md b/docs/source/zh-Hans/features/distributed_adafactor.md deleted file mode 100644 index 19610a85c8c1..000000000000 --- a/docs/source/zh-Hans/features/distributed_adafactor.md +++ /dev/null @@ -1,155 +0,0 @@ -# 分布式 Adafactor - -作者: - -**相关论文** -- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) - -## 简介 - -分布式 Adafactor 是一种支持混合优化的优化器,包括 1D 张量并行和 ZerO。它通过合理的任务并行化充分利用了计算资源,提高了训练效率和速度,并减少了存储压力。它应用广泛,目前支持一系列基于 Transformer 的模型,详见 [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo). - -## API接口 - -{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} - -## 实例演示 -现在我们演示如何使用 Booster API 启动分布式 Adafactor。 -### 步骤 1. 导入相关库 - -```python -import torch -from torch import nn -import torch.distributed as dist -from transformers import LlamaModel, LlamaConfig - -from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row -from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor -from colossal_llama2.dataset.loader import load_tokenized_dataset -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin -``` - -### 步骤 2. 初始化分布式环境和参数 -然后,我们需要初始化分布式环境。为了演示的目的,我们使用了 `colossalai.launch`。您可以参考 [Launch Colossal-AI](../basics/launch_colossalai.md) 获得其他的初始化方法。这里, 我们使用 "ProcessGroupMesh"来创建张量并行组和数据并行组。 - -```python -# Distributed Enviroment -config = {} -colossalai.launch(config=config, rank=rank, world_size=world_size,host="localhost", port=port, backend="nccl") -``` - -### 步骤 3.初始化模块和优化器 -Build our model. We created an MLP using two Linear Layer. - -```python -# Init Llama from huggingface -configuration = LlamaConfig() -model = LlamaModel(configuration) -dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train") -dataloader = plugin.prepare_dataloader(dataset, batch_size=8) -criterion = lambda x: x.mean() -dist_optim = DistributedAdaFactor(model.parameters()) - -``` - -### 步骤 4.初始化Booster - -```python -plugin = LowLevelZeroPlugin() -booster = Booster(plugin=plugin) -model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion, dataloader) -``` -### 步骤 5.训练模型 -```python -for epoch in range(max_epochs): - for input_ids, attention_mask in dataloader: - outputs = model(input_ids.cuda(), attention_mask.cuda()) - loss = criterion(outputs.logits, input_ids) - booster.backward(loss, optimizer) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() -``` - -## 支持信息 -模型/功能兼容性矩阵: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model/FeatureTransformers
Bert
Transformers Bert
For Pretraining
Transformers Bert
Lm Head Model
Transformers Bert
For Masked Lm
Transformers Bert
For Sequence Classification
Transformers Bert
For Token Classification
Transformers Bert
For Next Sentence
Transformers Bert
For Multiple-choice Question
Transformers Bert
For Question Answering
Hybrid Parallel
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️✔️✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
- - diff --git a/docs/source/zh-Hans/features/distributed_optimizers.md b/docs/source/zh-Hans/features/distributed_optimizers.md new file mode 100644 index 000000000000..e1d00d22ace1 --- /dev/null +++ b/docs/source/zh-Hans/features/distributed_optimizers.md @@ -0,0 +1,141 @@ +# 分布式优化器 + +Author: Wenxuan Tan, Junwen Duan, Renjie Mao + +**相关论文** +- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235) +- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047) +- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507) +- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962) + +## 介绍 +除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。 +## 优化器 +Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现 + +## API 参考 + +{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} +{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} +{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} +{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }} + +## 使用 +We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. +### step 1. 导包 + +```python +from transformers import LlamaModel, LlamaConfig +from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +import colossalai +import torch +``` + +### step 2. 初始化分布式 +We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md) + +```python +colossalai.launch_from_torch() +``` + +### step 3. 初始化模型和优化器 +Build our model. We created an MLP using two Linear Layer. + +```python +configuration = LlamaConfig() +model = LlamaModel(configuration).cuda() +criterion = lambda x: x.mean() +dist_optim = DistributedAdaFactor(model.parameters()) + +``` + +### step 4.初始化booster和plugin + +```python +plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True) +booster = Booster(plugin=plugin) +# You should also pass in your own dataset. +model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion) + +``` +### step 5.训练 +```python +steps = 10 +for step in range(steps): + input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int) + attention_mask = input_ids.clone() + outputs = model(input_ids.cuda(), attention_mask.cuda()) + loss = criterion(outputs.last_hidden_state) + booster.backward(loss, dist_optim) + dist_optim.step() + dist_optim.zero_grad() +``` +### GaLore的特殊初期 +对于 GaLore,我们需要为每个参数组指定投影rank,以及量化和分页优化器参数。有关量化的详细信息,请参考 bitandbytes. +```python +from colossalai.nn.optimizer.galore import get_galore_param_groups +from colossalai.nn.optimizer import DistGaloreAwamW +optim = DistGaloreAwamW( + get_galore_param_groups(model, decay=1e-2, rank=8), + lr=lr, + betas=(beta1, beta2), + eps=eps, + nbits=8, + percentile_clipping=100, + block_wise=True, + min_8bit_size=4096, +) +``` + +## 兼容性 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model/FeatureLambGaLoreAdafactorCAME
Hybrid Parallel
Plugin
✔️✔️✔️✔️
Low Level Zero
Plugin
✔️✔️✔️
Torch DDP
Plugin
✔️✔️✔️✔️
Gemini
Plugin
Moe Hybrid
Plugin
+ + diff --git a/tests/test_optimizer/_utils.py b/tests/test_optimizer/_utils.py index 6157437de6bf..b23e3cb03895 100644 --- a/tests/test_optimizer/_utils.py +++ b/tests/test_optimizer/_utils.py @@ -131,7 +131,7 @@ def run_bert_test(test_config, optim_class, sharded_optim_class): def _run_bert_test(rank, world_size, port, optim_class, sharded_optim_class): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_test(optim_class, sharded_optim_class) diff --git a/tests/test_optimizer/test_dist_adafactor.py b/tests/test_optimizer/test_dist_adafactor.py index aeee1771801b..8589dfd637c8 100644 --- a/tests/test_optimizer/test_dist_adafactor.py +++ b/tests/test_optimizer/test_dist_adafactor.py @@ -680,8 +680,7 @@ def exam_bert_test_on_hybrid_plugin(test_config): def run_dist(rank, world_size, port): disable_existing_loggers() - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_bert_test_on_lowlevelzero_plugin() exam_bert_test_on_hybrid_plugin() exam_dist_adafactor_base() diff --git a/tests/test_optimizer/test_dist_came.py b/tests/test_optimizer/test_dist_came.py index a7b7f86b5d26..f27287b88b0c 100644 --- a/tests/test_optimizer/test_dist_came.py +++ b/tests/test_optimizer/test_dist_came.py @@ -459,8 +459,7 @@ def exam_bert_test_on_hybrid_plugin(test_config): def run_dist(rank, world_size, port): disable_existing_loggers() - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_bert_test_on_lowlevelzero_plugin() # err in TODO layer exam_bert_test_on_hybrid_plugin() # pass exam_dist_came_base() # pass diff --git a/tests/test_optimizer/test_dist_galore.py b/tests/test_optimizer/test_dist_galore.py index a0db18a518f1..71b22001d383 100644 --- a/tests/test_optimizer/test_dist_galore.py +++ b/tests/test_optimizer/test_dist_galore.py @@ -9,7 +9,7 @@ import colossalai from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.logging import disable_existing_loggers -from colossalai.nn.optimizer import DistGaloreAwamW8bit, GaLoreAdamW8bit +from colossalai.nn.optimizer import DistGaloreAwamW, GaLoreAdamW8bit from colossalai.nn.optimizer.galore import get_galore_param_groups from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor from colossalai.tensor.d_tensor.api import clear_layout_converter @@ -172,7 +172,7 @@ def run_dist_galore_basic(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_si block_wise=False, min_8bit_size=1e10, # Disable quantization ) - optim = DistGaloreAwamW8bit( + optim = DistGaloreAwamW( get_galore_param_groups(tp_model, decay, rank=8), lr=lr, betas=(beta1, beta2), @@ -236,7 +236,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_ block_wise=False, min_8bit_size=1e10, ) - optim = DistGaloreAwamW8bit( + optim = DistGaloreAwamW( get_galore_param_groups(tp_model, decay, rank=8), lr=lr, betas=(beta1, beta2), @@ -302,7 +302,7 @@ def run_dist_galore_fwd_bwd(p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_ def check_dist_galore(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") global coordinator coordinator = DistCoordinator() @@ -319,9 +319,10 @@ def check_dist_galore(rank, world_size, port): ) for config in test_config: try: - run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW8bit) + run_bert_test(test_config=config, optim_class=GaLoreAdamW8bit, sharded_optim_class=DistGaloreAwamW) except Exception as e: print(e) + dist.barrier() print(f"rank {rank} tests passed :)") diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py index 6f9f551c085a..f171e952bd98 100644 --- a/tests/test_optimizer/test_dist_lamb.py +++ b/tests/test_optimizer/test_dist_lamb.py @@ -279,7 +279,7 @@ def run_dist_lamb_fwd_bwd( def check_dist_lamb(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") global coordinator coordinator = DistCoordinator() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index df72934a93e4..4febe47de534 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -17,7 +17,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.lazy import LazyInitContext -from colossalai.nn.optimizer import DistGaloreAwamW8bit +from colossalai.nn.optimizer import DistGaloreAwamW from colossalai.nn.optimizer.galore import get_galore_param_groups from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -130,7 +130,7 @@ def build_model_from_hybrid_plugin( if use_lazy_init: ctx.materialize(org_model) org_model = org_model.cuda() - if sharded_optim_class == DistGaloreAwamW8bit: + if sharded_optim_class == DistGaloreAwamW: # Disable clipping and block-wise quantization org_optimizer = optim_class( get_galore_param_groups(org_model, weight_decay=0, rank=4),