Skip to content

Commit

Permalink
[feature] [experimental] Layerwise Gradient Scaler (#879)
Browse files Browse the repository at this point in the history
* [skip ci] first commit

* [skip ci] gradient scaler example

* [skip ci] adding feed forward toy example

* [skip ci] adding types

* [skip ci] adding backward hook

* [skip ci] update

* [skip ci] working feed forward example

* [skip ci] working feed forward example

* [skip ci] use named_modules instead of named_children

* [skip ci] adding new file

* [skip ci] clean up

* [skip ci] implement unscale function

* [skip ci] implement unscale function

* [skip ci] removing old file

* [skip ci] removing some more old files

* [skip ci] making unscale function generic

* [skip ci] adding test for vision model

* [skip ci] adding identity layer

* [skip ci] cleanup files

* [skip ci] refactoring

* [skip ci] more refactoring

* [skip ci] added functionality to update scale

* [skip ci] data loader clean up

* [skip ci] implemented inf checks and update scale functions

* [skip ci]code clean up. added test with autocast. does not work atm

* adding documentation

* adding dependency in requirements-dev.txt

* updating pytorch nightly version

* updating changelog

* adding is_cuda_available to test_vision_model

* set same timeout on cpu and gpu

* reverting cpu timeout, skip vision test on cpu

* addressing comments, fixing vision test

* unscale uses in-place matmul

* some more cleanup
  • Loading branch information
Anupam Bhatnagar committed Jan 13, 2022
1 parent fb4eca1 commit 52d066a
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.5] - TBD

### Added
- Layer-wise Gradient Scaling [new feature][experimental] Layer-wise gradient
scaling helps overcomes gradient overflow issues. When used in conjunction with
mixed precision, it enables training larger models and makes the training
process more stable, especially in deep networks [#879]
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full
state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]
- FSDP: Added process_group_reduce_scatter parameter to allow users to pass in the process group that is used for reduce scatter operation. [#897]
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]

Expand Down
2 changes: 1 addition & 1 deletion fairscale/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import torch
from torch.cuda import FloatTensor # type: ignore
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.common import amp_definitely_not_available
from torch.cuda.amp.grad_scaler import GradScaler as TorchGradScaler
import torch.distributed as dist
from torch.optim import Optimizer
from torch.optim.sgd import SGD
Expand Down
289 changes: 289 additions & 0 deletions fairscale/optim/layerwise_gradient_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import logging
from typing import List, Tuple

import torch
import torch.nn as nn


class LayerInfo:
"""
A class to record the layer attributes.
"""

def __init__(self, name: str, layer: nn.Module, scale: float = 1.0, scale_layer: bool = False) -> None:
"""
layer_name: name of the layer e.g. fc1, conv1, relu1
layer: type of the layer e.g. Linear, Conv2d, ReLU
scaling_factor: user configurable scaling factor for the layer, defaults to 1.0
found_inf_or_nan: a boolean indicating if any parameter of layer's gradient contains inf/nan
growth_tracker: tracks number of step since last time scale was increased
scale_layer: a boolean indicating if the layer should be scaled or not
"""
self.layer_name = name
self.layer = layer
self.scaling_factor = scale
self.found_inf_or_nan = False
self.growth_tracker = 0
self.scale_layer = scale_layer


class GradientHelper:
"""
A helper class to create instances of backward hooks. The hooks are registered in the
scale method of LayerwiseGradientScaler.
"""

def __init__(self, name: str, inputs_multiplier: float, outputs_multiplier: float):
self.layer_name = name
self.inputs_multiplier = inputs_multiplier
self.outputs_multiplier = outputs_multiplier

def scale_gradients(self, m: nn.Module, inputs: Tuple, outputs: Tuple) -> Tuple[torch.Tensor]:
"""
Backward hook that is attached to the layers to scale the gradients.
"""
scaled_up_grads = list()
for idx in range(len(inputs)):
if inputs[idx] is not None:
if self.inputs_multiplier != 1.0 or self.outputs_multiplier != 1.0:
logging.debug(
"layer = %s \t scale = %s \t scale_down = %s"
% (self.layer_name, self.inputs_multiplier, self.outputs_multiplier)
)
scaled_up_grads.append(inputs[idx].mul(self.inputs_multiplier * self.outputs_multiplier))
else:
logging.debug("next layer is None")
scaled_up_grads.append(inputs[idx])
return tuple(scaled_up_grads) # type: ignore


