Skip to content

Commit

Permalink
Add walltime override to SummaryWriter methods (#207)
Browse files Browse the repository at this point in the history
* Add optional "walltime" to summary_writer which will override
default (current) walltime

* Fix missing walltime in add_scalars (fw.add_summary)

* Remove unused arg in SummaryWriter.add_embedding()

* Fix Flake8 error

* match codeing style
  • Loading branch information
Andrew Ho authored and lanpa committed Aug 8, 2018
1 parent 797e610 commit 3e9913c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 45 deletions.
6 changes: 3 additions & 3 deletions demo.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torchvision.models as models import torchvision.models as models
from torchvision import datasets from torchvision import datasets
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
#import skimage import datetime
#from skimage import data, io


resnet18 = models.resnet18(False) resnet18 = models.resnet18(False)
writer = SummaryWriter() writer = SummaryWriter()
Expand All @@ -23,7 +22,8 @@
for n_iter in range(100): for n_iter in range(100):
s1 = torch.rand(1) # value to keep s1 = torch.rand(1) # value to keep
s2 = torch.rand(1) s2 = torch.rand(1)
writer.add_scalar('data/scalar1', s1[0], n_iter) # data grouping by `slash` writer.add_scalar('data/scalar_systemtime', s1[0], n_iter) # data grouping by `slash`
writer.add_scalar('data/scalar_customtime', s1[0], n_iter, walltime=n_iter) # data grouping by `slash`
writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter), writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter),
"xcosx":n_iter*np.cos(n_iter), "xcosx":n_iter*np.cos(n_iter),
"arctanx": np.arctan(n_iter)}, n_iter) "arctanx": np.arctan(n_iter)}, n_iter)
Expand Down
109 changes: 67 additions & 42 deletions tensorboardX/writer.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, event_writer, graph=None, graph_def=None):
# TODO(zihaolucky). pass this an empty graph to check whether it's necessary. # TODO(zihaolucky). pass this an empty graph to check whether it's necessary.
# currently we don't support graph in MXNet using tensorboard. # currently we don't support graph in MXNet using tensorboard.


def add_summary(self, summary, global_step=None): def add_summary(self, summary, global_step=None, walltime=None):
"""Adds a `Summary` protocol buffer to the event file. """Adds a `Summary` protocol buffer to the event file.
This method wraps the provided summary in an `Event` protocol buffer This method wraps the provided summary in an `Event` protocol buffer
and adds it to the event file. and adds it to the event file.
Expand All @@ -93,33 +93,35 @@ def add_summary(self, summary, global_step=None):
summary: A `Summary` protocol buffer, optionally serialized as a string. summary: A `Summary` protocol buffer, optionally serialized as a string.
global_step: Number. Optional global step value to record with the global_step: Number. Optional global step value to record with the
summary. summary.
walltime: float. Optional walltime to override the default (current)
walltime (from time.time())
""" """
if isinstance(summary, bytes): if isinstance(summary, bytes):
summ = summary_pb2.Summary() summ = summary_pb2.Summary()
summ.ParseFromString(summary) summ.ParseFromString(summary)
summary = summ summary = summ
event = event_pb2.Event(summary=summary) event = event_pb2.Event(summary=summary)
self._add_event(event, global_step) self._add_event(event, global_step, walltime)


def add_graph(self, graph_profile): def add_graph(self, graph_profile, walltime=None):
graph = graph_profile[0] graph = graph_profile[0]
stepstats = graph_profile[1] stepstats = graph_profile[1]
"""Adds a `Graph` protocol buffer to the event file. """Adds a `Graph` protocol buffer to the event file.
""" """
event = event_pb2.Event(graph_def=graph.SerializeToString()) event = event_pb2.Event(graph_def=graph.SerializeToString())
self._add_event(event, None) self._add_event(event, None, walltime)


trm = event_pb2.TaggedRunMetadata(tag='step1', run_metadata=stepstats.SerializeToString()) trm = event_pb2.TaggedRunMetadata(tag='step1', run_metadata=stepstats.SerializeToString())
event = event_pb2.Event(tagged_run_metadata=trm) event = event_pb2.Event(tagged_run_metadata=trm)
self._add_event(event, None) self._add_event(event, None, walltime)


