-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] [experimental] Layerwise Gradient Scaler (#879)
* [skip ci] first commit * [skip ci] gradient scaler example * [skip ci] adding feed forward toy example * [skip ci] adding types * [skip ci] adding backward hook * [skip ci] update * [skip ci] working feed forward example * [skip ci] working feed forward example * [skip ci] use named_modules instead of named_children * [skip ci] adding new file * [skip ci] clean up * [skip ci] implement unscale function * [skip ci] implement unscale function * [skip ci] removing old file * [skip ci] removing some more old files * [skip ci] making unscale function generic * [skip ci] adding test for vision model * [skip ci] adding identity layer * [skip ci] cleanup files * [skip ci] refactoring * [skip ci] more refactoring * [skip ci] added functionality to update scale * [skip ci] data loader clean up * [skip ci] implemented inf checks and update scale functions * [skip ci]code clean up. added test with autocast. does not work atm * adding documentation * adding dependency in requirements-dev.txt * updating pytorch nightly version * updating changelog * adding is_cuda_available to test_vision_model * set same timeout on cpu and gpu * reverting cpu timeout, skip vision test on cpu * addressing comments, fixing vision test * unscale uses in-place matmul * some more cleanup
- Loading branch information
Anupam Bhatnagar
committed
Jan 13, 2022
1 parent
fb4eca1
commit 52d066a
Showing
7 changed files
with
512 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
import logging | ||
from typing import List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class LayerInfo: | ||
""" | ||
A class to record the layer attributes. | ||
""" | ||
|
||
def __init__(self, name: str, layer: nn.Module, scale: float = 1.0, scale_layer: bool = False) -> None: | ||
""" | ||
layer_name: name of the layer e.g. fc1, conv1, relu1 | ||
layer: type of the layer e.g. Linear, Conv2d, ReLU | ||
scaling_factor: user configurable scaling factor for the layer, defaults to 1.0 | ||
found_inf_or_nan: a boolean indicating if any parameter of layer's gradient contains inf/nan | ||
growth_tracker: tracks number of step since last time scale was increased | ||
scale_layer: a boolean indicating if the layer should be scaled or not | ||
""" | ||
self.layer_name = name | ||
self.layer = layer | ||
self.scaling_factor = scale | ||
self.found_inf_or_nan = False | ||
self.growth_tracker = 0 | ||
self.scale_layer = scale_layer | ||
|
||
|
||
class GradientHelper: | ||
""" | ||
A helper class to create instances of backward hooks. The hooks are registered in the | ||
scale method of LayerwiseGradientScaler. | ||
""" | ||
|
||
def __init__(self, name: str, inputs_multiplier: float, outputs_multiplier: float): | ||
self.layer_name = name | ||
self.inputs_multiplier = inputs_multiplier | ||
self.outputs_multiplier = outputs_multiplier | ||
|
||
def scale_gradients(self, m: nn.Module, inputs: Tuple, outputs: Tuple) -> Tuple[torch.Tensor]: | ||
""" | ||
Backward hook that is attached to the layers to scale the gradients. | ||
""" | ||
scaled_up_grads = list() | ||
for idx in range(len(inputs)): | ||
if inputs[idx] is not None: | ||
if self.inputs_multiplier != 1.0 or self.outputs_multiplier != 1.0: | ||
logging.debug( | ||
"layer = %s \t scale = %s \t scale_down = %s" | ||
% (self.layer_name, self.inputs_multiplier, self.outputs_multiplier) | ||
) | ||
scaled_up_grads.append(inputs[idx].mul(self.inputs_multiplier * self.outputs_multiplier)) | ||
else: | ||
logging.debug("next layer is None") | ||
scaled_up_grads.append(inputs[idx]) | ||
return tuple(scaled_up_grads) # type: ignore | ||
|
||
|
||
class LayerwiseGradientScaler: | ||
""" | ||
LayerwiseGradientScaler enables using distinct scaling factors for each layer | ||
of the network. | ||
Example: | ||
# Create a convolutional network | ||
class ConvNet(nn.Module): | ||
def __init__(self): | ||
... | ||
def forward(self, x): | ||
... | ||
# Create an instance of the model | ||
model = ConvNet() | ||
optimizer = torch.optim.SGD(model.parameters()) | ||
# specify the layers to scale and their scaling factor | ||
layer_scale_dict = {"conv1": 2**10, "conv2": 2**8, "fc1": 2**10, "fc2": 2**9} | ||
scaler = LayerwiseGradientScaler(model, layer_scale_dict) | ||
for epoch in num_epochs: | ||
for inputs, targets in batch: | ||
optimizer.zero_grad() | ||
# scale the gradients | ||
scaler.scale() | ||
# enables mixed precision training | ||
with autocast(): | ||
predictions = model(inputs) | ||
loss = loss_function(predictions, targets) | ||
loss.backward() | ||
# unscale the gradients | ||
loss.unscale() | ||
# step is taken if there are no inf/nan in the gradients | ||
# scaling factor for each layer are updated | ||
loss.step(optimizer) | ||
Args: | ||
model : instance of a Model class, such as ConvNet above | ||
layer_scale_dict (dict) : dictionary with key = layer_name and value = scaling_factor | ||
growth_factor (float) : per layer scaling factor multiplier | ||
backoff_factor (float) : per layer scaling factor multiplier when an inf/nan is found | ||
growth_interval (int) : number of steps after which scale is multiplied by growth_factor | ||
min_scaling_factor (float) : smallest scaling factor | ||
max_scaling_factor (float) : largest scaling factor | ||
""" | ||
|
||
def __init__( # type: ignore | ||
self, | ||
model, | ||
layer_scale_dict: dict, | ||
growth_factor: float = 2.0, | ||
backoff_factor: float = 0.5, | ||
growth_interval: int = 10000, | ||
min_scale: float = torch.finfo(torch.float32).tiny, # type: ignore | ||
max_scale: float = torch.finfo(torch.float32).max, # type: ignore | ||
) -> None: | ||
self._model = model | ||
self._layer_scale_dict: dict = layer_scale_dict | ||
self._growth_factor: float = growth_factor | ||
self._backoff_factor: float = backoff_factor | ||
self._growth_interval: int = growth_interval | ||
self._apply_layerwise_scaling: bool = True if len(layer_scale_dict.keys()) > 0 else False | ||
self._min_scale = min_scale | ||
self._max_scale = max_scale | ||
self._handles: List = [] | ||
self.layer_info: List = [] | ||
|
||
if self._apply_layerwise_scaling: | ||
assert self._growth_factor > 1.0, "The growth factor must be > 1.0." | ||
assert self._backoff_factor < 1.0, "The backoff factor must be < 1.0." | ||
self.layer_info = self._build_layer_info() | ||
|
||
def _build_layer_info(self) -> List: | ||
""" | ||
Helper function to create a list of LayerInfo instances. | ||
""" | ||
layer_info_list = list() | ||
|
||
for name, layer in self._model.named_modules(): | ||
if name != "": | ||
if name not in self._layer_scale_dict.keys(): | ||
logging.debug("name = %s, layer = %s, scaling_factor = %s" % (name, layer, 1.0)) | ||
layer_info_list.append(LayerInfo(name, layer, 1.0)) | ||
else: | ||
logging.debug( | ||
"name = %s, layer = %s, scaling_factor = %s" % (name, layer, self._layer_scale_dict[name]) | ||
) | ||
layer_info_list.append(LayerInfo(name, layer, self._layer_scale_dict[name], True)) | ||
return layer_info_list | ||
|
||
def scale(self) -> None: | ||
""" | ||
For each layer calculates the scaling factor for preceding layers' grad inputs | ||
and current layers' grad outputs. These values are used to register a full backward | ||
hook. The handle returned from registering the backward hook is appended to a list | ||
of handles. New hooks are created and registered at every step and a new list of | ||
handles is created. The handles are flushed out in the unscale function. | ||
""" | ||
if not self._apply_layerwise_scaling: | ||
return | ||
|
||
for idx in range(len(self.layer_info)): | ||
elt = self.layer_info[idx] | ||
layer_name, layer = elt.layer_name, elt.layer | ||
|
||
inputs_multiplier = 1.0 | ||
if idx > 0: | ||
inputs_multiplier = self.layer_info[idx - 1].scaling_factor | ||
|
||
outputs_multiplier = 1.0 / elt.scaling_factor | ||
helper = GradientHelper(layer_name, inputs_multiplier, outputs_multiplier) | ||
layer_handle = layer.register_full_backward_hook(helper.scale_gradients) | ||
self._handles.append(layer_handle) | ||
logging.debug("name = %s \t scale = %s" % (layer_name, elt.scaling_factor)) | ||
|
||
def _get_layers_with_finite_values(self) -> List[LayerInfo]: | ||
layers_with_finite_values: List = [] | ||
for item in self.layer_info: | ||
if not item.found_inf_or_nan: | ||
layers_with_finite_values.append(item) | ||
return layers_with_finite_values | ||
|
||
def unscale(self) -> None: | ||
""" | ||
For each layer, check if any of the layers' parameters contain an inf/nan. | ||
If there are no inf/nan in the gradient, then gradient of that layer is | ||
unscaled by the reciprocal of the scaling factor for that layer. | ||
Finally, all handles recorded while registering the hooks are deleted. | ||
""" | ||
if not self._apply_layerwise_scaling: | ||
return | ||
|
||
layers_with_finite_values = self._get_layers_with_finite_values() | ||
for item in layers_with_finite_values: | ||
for param_name, param in item.layer.named_parameters(): | ||
if hasattr(param, "grad"): | ||
logging.debug("%s scaling down %s by %s" % (item.layer_name, param_name, 1.0 / item.scaling_factor)) | ||
param.grad.mul_(1.0 / item.scaling_factor) | ||
|
||
while len(self._handles) > 0: | ||
elt = self._handles.pop() | ||
elt.remove() | ||
|
||
def _check_for_inf_or_nan(self) -> None: | ||
""" | ||
For each layer, check if any of the parameters with a gradient attribute | ||
contain an inf/nan. If any of the parameters' gradient contain an inf/nan, | ||
then that layers' found_inf_or_nan attribute is set to True and all | ||
remaining parameters for that layer are skipped. | ||
""" | ||
for elt in self.layer_info: | ||
elt.found_inf_or_nan = False | ||
for _, param in elt.layer.named_parameters(): | ||
if hasattr(param, "grad") and param.grad is not None: | ||
if torch.isinf(param.grad).any().item() or torch.isnan(param.grad).any().item(): # type: ignore | ||
elt.found_inf_or_nan = True | ||
break # skip all remaining named parameters | ||
|
||
def step(self, optimizer) -> None: # type: ignore | ||
""" | ||
If there are no inf/nan in the gradients' of all layers, then optimizer | ||
takes a step, otherwise not. Update the scaling factor for each layer. | ||
""" | ||
# using layerwise gradient scaling | ||
if self._apply_layerwise_scaling: | ||
self._check_for_inf_or_nan() | ||
inf_nan_found = any(elt.found_inf_or_nan for elt in self.layer_info) | ||
|
||
if not inf_nan_found: | ||
optimizer.step() | ||
self._update_scale() | ||
# not using layerwise gradient scaling | ||
else: | ||
optimizer.step() | ||
|
||
def _update_scale(self) -> None: | ||
""" | ||
For each layer, if an inf/nan is found, then multiply the scaling factor | ||
of that layer by the backoff factor and set the growth tracker of that | ||
layer to 0. Else, increment the growth tracker of the layer. If growth | ||
tracker equals the growth interval, then multiply the scaling factor of | ||
the layer by the growth factor and reset the layers' growth tracker to 0. | ||
Finally, clip the scaling factor to the range | ||
[self.min_scaling_factor, self.max_scaling_factor]. The min/max scaling | ||
factor values are user configurable. | ||
""" | ||
if not self._apply_layerwise_scaling: | ||
return | ||
|
||
for layer in self.layer_info: | ||
if layer.found_inf_or_nan: | ||
if layer.scale_layer: | ||
layer.scaling_factor = max( | ||
self._min_scale, | ||
min(self._backoff_factor * layer.scaling_factor, self._max_scale), | ||
) | ||
layer.growth_tracker = 0 | ||
else: | ||
layer.growth_tracker += 1 | ||
if layer.scale_layer and layer.growth_tracker == self._growth_interval: | ||
layer.scaling_factor = max( | ||
self._min_scale, | ||
min(self._growth_factor * layer.scaling_factor, self._max_scale), | ||
) | ||
layer.growth_tracker = 0 | ||
|
||
def get_layer_info(self) -> List[LayerInfo]: | ||
""" | ||
Returns a list of LayerInfo instances of the model. | ||
""" | ||
return self.layer_info | ||
|
||
def get_backward_hooks(self) -> List: | ||
""" | ||
Returns a list of tuples. Each tuple contains the layer name and the | ||
hook attached to it. | ||
""" | ||
layer_name_and_hooks = list() | ||
for name, layer in self._model.named_modules(): | ||
if name != "": | ||
layer_name_and_hooks.append((name, layer._get_backward_hooks())) | ||
return layer_name_and_hooks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,3 +28,6 @@ pynvml == 8.0.4 | |
|
||
# For mypy typing | ||
numpy >= 1.21 | ||
|
||
# For layerwise gradient scaler | ||
sklearn >= 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.