class LayerwiseGradientScaler:
"""
LayerwiseGradientScaler enables using distinct scaling factors for each layer
of the network.
Example:
# Create a convolutional network
class ConvNet(nn.Module):
def __init__(self):
...
def forward(self, x):
...
# Create an instance of the model
model = ConvNet()
optimizer = torch.optim.SGD(model.parameters())
# specify the layers to scale and their scaling factor
layer_scale_dict = {"conv1": 2**10, "conv2": 2**8, "fc1": 2**10, "fc2": 2**9}
scaler = LayerwiseGradientScaler(model, layer_scale_dict)
for epoch in num_epochs:
for inputs, targets in batch:
optimizer.zero_grad()
# scale the gradients
scaler.scale()
# enables mixed precision training
with autocast():
predictions = model(inputs)
loss = loss_function(predictions, targets)
loss.backward()
# unscale the gradients
loss.unscale()
# step is taken if there are no inf/nan in the gradients
# scaling factor for each layer are updated
loss.step(optimizer)
Args:
model : instance of a Model class, such as ConvNet above
layer_scale_dict (dict) : dictionary with key = layer_name and value = scaling_factor
growth_factor (float) : per layer scaling factor multiplier
backoff_factor (float) : per layer scaling factor multiplier when an inf/nan is found
growth_interval (int) : number of steps after which scale is multiplied by growth_factor
min_scaling_factor (float) : smallest scaling factor
max_scaling_factor (float) : largest scaling factor
"""

def __init__( # type: ignore
self,
model,
layer_scale_dict: dict,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 10000,
min_scale: float = torch.finfo(torch.float32).tiny, # type: ignore
max_scale: float = torch.finfo(torch.float32).max, # type: ignore
) -> None:
self._model = model
self._layer_scale_dict: dict = layer_scale_dict
self._growth_factor: float = growth_factor
self._backoff_factor: float = backoff_factor
self._growth_interval: int = growth_interval
self._apply_layerwise_scaling: bool = True if len(layer_scale_dict.keys()) > 0 else False
self._min_scale = min_scale
self._max_scale = max_scale
self._handles: List = []
self.layer_info: List = []

if self._apply_layerwise_scaling:
assert self._growth_factor > 1.0, "The growth factor must be > 1.0."
assert self._backoff_factor < 1.0, "The backoff factor must be < 1.0."
self.layer_info = self._build_layer_info()

def _build_layer_info(self) -> List:
"""
Helper function to create a list of LayerInfo instances.
"""
layer_info_list = list()

for name, layer in self._model.named_modules():
if name != "":
if name not in self._layer_scale_dict.keys():
logging.debug("name = %s, layer = %s, scaling_factor = %s" % (name, layer, 1.0))
layer_info_list.append(LayerInfo(name, layer, 1.0))
else:
logging.debug(
"name = %s, layer = %s, scaling_factor = %s" % (name, layer, self._layer_scale_dict[name])
)
layer_info_list.append(LayerInfo(name, layer, self._layer_scale_dict[name], True))
return layer_info_list

def scale(self) -> None:
"""
For each layer calculates the scaling factor for preceding layers' grad inputs
and current layers' grad outputs. These values are used to register a full backward
hook. The handle returned from registering the backward hook is appended to a list
of handles. New hooks are created and registered at every step and a new list of
handles is created. The handles are flushed out in the unscale function.
"""
if not self._apply_layerwise_scaling:
return

for idx in range(len(self.layer_info)):
elt = self.layer_info[idx]
layer_name, layer = elt.layer_name, elt.layer

inputs_multiplier = 1.0
if idx > 0:
inputs_multiplier = self.layer_info[idx - 1].scaling_factor

outputs_multiplier = 1.0 / elt.scaling_factor
helper = GradientHelper(layer_name, inputs_multiplier, outputs_multiplier)
layer_handle = layer.register_full_backward_hook(helper.scale_gradients)
self._handles.append(layer_handle)
logging.debug("name = %s \t scale = %s" % (layer_name, elt.scaling_factor))

