Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #120 from bstriner/log_histogram
Browse files Browse the repository at this point in the history
log_histogram
  • Loading branch information
nasimrahaman committed Aug 15, 2018
2 parents ebeea3c + 5acb045 commit b227450
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
58 changes: 52 additions & 6 deletions inferno/trainers/callbacks/logging/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tensorboardX as tX
import numpy as np
import warnings
from scipy.misc import toimage
from .base import Logger
from ....utils import torch_utils as tu
from ....utils import python_utils as pyu
Expand All @@ -18,8 +20,8 @@ class TensorboardLogger(Logger):
Currently supports logging scalars and images.
"""

def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=None,
def __init__(self, log_directory=None,
log_scalars_every=None, log_images_every=None, log_histograms_every=None,
send_image_at_batch_indices='all', send_image_at_channel_indices='all',
send_volume_at_z_indices='mid'):
"""
Expand All @@ -31,6 +33,8 @@ def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=
How often scalars should be logged to Tensorboard. By default, once every iteration.
log_images_every : str or tuple or inferno.utils.train_utils.Frequency
How often images should be logged to Tensorboard. By default, once every iteration.
log_histograms_every : str or tuple or inferno.utils.train_utils.Frequency
How often histograms should be logged to Tensorboard. By default, once every iteration.
send_image_at_batch_indices : list or str
The indices of the batches to be logged. An `image_batch` usually has the shape
(num_samples, num_channels, num_rows, num_cols). By setting this argument to say
Expand All @@ -51,6 +55,7 @@ def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=
super(TensorboardLogger, self).__init__(log_directory=log_directory)
self._log_scalars_every = None
self._log_images_every = None
self._log_histograms_every = None
self._writer = None
self._config = {'image_batch_indices': send_image_at_batch_indices,
'image_channel_indices': send_image_at_channel_indices,
Expand All @@ -69,6 +74,8 @@ def __init__(self, log_directory=None, log_scalars_every=None, log_images_every=
self.log_scalars_every = log_scalars_every
if log_images_every is not None:
self.log_images_every = log_images_every
if log_histograms_every is not None:
self.log_histograms_every = log_histograms_every

@property
def writer(self):
Expand Down Expand Up @@ -112,6 +119,24 @@ def log_images_now(self):
epoch_count=self.trainer.epoch_count,
persistent=True)

@property
def log_histograms_every(self):
if self._log_histograms_every is None:
self._log_histograms_every = tru.Frequency(1, 'iterations')
return self._log_histograms_every

@log_histograms_every.setter
def log_histograms_every(self, value):
self._log_histograms_every = tru.Frequency.build_from(value)

@property
def log_histograms_now(self):
# Using persistent=True in a property getter is probably not a very good idea...
# We need to make sure that this getter is called only once per callback-call.
return self.log_histograms_every.match(iteration_count=self.trainer.iteration_count,
epoch_count=self.trainer.epoch_count,
persistent=True)

def observe_state(self, key, observe_while='training'):
# Validate arguments
keyword_mapping = {'train': 'training',
Expand All @@ -135,6 +160,20 @@ def observe_state(self, key, observe_while='training'):
raise NotImplementedError
return self

def unobserve_state(self, key, observe_while='training'):
if observe_while == 'training':
self._trainer_states_being_observed_while_training.remove(key)
elif observe_while == 'validating':
self._trainer_states_being_observed_while_validating.remove(key)
else:
raise NotImplementedError
return self

def unobserve_states(self, keys, observe_while='training'):
for key in keys:
self.unobserve_state(key, observe_while=observe_while)
return self

def observe_training_and_validation_state(self, key):
for mode in ['training', 'validation']:
self.observe_state('{}_{}'.format(mode, key), observe_while=mode)
Expand All @@ -149,14 +188,16 @@ def observe_training_and_validation_states(self, keys):
self.observe_training_and_validation_state(key)
return self

def log_object(self, tag, object_, allow_scalar_logging=True, allow_image_logging=True):
def log_object(self, tag, object_,
allow_scalar_logging=True, allow_image_logging=True, allow_histogram_logging=True):
assert isinstance(tag, str)
if isinstance(object_, (list, tuple)):
for object_num, _object in enumerate(object_):
self.log_object("{}_{}".format(tag, object_num),
_object,
allow_scalar_logging,
allow_image_logging)
allow_image_logging,
allow_histogram_logging)
return
# Check whether object is a scalar
if tu.is_scalar_tensor(object_) and allow_scalar_logging:
Expand All @@ -173,9 +214,14 @@ def log_object(self, tag, object_, allow_scalar_logging=True, allow_image_loggin
elif tu.is_image_or_volume_tensor(object_) and allow_image_logging:
# Log images
self.log_image_or_volume_batch(tag, object_, self.trainer.iteration_count)
elif tu.is_vector_tensor(object_) and allow_histogram_logging:
# Log histograms
values = tu.unwrap(object_, as_numpy=True)
self.log_histogram(tag, values, self.trainer.iteration_count)
else:
# Object is neither a scalar nor an image, there's nothing we can do
pass
# Object is neither a scalar nor an image nor a vector, there's nothing we can do
if tu.is_tensor(object_):
warnings.warn("Unsupported attempt to log tensor `{}` of shape `{}`".format(tag, object_.size()))

def end_of_training_iteration(self, **_):
log_scalars_now = self.log_scalars_now
Expand Down
8 changes: 6 additions & 2 deletions inferno/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def unwrap(tensor_or_variable, to_cpu=True, as_numpy=False, extract_item=False):


def is_tensor(object_):
missed_tensor_classes = {torch.HalfTensor}
return torch.is_tensor(object_) or type(object_) in missed_tensor_classes
missed_tensor_classes = (torch.HalfTensor,)
return torch.is_tensor(object_) or isinstance(object_, missed_tensor_classes)


def is_label_tensor(object_):
Expand Down Expand Up @@ -78,6 +78,10 @@ def is_scalar_tensor(object_):
return is_tensor(object_) and object_.dim() <= 1 and object_.numel() == 1


def is_vector_tensor(object_):
return is_tensor(object_) and object_.dim() == 1 and object_.numel() > 1


def assert_same_size(tensor_1, tensor_2):
assert_(list(tensor_1.size()) == list(tensor_2.size()),
"Tensor sizes {} and {} do not match.".format(tensor_1.size(), tensor_2.size()),
Expand Down

0 comments on commit b227450

Please sign in to comment.