Skip to content

Commit

Permalink
closes #24
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Sep 24, 2017
1 parent 3eac9b5 commit 8f8d6c9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@
label = dataset.test_labels[:100]
features = images.view(100, 784)
writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))
writer.add_embedding(features, global_step=1, tag='noMetadata')
writer.close()
2 changes: 1 addition & 1 deletion demo_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_data(value, shape):
#settings for train and log
num_epochs = 20
embedding_log = 5
writer = SummaryWriter()
writer = SummaryWriter(comment='mnist_embedding_training')

#TRAIN
for epoch in range(num_epochs):
Expand Down
4 changes: 2 additions & 2 deletions tensorboardX/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def make_sprite(label_img, save_path):
else:
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)

def append_pbtxt(metadata, label_img, save_path, global_step):
def append_pbtxt(metadata, label_img, save_path, global_step, tag):
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
#step = os.path.split(save_path)[-1]
f.write('embeddings {\n')
f.write('tensor_name: "embedding:{}"\n'.format(global_step))
f.write('tensor_name: "{}:{}"\n'.format(tag, global_step))
f.write('tensor_path: "{}"\n'.format(os.path.join(global_step,"tensors.tsv")))
if metadata is not None:
f.write('metadata_path: "{}"\n'.format(os.path.join(global_step,"metadata.tsv")))
Expand Down
5 changes: 3 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,15 @@ def add_graph(self, model, lastVar):
return
self.file_writer.add_graph(graph(model, lastVar))

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None):
def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default'):
"""Add embedding projector data to summary.
Args:
mat (torch.Tensor): A matrix which each row is the feature vector of the data point
metadata (list): A list of labels, each element will be convert to string
label_img (torch.Tensor): Images correspond to each data point
global_step (int): Global step value to record
tag (string): Name for the embedding
Shape:
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
Expand Down Expand Up @@ -426,7 +427,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None):
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), save_path)
#new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5))
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5), tag)

def close(self):
self.file_writer.flush()
Expand Down

0 comments on commit 8f8d6c9

Please sign in to comment.