def _get_layers_with_finite_values(self) -> List[LayerInfo]:
layers_with_finite_values: List = []
for item in self.layer_info:
if not item.found_inf_or_nan:
layers_with_finite_values.append(item)
return layers_with_finite_values

def unscale(self) -> None:
"""
For each layer, check if any of the layers' parameters contain an inf/nan.
If there are no inf/nan in the gradient, then gradient of that layer is
unscaled by the reciprocal of the scaling factor for that layer.
Finally, all handles recorded while registering the hooks are deleted.
"""
if not self._apply_layerwise_scaling:
return

layers_with_finite_values = self._get_layers_with_finite_values()
for item in layers_with_finite_values:
for param_name, param in item.layer.named_parameters():
if hasattr(param, "grad"):
logging.debug("%s scaling down %s by %s" % (item.layer_name, param_name, 1.0 / item.scaling_factor))
param.grad.mul_(1.0 / item.scaling_factor)

while len(self._handles) > 0:
elt = self._handles.pop()
elt.remove()

def _check_for_inf_or_nan(self) -> None:
"""
For each layer, check if any of the parameters with a gradient attribute
contain an inf/nan. If any of the parameters' gradient contain an inf/nan,
then that layers' found_inf_or_nan attribute is set to True and all
remaining parameters for that layer are skipped.
"""
for elt in self.layer_info:
elt.found_inf_or_nan = False
for _, param in elt.layer.named_parameters():
if hasattr(param, "grad") and param.grad is not None:
if torch.isinf(param.grad).any().item() or torch.isnan(param.grad).any().item(): # type: ignore
elt.found_inf_or_nan = True
break # skip all remaining named parameters

def step(self, optimizer) -> None: # type: ignore
"""
If there are no inf/nan in the gradients' of all layers, then optimizer
takes a step, otherwise not. Update the scaling factor for each layer.
"""
# using layerwise gradient scaling
if self._apply_layerwise_scaling:
self._check_for_inf_or_nan()
inf_nan_found = any(elt.found_inf_or_nan for elt in self.layer_info)

if not inf_nan_found:
optimizer.step()
self._update_scale()
# not using layerwise gradient scaling
else:
optimizer.step()

def _update_scale(self) -> None:
"""
For each layer, if an inf/nan is found, then multiply the scaling factor
of that layer by the backoff factor and set the growth tracker of that
layer to 0. Else, increment the growth tracker of the layer. If growth
tracker equals the growth interval, then multiply the scaling factor of
the layer by the growth factor and reset the layers' growth tracker to 0.
Finally, clip the scaling factor to the range
[self.min_scaling_factor, self.max_scaling_factor]. The min/max scaling
factor values are user configurable.
"""
if not self._apply_layerwise_scaling:
return

for layer in self.layer_info:
if layer.found_inf_or_nan:
if layer.scale_layer:
layer.scaling_factor = max(
self._min_scale,
min(self._backoff_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0
else:
layer.growth_tracker += 1
if layer.scale_layer and layer.growth_tracker == self._growth_interval:
layer.scaling_factor = max(
self._min_scale,
min(self._growth_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0

def get_layer_info(self) -> List[LayerInfo]:
"""
Returns a list of LayerInfo instances of the model.
"""
return self.layer_info

def get_backward_hooks(self) -> List:
"""
Returns a list of tuples. Each tuple contains the layer name and the
hook attached to it.
"""
layer_name_and_hooks = list()
for name, layer in self._model.named_modules():
if name != "":
layer_name_and_hooks.append((name, layer._get_backward_hooks()))
return layer_name_and_hooks
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision", "utils"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "sklearn", "torch", "torchtext", "torchvision", "utils"]
3 changes: 3 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ pynvml == 8.0.4

# For mypy typing
numpy >= 1.21

# For layerwise gradient scaler
sklearn >= 0.0
1 change: 1 addition & 0 deletions tests/ci_test_list_1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
tests/optim/test_layerwise_gradient_scaler.py

0 comments on commit 52d066a

Please sign in to comment.