Skip to content

Commit

Permalink
Merge pull request #94 from bnsh/master
Browse files Browse the repository at this point in the history
Added the ability to handle multiple embeddings per epoch.
  • Loading branch information
lanpa committed Mar 16, 2018
2 parents 895f687 + ef80c7f commit 7bdda22
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
50 changes: 50 additions & 0 deletions demo_multiple_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import math
import torch
from tensorboardX import SummaryWriter

def main():
degrees = torch.arange(3600).resize_(3600, 1) * math.pi / 180.0
labels = ["%d" % (i) for i in range(0, 3600)]

with SummaryWriter() as writer:
# Maybe make a bunch of data that's always shifted in some
# way, and that will be hard for PCA to turn into a sphere?

for epoch in range(0, 16):
shift = epoch * 2 * math.pi / 16.0
mat = torch.cat([ \
torch.sin(shift + degrees * 2 * math.pi / 180.0), \
torch.sin(shift + degrees * 3 * math.pi / 180.0), \
torch.sin(shift + degrees * 5 * math.pi / 180.0), \
torch.sin(shift + degrees * 7 * math.pi / 180.0), \
torch.sin(shift + degrees * 11 * math.pi / 180.0) \
], dim=1)
writer.add_embedding(mat=mat, metadata=labels, tag="sin", global_step=epoch)

mat = torch.cat([ \
torch.cos(shift + degrees * 2 * math.pi / 180.0), \
torch.cos(shift + degrees * 3 * math.pi / 180.0), \
torch.cos(shift + degrees * 5 * math.pi / 180.0), \
torch.cos(shift + degrees * 7 * math.pi / 180.0), \
torch.cos(shift + degrees * 11 * math.pi / 180.0) \
], dim=1)
writer.add_embedding(mat=mat, metadata=labels, tag="cos", global_step=epoch)

mat = torch.cat([ \
torch.tan(shift + degrees * 2 * math.pi / 180.0), \
torch.tan(shift + degrees * 3 * math.pi / 180.0), \
torch.tan(shift + degrees * 5 * math.pi / 180.0), \
torch.tan(shift + degrees * 7 * math.pi / 180.0), \
torch.tan(shift + degrees * 11 * math.pi / 180.0) \
], dim=1)
writer.add_embedding(mat=mat, metadata=labels, tag="tan", global_step=epoch)

if __name__ == "__main__":
main()

# tensorboard --logdir runs
# Under "Projection, you should see
# 48 tensor found named
# cos:cos-00000 to cos:cos-00016
# sin:sin-00000 to sin:sin-00016
# tan:tan-00000 to tan:tan-00016
10 changes: 5 additions & 5 deletions tensorboardX/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ def make_sprite(label_img, save_path):
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)


def append_pbtxt(metadata, label_img, save_path, global_step, tag):
def append_pbtxt(metadata, label_img, save_path, subdir, global_step, tag):
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
# step = os.path.split(save_path)[-1]
f.write('embeddings {\n')
f.write('tensor_name: "{}:{}"\n'.format(tag, global_step))
f.write('tensor_path: "{}"\n'.format(os.path.join(global_step, 'tensors.tsv')))
f.write('tensor_name: "{}:{}"\n'.format(tag, str(global_step).zfill(5)))
f.write('tensor_path: "{}"\n'.format(os.path.join(subdir, 'tensors.tsv')))
if metadata is not None:
f.write('metadata_path: "{}"\n'.format(os.path.join(global_step, 'metadata.tsv')))
f.write('metadata_path: "{}"\n'.format(os.path.join(subdir, 'metadata.tsv')))
if label_img is not None:
f.write('sprite {\n')
f.write('image_path: "{}"\n'.format(os.path.join(global_step, 'sprite.png')))
f.write('image_path: "{}"\n'.format(os.path.join(subdir, 'sprite.png')))
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
f.write('}\n')
Expand Down
16 changes: 14 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,15 @@ def add_graph(self, model, input_to_model, verbose=False):
return
self.file_writer.add_graph(graph(model, input_to_model, verbose))

@staticmethod
def _encode(rawstr):
# I'd use urllib but, I'm unsure about the differences from python3 to python2, etc.
retval = rawstr
retval = retval.replace("%", "%%%02x" % (ord("%")))
retval = retval.replace("/", "%%%02x" % (ord("/")))
retval = retval.replace("\\", "%%%02x" % (ord("\\")))
return retval

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default'):
"""Add embedding projector data to summary.
Expand Down Expand Up @@ -436,7 +445,10 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
if global_step is None:
global_step = 0
# clear pbtxt?
save_path = os.path.join(self.file_writer.get_logdir(), str(global_step).zfill(5))
# Maybe we should encode the tag so slashes don't trip us up?
# I don't think this will mess us up, but better safe than sorry.
subdir = "%s/%s" % (str(global_step).zfill(5), self._encode(tag))
save_path = os.path.join(self.file_writer.get_logdir(), subdir)
try:
os.makedirs(save_path)
except OSError:
Expand All @@ -450,7 +462,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), save_path)
# new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5), tag)
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), subdir, global_step, tag)

def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None):
"""Adds precision recall curve.
Expand Down

0 comments on commit 7bdda22

Please sign in to comment.