Skip to content

Commit

Permalink
Prevent saving too large sprites in add_embedding (#524)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdpino authored and lanpa committed Oct 22, 2019
1 parent d8c3aa7 commit 3b01f80
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tensorboardX/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

# Maximum sprite size allowed by TB frontend,
# see https://github.com/lanpa/tensorboardX/issues/516
TB_MAX_SPRITE_SIZE = 8192

def make_tsv(metadata, save_path, metadata_header=None):
if not metadata_header:
Expand Down Expand Up @@ -44,7 +47,9 @@ def make_sprite(label_img, save_path):
arranged_img_CHW = make_grid(make_np(label_img), ncols=number_of_images_per_row)
arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc

arranged_augment_square_HWC = np.ndarray((arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3))
sprite_size = arranged_img_CHW.shape[2]
assert sprite_size <= TB_MAX_SPRITE_SIZE, 'Sprite too large, see label_img shape limits'
arranged_augment_square_HWC = np.ndarray((sprite_size, sprite_size, 3))
arranged_augment_square_HWC[:arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
im.save(os.path.join(save_path, 'sprite.png'))
Expand Down
7 changes: 6 additions & 1 deletion tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,13 +834,18 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
Args:
mat (torch.Tensor or numpy.array): A matrix which each row is the feature vector of the data point
metadata (list): A list of labels, each element will be convert to string
label_img (torch.Tensor or numpy.array): Images correspond to each data point. Each image should be square.
label_img (torch.Tensor or numpy.array): Images correspond to each data point. Each image should
be square. The amount and size of the images are limited by the Tensorboard frontend,
see limits below.
global_step (int): Global step value to record
tag (string): Name for the embedding
Shape:
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
label_img: :math:`(N, C, H, W)`, where `Height` should be equal to `Width`.
Also, :math:`\sqrt{N}*W` must be less than or equal to 8192, so that the generated sprite
image can be loaded by the Tensorboard frontend
(see `tensorboardX#516 <https://github.com/lanpa/tensorboardX/issues/516>`_ for more).
Examples::
Expand Down

0 comments on commit 3b01f80

Please sign in to comment.