Skip to content

Commit

Permalink
Merge pull request #110 from h0rm/master
Browse files Browse the repository at this point in the history
add_embedding: added support for multiple labels
  • Loading branch information
lanpa committed Mar 23, 2018
2 parents 29c6194 + 2c707fe commit 87dbc51
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
15 changes: 15 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,19 @@
features = images.view(100, 784)
writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))
writer.add_embedding(features, global_step=1, tag='noMetadata')
dataset = datasets.MNIST('mnist', train=True, download=True)
images_train = dataset.train_data[:100].float()
labels_train = dataset.train_labels[:100]
features_train = images_train.view(100, 784)

all_features = torch.cat((features, features_train))
all_labels = torch.cat((label, labels_train))
all_images = torch.cat((images, images_train))
dataset_label = ['test']*100 + ['train']*100
all_labels = list(zip(all_labels, dataset_label))

writer.add_embedding(all_features, metadata=all_labels, label_img=all_images.unsqueeze(1),
metadata_header=['digit', 'dataset'], global_step=2)


writer.close()
10 changes: 8 additions & 2 deletions tensorboardX/embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os


def make_tsv(metadata, save_path):
metadata = [str(x) for x in metadata]
def make_tsv(metadata, save_path, metadata_header=None):
if not metadata_header:
metadata = [str(x) for x in metadata]
else:
assert len(metadata_header) == len(metadata[0]), \
'len of header must be equal to the number of columns in metadata'
metadata = ['\t'.join(str(e) for e in l) for l in [metadata_header] + metadata]

with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
for x in metadata:
f.write(x + '\n')
Expand Down
4 changes: 2 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _encode(rawstr):
retval = retval.replace("\\", "%%%02x" % (ord("\\")))
return retval

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default'):
def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None):
"""Add embedding projector data to summary.
Args:
Expand Down Expand Up @@ -455,7 +455,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
print('warning: Embedding dir exists, did you set global_step for add_embedding()?')
if metadata is not None:
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
make_tsv(metadata, save_path)
make_tsv(metadata, save_path, metadata_header=metadata_header)
if label_img is not None:
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, save_path)
Expand Down

0 comments on commit 87dbc51

Please sign in to comment.