Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] update distributed optim docs #5701

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW8bit
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
Expand Down Expand Up @@ -1179,7 +1179,7 @@ def configure(
# TODO: Support Galore + ZeRO
zero_stage = self.zero_stage
zero_config = deepcopy(self.zero_config)
if isinstance(optimizer, DistGaloreAwamW8bit) and zero_stage > 0 and self.dp_size > 0:
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_config["partition_grad"] = False
zero_stage = 0
Expand Down
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW8bit
from colossalai.nn.optimizer import DistGaloreAwamW
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.zero import LowLevelZeroOptimizer

Expand Down Expand Up @@ -437,7 +437,7 @@ def configure(
zero_stage = self.stage
zero_optim_kwargs = {**self.zero_optim_kwargs}
dp_size = dist.get_world_size()
if isinstance(optimizer, DistGaloreAwamW8bit) and zero_stage > 0 and dp_size > 0:
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0
Expand Down
4 changes: 2 additions & 2 deletions colossalai/nn/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
from .distributed_came import DistributedCAME
from .distributed_galore import DistGaloreAwamW8bit
from .distributed_galore import DistGaloreAwamW
from .distributed_lamb import DistributedLamb
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
Expand All @@ -25,7 +25,7 @@
"CPUAdam",
"HybridAdam",
"DistributedLamb",
"DistGaloreAwamW8bit",
"DistGaloreAwamW",
"GaLoreAdamW",
"GaLoreAdafactor",
"GaLoreAdamW8bit",
Expand Down
16 changes: 9 additions & 7 deletions colossalai/nn/optimizer/distributed_galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Mark sharded dimension


class DistGaloreAwamW8bit(DistributedOptim, Optimizer2State):
class DistGaloreAwamW(DistributedOptim, Optimizer2State):
r"""Implements Galore, a optimizer-agonistic gradient compression technique on 8-bit AdamW.
It largely compresses gradient via low-rank projection and is claimed to be insensitive to hyperparams like lr.
Supports Tensor Parallel and ZeRO stage 1 and 2 via booster and plugin.
Expand All @@ -35,12 +35,14 @@ class DistGaloreAwamW8bit(DistributedOptim, Optimizer2State):
numerical stability. (default: 1e-6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)
nbits: Number of bits for quantization optim states. Only 32 and 8 are supported.
Example:
>>> optim = DistributedLamb(model.parameters(), lr=1e-3)
>>> proc_mesh = ProcessGroupMesh(tp_size, zero_size)
>>> tp_group = proc_mesh.get_group_along_axis(0)
>>> dp_group = proc_mesh.get_group_along_axis(1)
>>> optim.setup_distributed(tp_group, dp_group)
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
"""

def __init__(
Expand Down
20 changes: 10 additions & 10 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
else:
if bucket_store.moe_extra_dp_pg is None:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.zero_world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)

if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)

grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.zero_local_rank]
LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store, grad_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1
bucket_store, grad_store, grad_in_bucket_current_rank, received_grad, group_id, 1
)
else:
# categorize moe and non moe param
Expand All @@ -421,13 +421,13 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
flat_grads_list = list(
non_moe_flat_grads.split(len(non_moe_flat_grads) // bucket_store.zero_world_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
LowLevelZeroOptimizer.update_partitoned_grad(
bucket_store,
grad_store,
non_moe_grad_in_bucket_current_rank,
recieved_grad,
received_grad,
group_id,
1,
)
Expand All @@ -436,15 +436,15 @@ def run_reduction(bucket_store: BucketStore, grad_store: GradientStore):
flat_grads_list = list(
moe_flat_grads.split(len(moe_flat_grads) // bucket_store.moe_extra_dp_pg_size)
)
recieved_grad = torch.zeros_like(flat_grads_list[0])
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(
recieved_grad,
received_grad,
flat_grads_list,
group=bucket_store.moe_extra_dp_pg,
)
param_slice = bucket_store.zero_world_size // bucket_store.moe_extra_dp_pg_size
recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
for split_recieved_grad in recieved_grad:
received_grad = list(received_grad.split(len(received_grad) // param_slice))
for split_recieved_grad in received_grad:
split_recieved_grad = _unflatten_dense_tensors(
split_recieved_grad, moe_grad_in_bucket_current_rank
)
Expand Down
156 changes: 0 additions & 156 deletions docs/source/en/features/distributed_adafactor.md

This file was deleted.