Skip to content

Commit

Permalink
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
Browse files Browse the repository at this point in the history
* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
  • Loading branch information
littsk committed Oct 12, 2023
1 parent df63564 commit 83b52c5
Show file tree
Hide file tree
Showing 8 changed files with 1,158 additions and 90 deletions.
75 changes: 60 additions & 15 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List
from typing import Dict, List, Tuple

import torch
from torch import Tensor
from torch import Tensor, inf
from torch.nn import Module, Parameter
from torch.optim import Optimizer

Expand Down Expand Up @@ -68,8 +68,6 @@ def __init__(
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
Expand Down Expand Up @@ -102,32 +100,65 @@ def zero_grad(self, *args, **kwargs):
return super().zero_grad(*args, **kwargs)

def _unscale_and_clip_grads(self, total_norm: float) -> None:
"""
Unscale and clip gradients before performing the optimization step.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
div_scale = 1.0

# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()

if self.max_norm > 0.0:
# norm is in fact norm*scale
# Calculate the scaling factor for gradient clipping
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm

# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
if clip > 1:
div_scale = clip * div_scale

# Apply the scaling factor to gradients
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1.0 / div_scale)

def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""

if len(param_gradient_pairs) == 0:
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()

# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]

if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)

else:
total_norm_exponentiated = 0.0
for grad in gradients:
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
total_norm = total_norm_exponentiated ** (1.0 / norm_type)

return total_norm

def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
Expand All @@ -142,8 +173,22 @@ def step(self, *args, **kwargs):
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()

# gradient unscale and clip.
if self.max_norm <= 0:
# no need to compute gradient norm.
total_norm = 0.0
else:
# compute the total norm.
param_gradient_pairs = [
(self.master_to_working_map[p], p.grad)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._unscale_and_clip_grads(total_norm)

self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
Expand Down
Loading

0 comments on commit 83b52c5

Please sign in to comment.