Skip to content

Commit

Permalink
Merge pull request #28 from andres-fr/master
Browse files Browse the repository at this point in the history
Support for multi-scalars and JSON export
  • Loading branch information
lanpa committed Sep 24, 2017
2 parents 4ac9de7 + 4be567d commit b14ffd6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ for n_iter in range(100):
s2 = torch.rand(1)
writer.add_scalar('data/scalar1', s1[0], n_iter) #data grouping by `slash`
writer.add_scalar('data/scalar2', s2[0], n_iter)
writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter),
"xcosx":n_iter*np.cos(n_iter),
"arctanx": np.arctan(n_iter)}, n_iter)
x = torch.rand(32, 3, 64, 64) # output from network
if n_iter%10==0:
x = vutils.make_grid(x, normalize=True, scale_each=True)
Expand All @@ -60,6 +63,10 @@ images = dataset.test_data[:100].float()
label = dataset.test_labels[:100]
features = images.view(100, 784)
writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))

# export scalar data to JSON for external processing
writer.export_scalars_to_json("./all_scalars.json")

writer.close()
```

Expand Down
51 changes: 48 additions & 3 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ class SummaryWriter(object):
"""
def __init__(self, log_dir=None, comment=''):
"""
Args:
log_dir (string): save location, default is: runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run. Use hierarchical folder structure to compare between runs easily. e.g. 'runs/exp1', 'runs/exp2'
comment (string): comment that appends to the default log_dir
Expand All @@ -239,9 +238,22 @@ def __init__(self, log_dir=None, comment=''):
v *= 1.1
self.default_bins = neg_buckets[::-1] + [0] + buckets
self.text_tags = []
#
self.all_writers = {self.file_writer.get_logdir() : self.file_writer}
self.scalar_dict = {} # {writer_id : [[timestamp, step, value],...],...}


def __append_to_scalar_dict(self, tag, scalar_value, global_step,
timestamp):
"""This adds an entry to the self.scalar_dict datastructure with format
{writer_id : [[timestamp, step, value], ...], ...}.
"""
if not tag in self.scalar_dict.keys():
self.scalar_dict[tag] = []
self.scalar_dict[tag].append([timestamp, global_step, scalar_value])

def add_scalar(self, tag, scalar_value, global_step=None):
"""Add scalar data to summary.
Args:
tag (string): Data identifier
scalar_value (float): Value to save
Expand All @@ -251,6 +263,35 @@ def add_scalar(self, tag, scalar_value, global_step=None):
scalar_value = makenp(scalar_value)
assert(scalar_value.squeeze().ndim==0), 'input of add_scalar should be 0D'
self.file_writer.add_summary(scalar(tag, scalar_value), global_step)
self.__append_to_scalar_dict(tag, float(scalar_value), global_step, time.time())

def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
"""Usage example:
writer.add_scalars('run_14h',{'xsinx':i*np.sin(i/r),
'xcosx':i*np.cos(i/r),
'arctanx': numsteps*np.arctan(i/r)}, i)
This function adds three values to the same scalar plot with the tag
'run_14h' in TensorBoard's scalar section.
"""
timestamp = time.time()
fw_logdir = self.file_writer.get_logdir()
for tag,scalar_value in tag_scalar_dict.items():
fw_tag = fw_logdir+"/"+main_tag+"/"+tag
if fw_tag in self.all_writers.keys():
fw = self.all_writers[fw_tag]
else:
fw = FileWriter(logdir=fw_tag)
self.all_writers[fw_tag] = fw
fw.add_summary(scalar(main_tag, scalar_value), global_step)
self.__append_to_scalar_dict(fw_tag, scalar_value, global_step, timestamp)

def export_scalars_to_json(self, path):
"""Exports to the given path an ASCII file containing all the scalars written
so far by this instance, with the following format:
{writer_id : [[timestamp, step, value], ...], ...}
"""
with open(path, "w") as f:
json.dump(self.scalar_dict, f)

def add_histogram(self, tag, values, global_step=None, bins='tensorflow'):
"""Add histogram to summary.
Expand Down Expand Up @@ -397,8 +438,12 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None):
def close(self):
self.file_writer.flush()
self.file_writer.close()
for path, writer in self.all_writers.items():
writer.flush()
writer.close()

def __del__(self):
if self.file_writer is not None:
self.file_writer.close()

for writer in self.all_writers.values():
writer.close()

0 comments on commit b14ffd6

Please sign in to comment.