Skip to content

Commit

Permalink
Merge pull request #632 from mv1388/fix-grad-clip-multi-optimizer-cal…
Browse files Browse the repository at this point in the history
…lback

Fix grad clip multi optimizer callback
  • Loading branch information
mv1388 committed Oct 31, 2020
2 parents 00d98d9 + 7bf2126 commit 8d0ee4f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions aitoolbox/torchtrain/callbacks/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from aitoolbox.torchtrain.callbacks.abstract import AbstractCallback, AbstractExperimentCallback
from aitoolbox.torchtrain.multi_loss_optim import MultiOptimizer
from aitoolbox.experiment.local_save.local_results_save import BaseLocalResultsSaver
from aitoolbox.experiment.result_reporting.report_generator import GradientPlotter
from aitoolbox.cloud.AWS.results_save import BaseResultsSaver as BaseResultsS3Saver
Expand Down Expand Up @@ -41,7 +42,11 @@ def __init__(self, max_grad_value):

def on_after_gradient_update(self, optimizer_idx):
if self.train_loop_obj.use_amp:
self.train_loop_obj.amp_scaler.unscale_(self.train_loop_obj.optimizer)
optimizer = self.train_loop_obj.optimizer
if isinstance(optimizer, MultiOptimizer):
optimizer = optimizer[optimizer_idx]

self.train_loop_obj.amp_scaler.unscale_(optimizer)

torch.nn.utils.clip_grad_value_(self.train_loop_obj.model.parameters(), self.max_grad_value)

Expand All @@ -60,7 +65,11 @@ def __init__(self, max_grad_norm, **kwargs):

def on_after_gradient_update(self, optimizer_idx):
if self.train_loop_obj.use_amp:
self.train_loop_obj.amp_scaler.unscale_(self.train_loop_obj.optimizer)
optimizer = self.train_loop_obj.optimizer
if isinstance(optimizer, MultiOptimizer):
optimizer = optimizer[optimizer_idx]

self.train_loop_obj.amp_scaler.unscale_(optimizer)

torch.nn.utils.clip_grad_norm_(self.train_loop_obj.model.parameters(), self.max_grad_norm, **self.kwargs)

Expand Down
Binary file modified dist/aitoolbox-1.2.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.2.0.tar.gz
Binary file not shown.

0 comments on commit 8d0ee4f

Please sign in to comment.