Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #164 from inferno-pytorch/gradient_callback
Browse files Browse the repository at this point in the history
add gradient logging callback
  • Loading branch information
Steffen-Wolf committed Jan 31, 2019
2 parents b0ed9bd + ef2de4e commit fa53dcb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
3 changes: 2 additions & 1 deletion inferno/trainers/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = ['CallbackEngine','Callback', 'Console','essentials','scheduling']
__all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients']

from .base import CallbackEngine, Callback
from .console import Console
from . import essentials
from . import scheduling
from . import gradients

try:
from .tqdm import TQDMProgressBar
Expand Down
49 changes: 49 additions & 0 deletions inferno/trainers/callbacks/gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from ...utils.train_utils import Frequency
from ...utils.exceptions import assert_, FrequencyValueError
from .base import Callback


class LogOutputGradients(Callback):
"""Logs the gradient of the network output"""

def __init__(self, frequency):
super(LogOutputGradients, self).__init__()
self.log_every = frequency
self.registered = False
self.hook_handle = None

@property
def log_every(self):
return self._log_every

@log_every.setter
def log_every(self, value):
self._log_every = Frequency(value, 'iterations')
assert_(self.log_every.is_consistent,
"Log frequency is not consistent.",
FrequencyValueError)

def add_hook(self):
def hook(module, grad_input, grad_output):
if self.log_every.match(iteration_count=self.trainer.iteration_count,
epoch_count=self.trainer.epoch_count,
persistent=True, match_zero=True):
self.trainer.update_state('output_gradient', grad_output[0].detach().cpu())

self.hook_handle = self.trainer.model.register_backward_hook(hook)

def begin_of_fit(self, **kwargs):
self._trainer.logger.observe_state("output_gradient",
observe_while='training')
self.add_hook()

def begin_of_save(self, **_):
# remove hook from model, because you can't pickle it.
if self.hook_handle is not None:
self.hook_handle.remove()
self.hook_handle = None


def end_of_save(self, **_):
# add hook after model save
self.add_hook()

0 comments on commit fa53dcb

Please sign in to comment.