Skip to content

Commit

Permalink
fix #5. Unified API
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Aug 16, 2017
1 parent 3f9532f commit b4d84aa
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 144 deletions.
22 changes: 11 additions & 11 deletions demo_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tensorboard import SummaryWriter
from datetime import datetime
from torch.utils.data import TensorDataset,DataLoader
from tensorboard.embedding import EmbeddingWriter
import os

#EMBEDDING VISUALIZATION FOR A TWO-CLASSES PROBLEM
Expand Down Expand Up @@ -52,37 +51,38 @@ def get_data(value,shape):
optimizer = Adam(params=m.parameters())
#settings for train and log
num_epochs = 20
num_batches = len(gen)
embedding_log = 5
#WE NEED A WRITER! BECAUSE TB LOOK FOR IT!
writer_name = datetime.now().strftime('%B%d %H:%M:%S')
writer = SummaryWriter(os.path.join("runs",writer_name))
#our brand new embwriter in the same dir
embedding_writer = EmbeddingWriter(os.path.join("runs",writer_name))

#TRAIN
for i in range(num_epochs):
for epoch in range(num_epochs):
for j,sample in enumerate(gen):
n_iter = (epoch*len(gen))+j
#reset grad
m.zero_grad()
optimizer.zero_grad()
#get batch data
data_batch = Variable(sample[0],requires_grad=True).float()
label_batch = Variable(sample[1],requires_grad=False).long()
data_batch = Variable(sample[0], requires_grad=True).float()
label_batch = Variable(sample[1], requires_grad=False).long()
#FORWARD
out = m(data_batch)
loss_value = loss(out,label_batch)
loss_value = loss(out, label_batch)
#BACKWARD
loss_value.backward()
optimizer.step()
#LOGGING
writer.add_scalar('loss', loss_value.data[0], n_iter)

if j % embedding_log == 0:
print("loss_value:{}".format(loss_value.data[0]))
#we need 3 dimension for tensor to visualize it!
out = torch.cat((out,torch.ones(len(out),1)),1)
#write the embedding for the timestep
embedding_writer.add_embedding(out.data,metadata=label_batch.data,label_img=data_batch.data,timestep=(i*num_batches)+j)
writer.add_embedding(out.data, metadata=label_batch.data, label_img=data_batch.data, global_step=n_iter)

writer.close()

#tensorboard --logdir runs
#you should now see a dropdown list with all the timestep, latest timestep should have a visible separation between the two classes
#you should now see a dropdown list with all the timestep,
# last timestep should have a visible separation between the two classes
1 change: 0 additions & 1 deletion docs/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ tensorboard-pytorch
:members:

.. automethod:: __init__
.. autofunction:: tensorboard.embedding.add_embedding
140 changes: 8 additions & 132 deletions tensorboard/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def make_sprite(label_img, save_path):
else:
torchvision.utils.save_image(label_img, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)

def make_pbtxt(save_path, metadata, label_img):
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'w') as f:
def append_pbtxt(metadata, label_img, save_path, global_step):
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'a') as f:
#step = os.path.split(save_path)[-1]
f.write('embeddings {\n')
f.write('tensor_name: "embedding:0"\n')
f.write('tensor_path: "tensors.tsv"\n')
f.write('tensor_name: "embedding:{}"\n'.format(global_step))
f.write('tensor_path: "{}"\n'.format(os.path.join(global_step,"tensors.tsv")))
if metadata is not None:
f.write('metadata_path: "metadata.tsv"\n')
f.write('metadata_path: "{}"\n'.format(os.path.join(global_step,"metadata.tsv")))
if label_img is not None:
f.write('sprite {\n')
f.write('image_path: "sprite.png"\n')
f.write('image_path: "{}"\n'.format(os.path.join(global_step,"sprite.png")))
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
f.write('}\n')
Expand All @@ -47,129 +48,4 @@ def make_mat(matlist, save_path):
with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
for x in matlist:
x = [str(i) for i in x]
f.write('\t'.join(x) + '\n')

