Skip to content

Commit

Permalink
fix #10
Browse files Browse the repository at this point in the history
Show many embeddings at once
  • Loading branch information
lucabergamini authored and lanpa committed Aug 16, 2017
1 parent 9802be0 commit 3f9532f
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 16 deletions.
100 changes: 84 additions & 16 deletions demo_embedding.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,88 @@
from tensorboard.embedding import add_embedding
import keyword
import torch
meta = []
while len(meta)<100:
meta = meta+keyword.kwlist
meta = meta[:100]
import torch.nn as nn
from torch.optim import Adam
from torch.autograd.variable import Variable
import torch.nn.functional as F
from collections import OrderedDict
from tensorboard import SummaryWriter
from datetime import datetime
from torch.utils.data import TensorDataset,DataLoader
from tensorboard.embedding import EmbeddingWriter
import os

for i, v in enumerate(meta):
meta[i] = v+str(i)
#EMBEDDING VISUALIZATION FOR A TWO-CLASSES PROBLEM

label_img = torch.rand(100, 3, 10, 32)
for i in range(100):
label_img[i]*=i/100.0

add_embedding(torch.randn(100, 5), save_path='embedding1', metadata=meta, label_img=label_img)
add_embedding(torch.randn(100, 5), save_path='embedding2', label_img=label_img)
add_embedding(torch.randn(100, 5), save_path='embedding3', metadata=meta)
#just a bunch of layers
class M(nn.Module):
def __init__(self):
super(M,self).__init__()
self.cn1 = nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3)
self.cn2 = nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3)
self.fc1 = nn.Linear(in_features=128,out_features=2)
def forward(self,i):
i = self.cn1(i)
i = F.relu(i)
i = F.max_pool2d(i,2)
i =self.cn2(i)
i = F.relu(i)
i = F.max_pool2d(i,2)
i = i.view(len(i),-1)
i = self.fc1(i)
i = F.log_softmax(i)
return i

#tensorboard --logdir embedding1
#get some random data around value
def get_data(value,shape):
data= torch.ones(shape)*value
#add some noise
data += torch.randn(shape)**2
return data

#dataset
#cat some data with different values
data = torch.cat((get_data(0,(100,1,14,14)),get_data(0.5,(100,1,14,14))),0)
#labels
labels = torch.cat((torch.zeros(100),torch.ones(100)),0)
#generator
gen = DataLoader(TensorDataset(data,labels),batch_size=25,shuffle=True)
#network
m = M()
#loss and optim
loss = torch.nn.NLLLoss()
optimizer = Adam(params=m.parameters())
#settings for train and log
num_epochs = 20
num_batches = len(gen)
embedding_log = 5
#WE NEED A WRITER! BECAUSE TB LOOK FOR IT!
writer_name = datetime.now().strftime('%B%d %H:%M:%S')
writer = SummaryWriter(os.path.join("runs",writer_name))
#our brand new embwriter in the same dir
embedding_writer = EmbeddingWriter(os.path.join("runs",writer_name))
#TRAIN
for i in range(num_epochs):
for j,sample in enumerate(gen):
#reset grad
m.zero_grad()
optimizer.zero_grad()
#get batch data
data_batch = Variable(sample[0],requires_grad=True).float()
label_batch = Variable(sample[1],requires_grad=False).long()
#FORWARD
out = m(data_batch)
loss_value = loss(out,label_batch)
#BACKWARD
loss_value.backward()
optimizer.step()
#LOGGING
if j % embedding_log == 0:
print("loss_value:{}".format(loss_value.data[0]))
#we need 3 dimension for tensor to visualize it!
out = torch.cat((out,torch.ones(len(out),1)),1)
#write the embedding for the timestep
embedding_writer.add_embedding(out.data,metadata=label_batch.data,label_img=data_batch.data,timestep=(i*num_batches)+j)

writer.close()

#tensorboard --logdir runs
#you should now see a dropdown list with all the timestep, latest timestep should have a visible separation between the two classes
67 changes: 67 additions & 0 deletions tensorboard/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,70 @@ def add_embedding(mat, save_path, metadata=None, label_img=None):
make_mat(mat.tolist(), save_path)
make_pbtxt(save_path, metadata, label_img)

def append_pbtxt(f, metadata, label_img,path):

f.write('embeddings {\n')
f.write('tensor_name: "{}"\n'.format(os.path.join(path,"embedding")))
f.write('tensor_path: "{}"\n'.format(os.path.join(path,"tensors.tsv")))
if metadata is not None:
f.write('metadata_path: "{}"\n'.format(os.path.join(path,"metadata.tsv")))
if label_img is not None:
f.write('sprite {\n')
f.write('image_path: "{}"\n'.format(os.path.join(path,"sprite.png")))
f.write('single_image_dim: {}\n'.format(label_img.size(3)))
f.write('single_image_dim: {}\n'.format(label_img.size(2)))
f.write('}\n')
f.write('}\n')


class EmbeddingWriter(object):
"""
Class to allow writing embeddings ad defined timestep
"""
def __init__(self,save_path):
"""
:param save_path: should be the same path of you SummaryWriter
"""
self.save_path = save_path
#make dir if needed, it should not
try:
os.makedirs(save_path)
except OSError:
print('warning: dir exists')
#create config file to store all embeddings conf
self.f = open(os.path.join(save_path, 'projector_config.pbtxt'), 'w')

def add_embedding(self,mat, metadata=None, label_img=None,timestep=0):
"""
add an embedding at the defined timestep
:param mat:
:param metadata:
:param label_img:
:param timestep:
:return:
"""
# TODO make doc
#path to the new subdir
timestep_path = "{}".format(timestep)
# TODO should this be handled?
os.makedirs(os.path.join(self.save_path,timestep_path))
#check other info
#save all this metadata in the new subfolder
if metadata is not None:
assert mat.size(0) == len(metadata), '#labels should equal with #data points'
make_tsv(metadata, os.path.join(self.save_path,timestep_path))
if label_img is not None:
assert mat.size(0) == label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, os.path.join(self.save_path,timestep_path))
assert mat.dim() == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), os.path.join(self.save_path,timestep_path))
#new funcion to append to the config file a new embedding
append_pbtxt(self.f, metadata, label_img,timestep_path)


def __del__(self):
#close the file at the end of the script
self.f.close()

0 comments on commit 3f9532f

Please sign in to comment.