def add_graph_onnx(self, graph): def add_graph_onnx(self, graph, walltime=None):
"""Adds a `Graph` protocol buffer to the event file. """Adds a `Graph` protocol buffer to the event file.
""" """
event = event_pb2.Event(graph_def=graph.SerializeToString()) event = event_pb2.Event(graph_def=graph.SerializeToString())
self._add_event(event, None) self._add_event(event, None, walltime)


def add_session_log(self, session_log, global_step=None): def add_session_log(self, session_log, global_step=None, walltime=None):
"""Adds a `SessionLog` protocol buffer to the event file. """Adds a `SessionLog` protocol buffer to the event file.
This method wraps the provided session in an `Event` protocol buffer This method wraps the provided session in an `Event` protocol buffer
and adds it to the event file. and adds it to the event file.
Expand All @@ -129,10 +131,10 @@ def add_session_log(self, session_log, global_step=None):
summary. summary.
""" """
event = event_pb2.Event(session_log=session_log) event = event_pb2.Event(session_log=session_log)
self._add_event(event, global_step) self._add_event(event, global_step, walltime)


def _add_event(self, event, step): def _add_event(self, event, step, walltime):
event.wall_time = time.time() event.wall_time = time.time() if walltime is None else walltime
if step is not None: if step is not None:
event.step = int(step) event.step = int(step)
self.event_writer.add_event(event) self.event_writer.add_event(event)
Expand Down Expand Up @@ -315,19 +317,20 @@ def _check_caffe2(self, item):
# TODO (ml7): Remove caffe2_enabled check when PyTorch 1.0 merges PyTorch and Caffe2 # TODO (ml7): Remove caffe2_enabled check when PyTorch 1.0 merges PyTorch and Caffe2
return self.caffe2_enabled and isinstance(item, six.string_types) return self.caffe2_enabled and isinstance(item, six.string_types)


def add_scalar(self, tag, scalar_value, global_step=None): def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
"""Add scalar data to summary. """Add scalar data to summary.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
scalar_value (float or string/blobname): Value to save scalar_value (float or string/blobname): Value to save
global_step (int): Global step value to record global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
""" """
if self._check_caffe2(scalar_value): if self._check_caffe2(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value) scalar_value = workspace.FetchBlob(scalar_value)
self.file_writer.add_summary(scalar(tag, scalar_value), global_step) self.file_writer.add_summary(scalar(tag, scalar_value), global_step, walltime)


def add_scalars(self, main_tag, tag_scalar_dict, global_step=None): def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
"""Adds many scalar data to summary. """Adds many scalar data to summary.
Note that this function also keeps logged scalars in memory. In extreme case it explodes your RAM. Note that this function also keeps logged scalars in memory. In extreme case it explodes your RAM.
Expand All @@ -336,6 +339,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
main_tag (string): The parent name for the tags main_tag (string): The parent name for the tags
tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
global_step (int): Global step value to record global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Examples:: Examples::
Expand All @@ -345,7 +349,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
# This call adds three values to the same scalar plot with the tag # This call adds three values to the same scalar plot with the tag
# 'run_14h' in TensorBoard's scalar section. # 'run_14h' in TensorBoard's scalar section.
""" """
timestamp = time.time() walltime = time.time() if walltime is None else walltime
fw_logdir = self.file_writer.get_logdir() fw_logdir = self.file_writer.get_logdir()
for tag, scalar_value in tag_scalar_dict.items(): for tag, scalar_value in tag_scalar_dict.items():
fw_tag = fw_logdir + "/" + main_tag + "/" + tag fw_tag = fw_logdir + "/" + main_tag + "/" + tag
Expand All @@ -356,8 +360,8 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
self.all_writers[fw_tag] = fw self.all_writers[fw_tag] = fw
if self._check_caffe2(scalar_value): if self._check_caffe2(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value) scalar_value = workspace.FetchBlob(scalar_value)
fw.add_summary(scalar(main_tag, scalar_value), global_step) fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime)
self.__append_to_scalar_dict(fw_tag, scalar_value, global_step, timestamp) self.__append_to_scalar_dict(fw_tag, scalar_value, global_step, walltime)


def export_scalars_to_json(self, path): def export_scalars_to_json(self, path):
"""Exports to the given path an ASCII file containing all the scalars written """Exports to the given path an ASCII file containing all the scalars written
Expand All @@ -370,7 +374,7 @@ def export_scalars_to_json(self, path):
json.dump(self.scalar_dict, f) json.dump(self.scalar_dict, f)
self.scalar_dict = {} self.scalar_dict = {}