def add_embedding(mat, save_path, metadata=None, label_img=None):
"""add embedding
Args:
mat (torch.Tensor): A matrix which each row is the feature vector of the data point
save_path (string): Save path (use ``writer.file_writer.get_logdir()`` to show embedding along with other summaries)
metadata (list): A list of labels, each element will be convert to string
label_img (torch.Tensor): Images correspond to each data point
Shape:
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
label_img: :math:`(N, C, H, W)`
.. note::
~~This function needs tensorflow installed. It invokes tensorflow to dump data. ~~
Therefore I separate it from the SummaryWriter class. Please pass ``writer.file_writer.get_logdir()`` to ``save_path`` to prevent glitches.
If ``save_path`` is different than SummaryWritter's save path, you need to pass the leave directory to tensorboard's logdir argument,
otherwise it cannot display anything. e.g. if ``save_path`` equals 'path/to/embedding',
you need to call 'tensorboard --logdir=path/to/embedding', instead of 'tensorboard --logdir=path'.
Examples::
from tensorboard.embedding import add_embedding
import keyword
import torch
meta = []
while len(meta)<100:
meta = meta+keyword.kwlist # get some strings
meta = meta[:100]
for i, v in enumerate(meta):
meta[i] = v+str(i)
label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
label_img[i]*=i/100.0
add_embedding(torch.randn(100, 5), 'embedding1', metadata=meta, label_img=label_img)
add_embedding(torch.randn(100, 5), 'embedding2', label_img=label_img)
add_embedding(torch.randn(100, 5), 'embedding3', metadata=meta)
"""
try:
os.makedirs(save_path)
except OSError:
print('warning: dir exists')
if metadata is not None:
assert mat.size(0)==len(metadata), '#labels should equal with #data points'
make_tsv(metadata, save_path)
if label_img is not None:
assert mat.size(0)==label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, save_path)
assert mat.dim()==2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), save_path)
make_pbtxt(save_path, metadata, label_img)

def append_pbtxt(f, metadata, label_img,path):

f.write('embeddings {\n')
f.write('tensor_name: "{}"\n'.format(os.path.join(path,"embedding")))
f.write('tensor_path: "{}"\n'.format(os.path.join(path,"tensors.tsv")))
if metadata is not None:
f.write('metadata_path: "{}"\n'.format(os.path.join(path,"metadata.tsv")))
if label_img is not None:
f.write('sprite {\n')
f.write('image_path: "{}"\n'.format(os.path.join(path,"sprite.png")))
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
f.write('}\n')
f.write('}\n')


class EmbeddingWriter(object):
"""
Class to allow writing embeddings ad defined timestep
"""
def __init__(self,save_path):
"""
:param save_path: should be the same path of you SummaryWriter
"""
self.save_path = save_path
#make dir if needed, it should not
try:
os.makedirs(save_path)
except OSError:
print('warning: dir exists')
#create config file to store all embeddings conf
self.f = open(os.path.join(save_path, 'projector_config.pbtxt'), 'w')

def add_embedding(self,mat, metadata=None, label_img=None,timestep=0):
"""
add an embedding at the defined timestep
:param mat:
:param metadata:
:param label_img:
:param timestep:
:return:
"""
# TODO make doc
#path to the new subdir
timestep_path = "{}".format(timestep)
# TODO should this be handled?
os.makedirs(os.path.join(self.save_path,timestep_path))
#check other info
#save all this metadata in the new subfolder
if metadata is not None:
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
make_tsv(metadata, os.path.join(self.save_path,timestep_path))
if label_img is not None:
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, os.path.join(self.save_path,timestep_path))
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), os.path.join(self.save_path,timestep_path))
#new funcion to append to the config file a new embedding
append_pbtxt(self.f, metadata, label_img,timestep_path)


def __del__(self):
#close the file at the end of the script
self.f.close()
f.write('\t'.join(x) + '\n')
52 changes: 52 additions & 0 deletions tensorboard/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .event_file_writer import EventFileWriter
from .summary import scalar, histogram, image, audio, text
from .graph import graph
from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt


class SummaryToEventTransformer(object):
Expand Down Expand Up @@ -329,6 +330,57 @@ def add_graph(self, model, lastVar):
return
self.file_writer.add_graph(graph(model, lastVar))

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None):
"""add embedding
Args:
mat (torch.Tensor): A matrix which each row is the feature vector of the data point
metadata (list): A list of labels, each element will be convert to string
label_img (torch.Tensor): Images correspond to each data point
global_step (int): Global step value to record
Shape:
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
label_img: :math:`(N, C, H, W)`
Examples::
import keyword
import torch
meta = []
while len(meta)<100:
meta = meta+keyword.kwlist # get some strings
meta = meta[:100]
for i, v in enumerate(meta):
meta[i] = v+str(i)
label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
label_img[i]*=i/100.0
writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
writer.add_embedding(torch.randn(100, 5), label_img=label_img)
writer.add_embedding(torch.randn(100, 5), metadata=meta)
"""
if global_step == None:
global_step = 0
# clear pbtxt?
save_path = os.path.join(self.file_writer.get_logdir(), str(global_step).zfill(5))
try:
os.makedirs(save_path)
except OSError:
print('warning: Embedding dir exists, did you set global_step for add_embedding()?')
if metadata is not None:
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
make_tsv(metadata, save_path)
if label_img is not None:
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, save_path)
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), save_path)
#new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5))

def close(self):
self.file_writer.flush()
Expand Down

1 comment on commit b4d84aa

@lucabergamini
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great works :) nice choices both for ordering subfolder and to manage the projector_config

Please sign in to comment.