Skip to content

Commit

Permalink
Merge pull request #13 from lucabergamini/master
Browse files Browse the repository at this point in the history
fix #12
  • Loading branch information
lanpa committed Aug 15, 2017
2 parents 0fd6c52 + 7001a02 commit 9802be0
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions tensorboard/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,26 @@ def make_tsv(metadata, save_path):
metadata = [str(x) for x in metadata]
with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
for x in metadata:
f.write(x+'\n')
f.write(x + '\n')


# https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
def make_sprite(label_img, save_path):
import math
import torch
import torchvision
nrow = int(math.floor(math.sqrt(label_img.size(0))))
xx = torchvision.utils.make_grid(torch.Tensor(1,3,32,32), padding=0)
if xx.size(2)==33: # https://github.com/pytorch/vision/issues/206
# this ensures the sprite image has correct dimension as described in
# https://www.tensorflow.org/get_started/embedding_viz
nrow = int(math.ceil((label_img.size(0)) ** 0.5))

# augment images so that #images equals nrow*nrow
label_img = torch.cat((label_img, torch.randn(nrow ** 2 - label_img.size(0), *label_img.size()[1:]) * 255), 0)

# Dirty fix: no pixel are appended by make_grid call in save_image (https://github.com/pytorch/vision/issues/206)
xx = torchvision.utils.make_grid(torch.Tensor(1, 3, 32, 32), padding=0)
if xx.size(2) == 33:
sprite = torchvision.utils.make_grid(label_img, nrow=nrow, padding=0)
sprite = sprite[:,1:,1:]
sprite = sprite[:, 1:, 1:]
torchvision.utils.save_image(sprite, os.path.join(save_path, 'sprite.png'))
else:
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)
Expand All @@ -40,8 +47,7 @@ def make_mat(matlist, save_path):
with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
for x in matlist:
x = [str(i) for i in x]
f.write('\t'.join(x)+'\n')

f.write('\t'.join(x) + '\n')

def add_embedding(mat, save_path, metadata=None, label_img=None):
"""add embedding
Expand All @@ -60,8 +66,8 @@ def add_embedding(mat, save_path, metadata=None, label_img=None):
~~This function needs tensorflow installed. It invokes tensorflow to dump data. ~~
Therefore I separate it from the SummaryWriter class. Please pass ``writer.file_writer.get_logdir()`` to ``save_path`` to prevent glitches.
If ``save_path`` is different than SummaryWritter's save path, you need to pass the leave directory to tensorboard's logdir argument,
otherwise it cannot display anything. e.g. if ``save_path`` equals 'path/to/embedding',
If ``save_path`` is different than SummaryWritter's save path, you need to pass the leave directory to tensorboard's logdir argument,
otherwise it cannot display anything. e.g. if ``save_path`` equals 'path/to/embedding',
you need to call 'tensorboard --logdir=path/to/embedding', instead of 'tensorboard --logdir=path'.
Expand All @@ -81,7 +87,7 @@ def add_embedding(mat, save_path, metadata=None, label_img=None):
label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
label_img[i]*=i/100.0
add_embedding(torch.randn(100, 5), 'embedding1', metadata=meta, label_img=label_img)
add_embedding(torch.randn(100, 5), 'embedding2', label_img=label_img)
add_embedding(torch.randn(100, 5), 'embedding3', metadata=meta)
Expand Down

0 comments on commit 9802be0

Please sign in to comment.