From 83b52c56cde772e32c034dbd027bfff7d14b32b1 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Thu, 12 Oct 2023 11:32:37 +0800 Subject: [PATCH] [feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837) * Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code --- .../naive_amp/mixed_precision_optimizer.py | 75 +++- .../booster/plugin/hybrid_parallel_plugin.py | 340 +++++++++++++++++- colossalai/zero/low_level/_utils.py | 49 --- .../low_level/bookkeeping/gradient_store.py | 33 ++ colossalai/zero/low_level/low_level_optim.py | 55 ++- .../test_amp_optimizer.py | 258 +++++++++++++ .../test_naive_optimizer.py | 197 ++++++++++ .../test_zero_optimizer.py | 241 +++++++++++++ 8 files changed, 1158 insertions(+), 90 deletions(-) create mode 100644 tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py create mode 100644 tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py create mode 100644 tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 501a843f6992..9e07bdebf8fa 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -1,7 +1,7 @@ -from typing import Dict, List +from typing import Dict, List, Tuple import torch -from torch import Tensor +from torch import Tensor, inf from torch.nn import Module, Parameter from torch.optim import Optimizer @@ -68,8 +68,6 @@ def __init__( self.mixed_precision = BF16MixedPrecisionMixin() else: raise ValueError(f"Unsupported precision: {precision}") - if max_norm > 0.0: - raise NotImplementedError("max_norm is not supported yet.") self.max_norm = max_norm self.working_to_master_map: Dict[Parameter, Tensor] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {} @@ -102,32 +100,65 @@ def zero_grad(self, *args, **kwargs): return super().zero_grad(*args, **kwargs) def _unscale_and_clip_grads(self, total_norm: float) -> None: + """ + Unscale and clip gradients before performing the optimization step. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ div_scale = 1.0 + + # If mixed-precision training is used, get the gradient division scale from the mixed-precision handler. if self.mixed_precision is not None: div_scale = self.mixed_precision.get_grad_div_scale() if self.max_norm > 0.0: - # norm is in fact norm*scale + # Calculate the scaling factor for gradient clipping + # The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm' clip = ((total_norm / div_scale) + 1e-6) / self.max_norm + + # If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping if clip > 1: div_scale = clip * div_scale + # Apply the scaling factor to gradients for group in self.param_groups: for p in group["params"]: if p.grad is None: continue p.grad.data.mul_(1.0 / div_scale) - def _compute_grad_norm(self) -> float: - if self.max_norm <= 0.0: - return 0.0 - grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None] - if len(grads) == 0: + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: return 0.0 - device = grads[0].device - # TODO(ver217): support tp - total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) - return total_norm.item() + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type + total_norm = total_norm_exponentiated ** (1.0 / norm_type) + + return total_norm def step(self, *args, **kwargs): if self.mixed_precision.should_skip_step(): @@ -142,8 +173,22 @@ def step(self, *args, **kwargs): if working_param.grad is not None: p.grad = working_param.grad.data.float() working_param.grad = None - total_norm = self._compute_grad_norm() + + # gradient unscale and clip. + if self.max_norm <= 0: + # no need to compute gradient norm. + total_norm = 0.0 + else: + # compute the total norm. + param_gradient_pairs = [ + (self.master_to_working_map[p], p.grad) + for group in self.param_groups + for p in group["params"] + if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) self._unscale_and_clip_grads(total_norm) + self.optim.step(*args, **kwargs) # update working params for group in self.optim.param_groups: diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 479ccc3eb36e..2c6237cd9a1a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,3 +1,4 @@ +import ctypes import random from contextlib import nullcontext from functools import partial @@ -7,7 +8,8 @@ import numpy as np import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from torch import Tensor, inf +from torch.distributed import ProcessGroup, get_world_size from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -24,6 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + def __init__( + self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.max_norm = max_norm + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group super().__init__(optim) + def step(self, *args, **kwargs): + r""" + Perform an optimization step. + + Args: + *args: Variable-length positional arguments to be passed to the optimizer's step function. + **kwargs: Keyword arguments to be passed to the optimizer's step function. + """ + + if self.max_norm > 0: + # Compute the total gradient norm. + param_gradient_pairs = [ + (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + + # Clip the gradients to prevent exploding gradients. + self._clip_grad_norm(total_norm) + + # Perform the optimization step using the underlying optimizer. + self.optim.step(*args, **kwargs) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + total_norm = total_norm_cuda.item() + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + if grad is stage_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def _clip_grad_norm(self, total_norm: float) -> None: + r""" + Clips the gradients of the model's parameters to prevent exploding gradients. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for group in self.optim.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(clip_coef_clamped) + def update_master_params(self, model: Module): pass @@ -192,23 +326,108 @@ def __init__( hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp ): self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__( optim, - precision, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - max_norm, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, ) + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + if len(param_gradient_pairs) == 0: + return 0.0 + + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we need to calculate the norm of 'tp' and 'pp' gradients. + total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_working_shared_param = shared_param[self.stage_manager.stage] + stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] + if grad is stage_master_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( @@ -233,9 +452,15 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp forced_dtype: Optional[torch.dtype] = None, ): self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.dp_pg = dp_process_group + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( @@ -255,10 +480,90 @@ def __init__( partition_grad, cpu_offload, dp_process_group, - tp_process_group, forced_dtype, ) + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): A list of tensors containing gradients. + norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. + + Returns: + float: The computed gradient norm. + """ + + # Check if the list of gradients is empty + if len(gradients) == 0: + return 0.0 + + dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1 + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we only need to calculate the norm 'tp' of 'pp' gradients. + total_norm = super()._compute_grad_norm(gradients, norm_type) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_id_for_grad = self._grad_store.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param)) + if grad is working_grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + if dp_size > 1: + # compute norm in dp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # Compute the 'total_norm' from 'total_norm_exponentiated' + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + class HybridParallelPlugin(PipelinePluginBase): """ @@ -475,11 +780,19 @@ def configure( param_info=param_info, precision=self.precision, max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, **self.amp_config, ) else: optimizer = HybridParallelNaiveOptimizer( - optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." @@ -491,6 +804,7 @@ def configure( param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, + pp_process_group=self.pp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 0a15f8ddd718..de08ecf3d57f 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -3,9 +3,7 @@ import torch import torch.distributed as dist -from torch import Tensor, inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.distributed import ProcessGroup def flatten(input_): @@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list): total_norm += norm**2.0 return math.sqrt(total_norm) - -def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int: - """Clips gradient norm of an iterable of parameters. - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. - - Args: - gradients (Tensor): The gradients to compute norm - dp_group (ProcessGroup): The process group of ZeRO Data Parallelism - tp_group (ProcessGroup): The process group of Tensor Parallelism - norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. - - Returns: - int: The total norm of given gradients - """ - - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group) - - # Take max across all GPUs. - if tp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - total_norm = 0.0 - for g in gradients: - param_norm = g.data.double().norm(norm_type) - total_norm += param_norm.item() ** norm_type - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group) - - if tp_group is not None: - dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) - - total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - - if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: - total_norm = -1 - - return total_norm - - def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 3ce688cfa930..1164532fa3a3 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -21,6 +21,8 @@ def __init__(self, *args, partition_grad: bool = False): # for zero2, it's `param_id: [grad_local_rank]` self._working_index = 0 if partition_grad else self._local_rank + self.grad_to_param_mapping = dict() + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: """Return list of gradient slices of a specific parameter @@ -54,6 +56,8 @@ def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: in else: self._grads_of_params[group_id][param_id].append(grad) + self.grad_to_param_mapping[id(grad)] = param_id + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): """Add a gradient slice on an existing slice of the parameter's gradient Used when no_sync is not activated. @@ -83,8 +87,37 @@ def get_working_grads_by_group_id(self, group_id: int) -> List: return grad_list + def get_working_grad_by_param_id(self, param_id) -> Tensor: + """ + Return the working gradient for the specified parameter. + + Args: + param_id (int): The index of the parameter. + + Returns: + Tensor: The the working gradient slices for the specified param_id. + """ + + for group in self._grads_of_params.values(): + if param_id in group.keys(): + return group[param_id][self._working_index] + + raise KeyError(f"Working gradient for param_id {param_id} not found.") + def reset_grads_by_group_id(self, group_id: int): self._grads_of_params[group_id] = dict() def reset_all_gradients(self): self._grads_of_params = dict() + + def get_param_id_for_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to + + Args: + grad (Tensor): the gradient slice + + Returns: + int: the id of a parameter which the gradient slice belongs to + """ + + return self.grad_to_param_mapping[id(grad)] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 72df93ace302..d9be7af17d15 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -2,11 +2,12 @@ import copy from contextlib import contextmanager from functools import partial -from typing import Dict, Iterator, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch import Tensor, inf from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -21,14 +22,7 @@ # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device -from ._utils import ( - calculate_global_norm_from_list, - compute_norm, - flatten, - has_inf_or_nan, - release_param_grad, - sync_tensor, -) +from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -80,7 +74,6 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -101,8 +94,6 @@ def __init__( self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) - self.tp_pg = tp_process_group - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -433,7 +424,7 @@ def step(self, closure=None): # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) - norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) + norm_group = self._compute_grad_norm(gradients=working_grads) norm_groups.append(norm_group) self._grad_store.reset_grads_by_group_id(group_id) @@ -467,6 +458,44 @@ def step(self, closure=None): self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): The gradients to compute norm + norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2. + + Returns: + float: The total norm of given gradients + """ + + if len(gradients) == 0: + return 0.0 + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) + total_norm = total_norm_cuda.item() + + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + total_norm_exponentiated += grad_norm_exponentiated + + # Sum across all model parallel GPUs. + total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + torch.distributed.all_reduce( + total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg + ) + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + ############################# # Mixed Precision Utilities # ############################# diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py new file mode 100644 index 000000000000..0192afc99ae4 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py @@ -0,0 +1,258 @@ +import pytest +import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + # Check grads + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + check_all_grad_tensors(grads_to_check) + + # Check gradient norm + # Convert the gradient data of the working parameter to float and assign it to the master parameter's gradient + # Note that this operation should have been done in the 'step' function, but it is performed here in advance for gradient norm calculation purposes. + # Although it will be done again in the 'step' function, it does not affect correctness. + for group in sharded_optimizer.optim.param_groups: + for p in group["params"]: + working_param = sharded_optimizer.master_to_working_map[p] + if p is working_param: + continue + if working_param.grad is not None: + p.grad = working_param.grad.data.float() + working_param.grad = None + # Create a list of parameter-gradient pairs containing working parameters and their gradients + param_gradient_pairs = [ + (sharded_optimizer.master_to_working_map[p], p.grad) + for group in sharded_optimizer.param_groups + for p in group["params"] + if p.grad is not None + ] + + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + # Calculate the gradient norm of the sharded optimizer + device = origin_norm.device + hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device) + + # If using fp16 precision, divide by the initial scale + if test_config["precision"] == "fp16": + hybrid_norm /= test_config["initial_scale"] + + # Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}" + + # Optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # Check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # Check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py new file mode 100644 index 000000000000..da298f5c0be1 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py @@ -0,0 +1,197 @@ +import pytest +import torch +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_all_grad_tensors, + check_loss, + check_output_hidden_state, + check_weight, + get_grad_tensors_for_check, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + # Check grads + # Save gradient tensors for comparison between the original model and the sharded model. + grads_to_check = {} + if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + grads_to_check.update(col_layer_grads) + grads_to_check.update(row_layer_grads) + check_all_grad_tensors(grads_to_check) + + # Check grad norm + param_gradient_pairs = [ + (p, p.grad) for group in sharded_optimizer.param_groups for p in group["params"] if p.grad is not None + ] + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + device = origin_norm.device + hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device) + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"orgin origin model grad norm is not equal to shard model grad norm\n{origin_norm}\n{hybrid_norm}" + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp32", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "max_norm": 5, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d() diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py new file mode 100644 index 000000000000..f1ac1de1acc9 --- /dev/null +++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py @@ -0,0 +1,241 @@ +import math + +import pytest +import torch +import torch.distributed as dist +from torch.nn.utils.clip_grad import clip_grad_norm_ + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, + unwrap_model, +) + + +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) + + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") + + col_layer_for_check = ["encoder.layer[0].output.dense"] + + if test_config["precision"] == "fp32": + atol, rtol = 1e-4, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + + dist.barrier() + # Check gradient norm + origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"]) + + # Calculate the gradient norm of the sharded optimizer + device = origin_norm.device + norm_groups = [] + for group_id in range(sharded_optimizer.num_param_groups): + working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id) + norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads) + norm_groups.append(norm_group) + total_norm = 0.0 + for norm in norm_groups: + total_norm += norm**2.0 + hybrid_norm = torch.tensor(math.sqrt(total_norm)).to(device) + + # If using fp16 precision, divide by the initial scale + if test_config["precision"] == "fp16": + hybrid_norm /= test_config["initial_scale"] + + # Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model + assert torch.allclose( + origin_norm, hybrid_norm, atol=atol, rtol=rtol + ), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}" + + # optimizer executes step + org_optimizer.step() + sharded_optimizer.step() + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config["precision"] == "fp32": + atol, rtol = 1e-5, 1e-3 + elif test_config["precision"] == "fp16": + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 2e-2, 2e-2 + if org_model.__class__.__name__ == "BertModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + # check weights + if test_config["precision"] == "fp32": + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": True, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 1, + "zero_stage": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + ], +) +def run_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "bf16", + "max_norm": 5, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "zero_stage": 1, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "max_norm": 5, + "initial_scale": 1, + }, + ], +) +def run_3d_test(test_config): + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + +def check_grad_clip_norm(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_test() + + +def check_grad_clip_norm_3d(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_3d_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm(): + spawn(check_grad_clip_norm, 4) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_grad_clip_norm_3d(): + spawn(check_grad_clip_norm_3d, 8) + + +if __name__ == "__main__": + test_grad_clip_norm() + test_grad_clip_norm_3d()