Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add clip_grad_norm for hybrid_parallel_plugin #4837

Merged
merged 6 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 60 additions & 15 deletions colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List
from typing import Dict, List, Tuple

import torch
from torch import Tensor
from torch import Tensor, inf
from torch.nn import Module, Parameter
from torch.optim import Optimizer

Expand Down Expand Up @@ -68,8 +68,6 @@ def __init__(
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
Expand Down Expand Up @@ -102,32 +100,65 @@ def zero_grad(self, *args, **kwargs):
return super().zero_grad(*args, **kwargs)

def _unscale_and_clip_grads(self, total_norm: float) -> None:
"""
Unscale and clip gradients before performing the optimization step.

Args:
total_norm (float): The computed total gradient norm.

Returns:
None
"""
div_scale = 1.0

# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()

if self.max_norm > 0.0:
# norm is in fact norm*scale
# Calculate the scaling factor for gradient clipping
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm

# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
if clip > 1:
div_scale = clip * div_scale

# Apply the scaling factor to gradients
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1.0 / div_scale)

def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.

Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.

Returns:
float: The total norm of the given gradients.
"""

if len(param_gradient_pairs) == 0:
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()

# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]

if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)

else:
total_norm_exponentiated = 0.0
for grad in gradients:
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
total_norm = total_norm_exponentiated ** (1.0 / norm_type)

return total_norm

def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
Expand All @@ -142,8 +173,22 @@ def step(self, *args, **kwargs):
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()

# gradient unscale and clip.
if self.max_norm <= 0:
# no need to compute gradient norm.
total_norm = 0.0
else:
# compute the total norm.
param_gradient_pairs = [
(self.master_to_working_map[p], p.grad)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._unscale_and_clip_grads(total_norm)

self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
Expand Down
Loading
Loading