def add_histogram(self, tag, values, global_step=None, bins='tensorflow'): def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None):
"""Add histogram to summary. """Add histogram to summary.
Args: Args:
Expand All @@ -379,14 +383,15 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
global_step (int): Global step value to record global_step (int): Global step value to record
bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find
other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
walltime (float): Optional override default walltime (time.time()) of event
""" """
if self._check_caffe2(values): if self._check_caffe2(values):
values = workspace.FetchBlob(values) values = workspace.FetchBlob(values)
if isinstance(bins, six.string_types) and bins == 'tensorflow': if isinstance(bins, six.string_types) and bins == 'tensorflow':
bins = self.default_bins bins = self.default_bins
self.file_writer.add_summary(histogram(tag, values, bins), global_step) self.file_writer.add_summary(histogram(tag, values, bins), global_step, walltime)


def add_image(self, tag, img_tensor, global_step=None): def add_image(self, tag, img_tensor, global_step=None, walltime=None):
"""Add image data to summary. """Add image data to summary.
Note that this requires the ``pillow`` package. Note that this requires the ``pillow`` package.
Expand All @@ -395,29 +400,32 @@ def add_image(self, tag, img_tensor, global_step=None):
tag (string): Data identifier tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
global_step (int): Global step value to record global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Shape: Shape:
img_tensor: :math:`(3, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea. img_tensor: :math:`(3, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea.
""" """
if self._check_caffe2(img_tensor): if self._check_caffe2(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor) img_tensor = workspace.FetchBlob(img_tensor)
self.file_writer.add_summary(image(tag, img_tensor), global_step) self.file_writer.add_summary(image(tag, img_tensor), global_step, walltime)


def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, **kwargs): def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
walltime=None, **kwargs):
"""Add image boxes data to summary (useful for models such as Detectron). """Add image boxes data to summary (useful for models such as Detectron).
Args: Args:
tag (string): Data identifier tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects) box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects)
global_step (int): Global step value to record global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
""" """
if self._check_caffe2(img_tensor): if self._check_caffe2(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor) img_tensor = workspace.FetchBlob(img_tensor)
if self._check_caffe2(box_tensor): if self._check_caffe2(box_tensor):
box_tensor = workspace.FetchBlob(box_tensor) box_tensor = workspace.FetchBlob(box_tensor)
self._file_writer.add_summary(image_boxes(tag, img_tensor, box_tensor, **kwargs), global_step) self.file_writer.add_summary(image_boxes(tag, img_tensor, box_tensor, **kwargs), global_step, walltime)


def add_figure(self, tag, figure, global_step=None, close=True): def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
"""Render matplotlib figure into an image and add it to summary. """Render matplotlib figure into an image and add it to summary.
Note that this requires the ``matplotlib`` package. Note that this requires the ``matplotlib`` package.
Expand All @@ -427,10 +435,11 @@ def add_figure(self, tag, figure, global_step=None, close=True):
figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
global_step (int): Global step value to record global_step (int): Global step value to record
close (bool): Flag to automatically close the figure close (bool): Flag to automatically close the figure
walltime (float): Optional override default walltime (time.time()) of event
""" """
self.add_image(tag, figure_to_image(figure, close), global_step) self.add_image(tag, figure_to_image(figure, close), global_step, walltime)


def add_video(self, tag, vid_tensor, global_step=None, fps=4): def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
"""Add video data to summary. """Add video data to summary.
Note that this requires the ``moviepy`` package. Note that this requires the ``moviepy`` package.
Expand All @@ -440,41 +449,42 @@ def add_video(self, tag, vid_tensor, global_step=None, fps=4):
vid_tensor (torch.Tensor): Video data vid_tensor (torch.Tensor): Video data
global_step (int): Global step value to record global_step (int): Global step value to record
fps (float or int): Frames per second fps (float or int): Frames per second
walltime (float): Optional override default walltime (time.time()) of event
Shape: Shape:
vid_tensor: :math:`(B, C, T, H, W)`. vid_tensor: :math:`(B, C, T, H, W)`.
""" """
self.file_writer.add_summary(video(tag, vid_tensor, fps), global_step) self.file_writer.add_summary(video(tag, vid_tensor, fps), global_step, walltime)


