Permalink
Browse files

Add walltime override to SummaryWriter methods (#207)

* Add optional "walltime" to summary_writer which will override
default (current) walltime

* Fix missing walltime in add_scalars (fw.add_summary)

* Remove unused arg in SummaryWriter.add_embedding()

* Fix Flake8 error

* match codeing style
  • Loading branch information...
andrewkho authored and lanpa committed Aug 8, 2018
1 parent 797e610 commit 3e9913cb65769ccaff38ccef902ad33774251198
Showing with 70 additions and 45 deletions.
  1. +3 −3 demo.py
  2. +67 −42 tensorboardX/writer.py
View
@@ -4,8 +4,7 @@
import torchvision.models as models
from torchvision import datasets
from tensorboardX import SummaryWriter
#import skimage
#from skimage import data, io
import datetime
resnet18 = models.resnet18(False)
writer = SummaryWriter()
@@ -23,7 +22,8 @@
for n_iter in range(100):
s1 = torch.rand(1) # value to keep
s2 = torch.rand(1)
writer.add_scalar('data/scalar1', s1[0], n_iter) # data grouping by `slash`
writer.add_scalar('data/scalar_systemtime', s1[0], n_iter) # data grouping by `slash`
writer.add_scalar('data/scalar_customtime', s1[0], n_iter, walltime=n_iter) # data grouping by `slash`
writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter),
"xcosx":n_iter*np.cos(n_iter),
"arctanx": np.arctan(n_iter)}, n_iter)
View
@@ -79,7 +79,7 @@ def __init__(self, event_writer, graph=None, graph_def=None):
# TODO(zihaolucky). pass this an empty graph to check whether it's necessary.
# currently we don't support graph in MXNet using tensorboard.
def add_summary(self, summary, global_step=None):
def add_summary(self, summary, global_step=None, walltime=None):
"""Adds a `Summary` protocol buffer to the event file.
This method wraps the provided summary in an `Event` protocol buffer
and adds it to the event file.
@@ -93,33 +93,35 @@ def add_summary(self, summary, global_step=None):
summary: A `Summary` protocol buffer, optionally serialized as a string.
global_step: Number. Optional global step value to record with the
summary.
walltime: float. Optional walltime to override the default (current)
walltime (from time.time())
"""
if isinstance(summary, bytes):
summ = summary_pb2.Summary()
summ.ParseFromString(summary)
summary = summ
event = event_pb2.Event(summary=summary)
self._add_event(event, global_step)
self._add_event(event, global_step, walltime)
def add_graph(self, graph_profile):
def add_graph(self, graph_profile, walltime=None):
graph = graph_profile[0]
stepstats = graph_profile[1]
"""Adds a `Graph` protocol buffer to the event file.
"""
event = event_pb2.Event(graph_def=graph.SerializeToString())
self._add_event(event, None)
self._add_event(event, None, walltime)
trm = event_pb2.TaggedRunMetadata(tag='step1', run_metadata=stepstats.SerializeToString())
event = event_pb2.Event(tagged_run_metadata=trm)
self._add_event(event, None)
self._add_event(event, None, walltime)
def add_graph_onnx(self, graph):
def add_graph_onnx(self, graph, walltime=None):
"""Adds a `Graph` protocol buffer to the event file.
"""
event = event_pb2.Event(graph_def=graph.SerializeToString())
self._add_event(event, None)
self._add_event(event, None, walltime)
def add_session_log(self, session_log, global_step=None):
def add_session_log(self, session_log, global_step=None, walltime=None):
"""Adds a `SessionLog` protocol buffer to the event file.
This method wraps the provided session in an `Event` protocol buffer
and adds it to the event file.
@@ -129,10 +131,10 @@ def add_session_log(self, session_log, global_step=None):
summary.
"""
event = event_pb2.Event(session_log=session_log)
self._add_event(event, global_step)
self._add_event(event, global_step, walltime)
def _add_event(self, event, step):
event.wall_time = time.time()
def _add_event(self, event, step, walltime):
event.wall_time = time.time() if walltime is None else walltime
if step is not None:
event.step = int(step)
self.event_writer.add_event(event)
@@ -315,19 +317,20 @@ def _check_caffe2(self, item):
# TODO (ml7): Remove caffe2_enabled check when PyTorch 1.0 merges PyTorch and Caffe2
return self.caffe2_enabled and isinstance(item, six.string_types)
def add_scalar(self, tag, scalar_value, global_step=None):
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
"""Add scalar data to summary.
Args:
tag (string): Data identifier
scalar_value (float or string/blobname): Value to save
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
"""
if self._check_caffe2(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value)
self.file_writer.add_summary(scalar(tag, scalar_value), global_step)
self.file_writer.add_summary(scalar(tag, scalar_value), global_step, walltime)
def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
"""Adds many scalar data to summary.
Note that this function also keeps logged scalars in memory. In extreme case it explodes your RAM.
@@ -336,6 +339,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
main_tag (string): The parent name for the tags
tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Examples::
@@ -345,7 +349,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
# This call adds three values to the same scalar plot with the tag
# 'run_14h' in TensorBoard's scalar section.
"""
timestamp = time.time()
walltime = time.time() if walltime is None else walltime
fw_logdir = self.file_writer.get_logdir()
for tag, scalar_value in tag_scalar_dict.items():
fw_tag = fw_logdir + "/" + main_tag + "/" + tag
@@ -356,8 +360,8 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
self.all_writers[fw_tag] = fw
if self._check_caffe2(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value)
fw.add_summary(scalar(main_tag, scalar_value), global_step)
self.__append_to_scalar_dict(fw_tag, scalar_value, global_step, timestamp)
fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime)
self.__append_to_scalar_dict(fw_tag, scalar_value, global_step, walltime)
def export_scalars_to_json(self, path):
"""Exports to the given path an ASCII file containing all the scalars written
@@ -370,7 +374,7 @@ def export_scalars_to_json(self, path):
json.dump(self.scalar_dict, f)
self.scalar_dict = {}
def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None):
"""Add histogram to summary.
Args:
@@ -379,14 +383,15 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
global_step (int): Global step value to record
bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find
other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
walltime (float): Optional override default walltime (time.time()) of event
"""
if self._check_caffe2(values):
values = workspace.FetchBlob(values)
if isinstance(bins, six.string_types) and bins == 'tensorflow':
bins = self.default_bins
self.file_writer.add_summary(histogram(tag, values, bins), global_step)
self.file_writer.add_summary(histogram(tag, values, bins), global_step, walltime)
def add_image(self, tag, img_tensor, global_step=None):
def add_image(self, tag, img_tensor, global_step=None, walltime=None):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
@@ -395,29 +400,32 @@ def add_image(self, tag, img_tensor, global_step=None):
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Shape:
img_tensor: :math:`(3, H, W)`. Use ``torchvision.utils.make_grid()`` to prepare it is a good idea.
"""
if self._check_caffe2(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor)
self.file_writer.add_summary(image(tag, img_tensor), global_step)
self.file_writer.add_summary(image(tag, img_tensor), global_step, walltime)
def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, **kwargs):
def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
walltime=None, **kwargs):
"""Add image boxes data to summary (useful for models such as Detectron).
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects)
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
"""
if self._check_caffe2(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor)
if self._check_caffe2(box_tensor):
box_tensor = workspace.FetchBlob(box_tensor)
self._file_writer.add_summary(image_boxes(tag, img_tensor, box_tensor, **kwargs), global_step)
self.file_writer.add_summary(image_boxes(tag, img_tensor, box_tensor, **kwargs), global_step, walltime)
def add_figure(self, tag, figure, global_step=None, close=True):
def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
"""Render matplotlib figure into an image and add it to summary.
Note that this requires the ``matplotlib`` package.
@@ -427,10 +435,11 @@ def add_figure(self, tag, figure, global_step=None, close=True):
figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
global_step (int): Global step value to record
close (bool): Flag to automatically close the figure
walltime (float): Optional override default walltime (time.time()) of event
"""
self.add_image(tag, figure_to_image(figure, close), global_step)
self.add_image(tag, figure_to_image(figure, close), global_step, walltime)
def add_video(self, tag, vid_tensor, global_step=None, fps=4):
def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
"""Add video data to summary.
Note that this requires the ``moviepy`` package.
@@ -440,41 +449,42 @@ def add_video(self, tag, vid_tensor, global_step=None, fps=4):
vid_tensor (torch.Tensor): Video data
global_step (int): Global step value to record
fps (float or int): Frames per second
walltime (float): Optional override default walltime (time.time()) of event
Shape:
vid_tensor: :math:`(B, C, T, H, W)`.
"""
self.file_writer.add_summary(video(tag, vid_tensor, fps), global_step)
self.file_writer.add_summary(video(tag, vid_tensor, fps), global_step, walltime)
def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100):
def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
"""Add audio data to summary.
Args:
tag (string): Data identifier
snd_tensor (torch.Tensor): Sound data
global_step (int): Global step value to record
sample_rate (int): sample rate in Hz
walltime (float): Optional override default walltime (time.time()) of event
Shape:
snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
"""
if self._check_caffe2(snd_tensor):
snd_tensor = workspace.FetchBlob(snd_tensor)
self.file_writer.add_summary(audio(tag, snd_tensor, sample_rate=sample_rate), global_step)
self.file_writer.add_summary(audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime)
def add_text(self, tag, text_string, global_step=None):
def add_text(self, tag, text_string, global_step=None, walltime=None):
"""Add text data to summary.
Args:
tag (string): Data identifier
text_string (string): String to save
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Examples::
writer.add_text('lstm', 'This is an lstm', 0)
writer.add_text('rnn', 'This is an rnn', 10)
"""
self.file_writer.add_summary(text(tag, text_string), global_step)
self.file_writer.add_summary(text(tag, text_string), global_step, walltime)
def add_graph_onnx(self, prototxt):
self.file_writer.add_graph_onnx(gg(prototxt))
@@ -594,7 +604,8 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
# new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), subdir, global_step, tag)
def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None):
def add_pr_curve(self, tag, labels, predictions, global_step=None,
num_thresholds=127, weights=None, walltime=None):
"""Adds precision recall curve.
Args:
@@ -604,18 +615,25 @@ def add_pr_curve(self, tag, labels, predictions, global_step=None, num_threshold
The probability that an element be classified as true. Value should in [0, 1]
global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (float): Optional override default walltime (time.time()) of event
"""
from .x2num import make_np
labels, predictions = make_np(labels), make_np(predictions)
self.file_writer.add_summary(pr_curve(tag, labels, predictions, num_thresholds, weights), global_step)
self.file_writer.add_summary(
pr_curve(tag, labels, predictions, num_thresholds, weights),
global_step, walltime)
def add_pr_curve_raw(self, tag, true_positive_counts,
false_positive_counts,
true_negative_counts,
false_negative_counts,
precision,
recall, global_step=None, num_thresholds=127, weights=None):
recall,
global_step=None,
num_thresholds=127,
weights=None,
walltime=None):
"""Adds precision recall curve with raw data.
Args:
@@ -628,14 +646,21 @@ def add_pr_curve_raw(self, tag, true_positive_counts,
recall (torch.Tensor, numpy.array, or string/blobname): recall
global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (float): Optional override default walltime (time.time()) of event
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
"""
self.file_writer.add_summary(pr_curve_raw(tag, true_positive_counts,
false_positive_counts,
true_negative_counts,
false_negative_counts,
precision,
recall, num_thresholds, weights), global_step)
self.file_writer.add_summary(
pr_curve_raw(tag,
true_positive_counts,
false_positive_counts,
true_negative_counts,
false_negative_counts,
precision,
recall,
num_thresholds,
weights),
global_step,
walltime)
def close(self):
if self.file_writer is None:

0 comments on commit 3e9913c

Please sign in to comment.