Skip to content

Commit

Permalink
fixes #484 (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Aug 21, 2019
1 parent 526ad95 commit 3e87c9f
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
18 changes: 14 additions & 4 deletions tensorboardX/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,22 @@ def make_sprite(label_img, save_path):
from PIL import Image
# this ensures the sprite image has correct dimension as described in
# https://www.tensorflow.org/get_started/embedding_viz
nrow = int(math.ceil((label_img.size(0)) ** 0.5))
arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
# There are some constraints for the sprite image:
# 1. The sprite image should be square.
# 2. Each image patch in the sprite image should be square.
# 2. The content is row major order, so we can padding the image on the
# bottom, but not on the right, otherwise, TB will treat some padded location
# as images to be shown.
# args: label_img: tensor in NCHW

# augment images so that #images equals nrow*nrow
arranged_augment_square_HWC = np.ndarray((arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3))
assert label_img.shape[2] == label_img.shape[3], 'Image should be square, see tensorflow/tensorboard#670'
total_pixels = label_img.shape[0] * label_img.shape[2] * label_img.shape[3]
pixels_one_side = total_pixels ** 0.5
number_of_images_per_row = int(math.ceil(pixels_one_side / label_img.shape[3]))
arranged_img_CHW = make_grid(make_np(label_img), ncols=number_of_images_per_row)
arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc

arranged_augment_square_HWC = np.ndarray((arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3))
arranged_augment_square_HWC[:arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
im.save(os.path.join(save_path, 'sprite.png'))
Expand Down
2 changes: 1 addition & 1 deletion tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def hparams(hparam_dict=None, metric_dict=None):
continue

if not isinstance(v, int) or not isinstance(v, float):
v = make_np(v)[0]
v = make_np(v)[0]
ssi.hparams[k].number_value = v

content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
Expand Down
7 changes: 4 additions & 3 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,13 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
Args:
mat (torch.Tensor or numpy.array): A matrix which each row is the feature vector of the data point
metadata (list): A list of labels, each element will be convert to string
label_img (torch.Tensor): Images correspond to each data point
label_img (torch.Tensor): Images correspond to each data point. Each image should be square.
global_step (int): Global step value to record
tag (string): Name for the embedding
Shape:
mat: :math:`(N, D)`, where N is number of data and D is feature dimension
label_img: :math:`(N, C, H, W)`
label_img: :math:`(N, C, H, W)`, where `Height` should be equal to `Width`.
Examples::
Expand All @@ -829,7 +829,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
for i, v in enumerate(meta):
meta[i] = v+str(i)
label_img = torch.rand(100, 3, 10, 32)
label_img = torch.rand(100, 3, 32, 32)
for i in range(100):
label_img[i]*=i/100.0
Expand Down Expand Up @@ -857,6 +857,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
make_tsv(metadata, save_path, metadata_header=metadata_header)
if label_img is not None:
assert mat.shape[0] == label_img.shape[0], '#images should equal with #data points'
assert label_img.shape[2] == label_img.shape[3], 'Image should be square, see tensorflow/tensorboard#670'
make_sprite(label_img, save_path)
assert mat.ndim == 2, 'mat should be 2D, where mat.size(0) is the number of data points'
make_mat(mat, save_path)
Expand Down
21 changes: 21 additions & 0 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ def test_embedding_64(self):
label_img=all_images,
metadata_header=['digit', 'dataset'],
global_step=2)

def test_embedding_square(self):
w = SummaryWriter(comment='sq')
all_features = torch.rand(228,256)
all_images = torch.rand(228, 3, 32, 32)
for i in range(all_images.shape[0]):
all_images[i] *= (float(i)+60)/(all_images.shape[0]+60)
w.add_embedding(all_features,
label_img=all_images,
global_step=2)

def test_embedding_fail(self):
with self.assertRaises(AssertionError):
w = SummaryWriter(comment='shouldfail')
all_features = torch.rand(228,256)
all_images = torch.rand(228, 3, 16, 32)
for i in range(all_images.shape[0]):
all_images[i] *= (float(i)+60)/(all_images.shape[0]+60)
w.add_embedding(all_features,
label_img=all_images,
global_step=2)

0 comments on commit 3e87c9f

Please sign in to comment.