Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- added gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Nov 3, 2018
1 parent 2ef65f6 commit 8487b6a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
25 changes: 25 additions & 0 deletions inferno/trainers/callbacks/essentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,28 @@ def apply(self):

def end_of_training_iteration(self, **_):
self.maintain()


class GradientClip(Callback):
def __init__(self, clip_value=None, clip_norm=None):
super(GradientClip, self).__init__()
assert_(not (clip_value is None and clip_norm is None),
"Must provide either clip_value or clip_norm.",
ValueError)
assert_(clip_value is None or clip_norm is None,
f"Must provide only one, but not both: "
f"clip_value ({clip_value}) or clip_norm ({clip_norm}).",
RuntimeError)
self._clip_value = clip_value
self._clip_norm = clip_norm

@property
def mode(self):
return 'value' if self._clip_value is not None else 'norm'

@property
def norm_or_value(self):
return self._clip_value if self._clip_value is not None else self._clip_norm

def after_model_and_loss_is_applied(self, **_):
tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value)
12 changes: 12 additions & 0 deletions inferno/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,15 @@ def flatten_samples(tensor_or_variable):
# Now flatten out all but the first axis and return
flattened = permuted.view(num_channels, -1)
return flattened


def clip_gradients_(parameters, mode, norm_or_value):
assert_(mode in ['norm', 'value'],
f"Mode must be 'norm' or 'value', got '{mode}' instead.",
ValueError)
if mode == 'norm':
torch.nn.utils.clip_grad_norm_(parameters, norm_or_value)
elif mode == 'value':
torch.nn.utils.clip_grad_value_(parameters, norm_or_value)
else:
raise NotImplementedError

0 comments on commit 8487b6a

Please sign in to comment.