def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100): def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
"""Add audio data to summary. """Add audio data to summary.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
snd_tensor (torch.Tensor): Sound data snd_tensor (torch.Tensor): Sound data
global_step (int): Global step value to record global_step (int): Global step value to record
sample_rate (int): sample rate in Hz sample_rate (int): sample rate in Hz
walltime (float): Optional override default walltime (time.time()) of event
Shape: Shape:
snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
""" """
if self._check_caffe2(snd_tensor): if self._check_caffe2(snd_tensor):
snd_tensor = workspace.FetchBlob(snd_tensor) snd_tensor = workspace.FetchBlob(snd_tensor)
self.file_writer.add_summary(audio(tag, snd_tensor, sample_rate=sample_rate), global_step) self.file_writer.add_summary(audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime)


def add_text(self, tag, text_string, global_step=None): def add_text(self, tag, text_string, global_step=None, walltime=None):
"""Add text data to summary. """Add text data to summary.
Args: Args:
tag (string): Data identifier tag (string): Data identifier
text_string (string): String to save text_string (string): String to save
global_step (int): Global step value to record global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Examples:: Examples::
writer.add_text('lstm', 'This is an lstm', 0) writer.add_text('lstm', 'This is an lstm', 0)
writer.add_text('rnn', 'This is an rnn', 10) writer.add_text('rnn', 'This is an rnn', 10)
""" """
self.file_writer.add_summary(text(tag, text_string), global_step) self.file_writer.add_summary(text(tag, text_string), global_step, walltime)


def add_graph_onnx(self, prototxt): def add_graph_onnx(self, prototxt):
self.file_writer.add_graph_onnx(gg(prototxt)) self.file_writer.add_graph_onnx(gg(prototxt))
Expand Down Expand Up @@ -594,7 +604,8 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
# new funcion to append to the config file a new embedding # new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), subdir, global_step, tag) append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), subdir, global_step, tag)


def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None): def add_pr_curve(self, tag, labels, predictions, global_step=None,
num_thresholds=127, weights=None, walltime=None):
"""Adds precision recall curve. """Adds precision recall curve.
Args: Args:
Expand All @@ -604,18 +615,25 @@ def add_pr_curve(self, tag, labels, predictions, global_step=None, num_threshold
The probability that an element be classified as true. Value should in [0, 1] The probability that an element be classified as true. Value should in [0, 1]
global_step (int): Global step value to record global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve. num_thresholds (int): Number of thresholds used to draw the curve.
walltime (float): Optional override default walltime (time.time()) of event
""" """
from .x2num import make_np from .x2num import make_np
labels, predictions = make_np(labels), make_np(predictions) labels, predictions = make_np(labels), make_np(predictions)
self.file_writer.add_summary(pr_curve(tag, labels, predictions, num_thresholds, weights), global_step) self.file_writer.add_summary(
pr_curve(tag, labels, predictions, num_thresholds, weights),
global_step, walltime)


def add_pr_curve_raw(self, tag, true_positive_counts, def add_pr_curve_raw(self, tag, true_positive_counts,
false_positive_counts, false_positive_counts,
true_negative_counts, true_negative_counts,
false_negative_counts, false_negative_counts,
precision, precision,
recall, global_step=None, num_thresholds=127, weights=None): recall,
global_step=None,
num_thresholds=127,
weights=None,
walltime=None):
"""Adds precision recall curve with raw data. """Adds precision recall curve with raw data.
Args: Args:
Expand All @@ -628,14 +646,21 @@ def add_pr_curve_raw(self, tag, true_positive_counts,
recall (torch.Tensor, numpy.array, or string/blobname): recall recall (torch.Tensor, numpy.array, or string/blobname): recall
global_step (int): Global step value to record global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve. num_thresholds (int): Number of thresholds used to draw the curve.
walltime (float): Optional override default walltime (time.time()) of event
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
""" """
self.file_writer.add_summary(pr_curve_raw(tag, true_positive_counts, self.file_writer.add_summary(
false_positive_counts, pr_curve_raw(tag,
true_negative_counts, true_positive_counts,
false_negative_counts, false_positive_counts,
precision, true_negative_counts,
recall, num_thresholds, weights), global_step) false_negative_counts,
precision,
recall,
num_thresholds,
weights),
global_step,
walltime)


def close(self): def close(self):
if self.file_writer is None: if self.file_writer is None:
Expand Down

0 comments on commit 3e9913c

Please sign in to comment.