Skip to content

Commit

Permalink
fix #11. strip Tensorflow (^• ω •^)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Aug 14, 2017
1 parent 6a2cac5 commit 0fd6c52
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ install:
- source activate test-environment
- which python
- conda list
- pip install tensorflow
#- pip install tensorflow
- pip install --upgrade pytest
- python setup.py install

Expand Down
24 changes: 10 additions & 14 deletions tensorboard/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def make_pbtxt(save_path, metadata, label_img):
with open(os.path.join(save_path, 'projector_config.pbtxt'), 'w') as f:
f.write('embeddings {\n')
f.write('tensor_name: "embedding:0"\n')
f.write('tensor_path: "tensors.tsv"\n')
if metadata is not None:
f.write('metadata_path: "metadata.tsv"\n')
if label_img is not None:
Expand All @@ -35,7 +36,12 @@ def make_pbtxt(save_path, metadata, label_img):
f.write('}\n')
f.write('}\n')


def make_mat(matlist, save_path):
with open(os.path.join(save_path, 'tensors.tsv'), 'w') as f:
for x in matlist:
x = [str(i) for i in x]
f.write('\t'.join(x)+'\n')


def add_embedding(mat, save_path, metadata=None, label_img=None):
"""add embedding
Expand All @@ -51,15 +57,13 @@ def add_embedding(mat, save_path, metadata=None, label_img=None):
label_img: :math:`(N, C, H, W)`
.. note::
This function needs tensorflow installed. It invokes tensorflow to dump data.
~~This function needs tensorflow installed. It invokes tensorflow to dump data. ~~
Therefore I separate it from the SummaryWriter class. Please pass ``writer.file_writer.get_logdir()`` to ``save_path`` to prevent glitches.
If ``save_path`` is different than SummaryWritter's save path, you need to pass the leave directory to tensorboard's logdir argument,
otherwise it cannot display anything. e.g. if ``save_path`` equals 'path/to/embedding',
you need to call 'tensorboard --logdir=path/to/embedding', instead of 'tensorboard --logdir=path'.
Finally, this funtion breaks PyTorch if you have 'torch.nn.DataParallel' in your code. Use it after training completes.
See https://github.com/pytorch/pytorch/issues/2230
Examples::
Expand Down Expand Up @@ -92,15 +96,7 @@ def add_embedding(mat, save_path, metadata=None, label_img=None):
if label_img is not None:
assert mat.size(0)==label_img.size(0), '#images should equal with #data points'
make_sprite(label_img, save_path)
import tensorflow as tf
tf.reset_default_graph()
with tf.device('/cpu:0'):
emb = tf.Variable(mat.tolist(), name="embedding")
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(emb.initializer)
saver = tf.train.Saver()
saver.save(sess, save_path=os.path.join(save_path, 'model.ckpt'), global_step=None, write_meta_graph=False)
assert mat.dim()==2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat.tolist(), save_path)
make_pbtxt(save_path, metadata, label_img)

1 comment on commit 0fd6c52

@acgtyrant
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(^• ω •^)

Please sign in to comment.