Skip to content

Commit

Permalink
Image api cleanup (#460)
Browse files Browse the repository at this point in the history
* clean up image api

* add test case
  • Loading branch information
lanpa committed Jun 27, 2019
1 parent 059d088 commit 96b86b7
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 31 deletions.
33 changes: 13 additions & 20 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')


def _calc_scale_factor(tensor):
converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
return 1 if converted.dtype == np.uint8 else 255


def _clean_tag(name):
# In the past, the first argument to summary ops was a tag, which allowed
# arbitrary characters. Now we are changing the first argument to be the node
Expand Down Expand Up @@ -190,19 +185,15 @@ def make_histogram(values, bins, max_bins=None):
bucket=counts.tolist())


def image(tag, tensor, rescale=1, dataformats='NCHW'):
def image(tag, tensor, rescale=1, dataformats='CHW'):
"""Outputs a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 3-D with shape `[height, width,
channels]` and where `channels` can be:
* 1: `tensor` is interpreted as Grayscale.
* 3: `tensor` is interpreted as RGB.
* 4: `tensor` is interpreted as RGBA.
The `name` in the outputted Summary.Value protobufs is generated based on the
name, with a suffix depending on the max_outputs setting:
* If `max_outputs` is 1, the summary value tag is '*name*/image'.
* If `max_outputs` is greater than 1, the summary value tags are
generated sequentially as '*name*/image/0', '*name*/image/1', etc.
Args:
tag: A name for the generated node. Will also serve as a series name in
TensorBoard.
Expand All @@ -219,9 +210,9 @@ def image(tag, tensor, rescale=1, dataformats='NCHW'):
tensor = make_np(tensor)
tensor = convert_to_HWC(tensor, dataformats)
# Do not assume that user passes in values in [0, 255], use data type to detect
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).astype(np.uint8)
if tensor.dtype != np.uint8:
tensor = (tensor * 255.0).astype(np.uint8)

image = make_image(tensor, rescale=rescale)
return Summary(value=[Summary.Value(tag=tag, image=image)])

Expand All @@ -231,9 +222,11 @@ def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats='CHW', l
tensor_image = make_np(tensor_image)
tensor_image = convert_to_HWC(tensor_image, dataformats)
tensor_boxes = make_np(tensor_boxes)
tensor_image = tensor_image.astype(
np.float32) * _calc_scale_factor(tensor_image)
image = make_image(tensor_image.astype(np.uint8),

if tensor_image.dtype != np.uint8:
tensor_image = (tensor_image * 255.0).astype(np.uint8)

image = make_image(tensor_image,
rescale=rescale,
rois=tensor_boxes, labels=labels)
return Summary(value=[Summary.Value(tag=tag, image=image)])
Expand Down Expand Up @@ -280,9 +273,9 @@ def video(tag, tensor, fps=4):
tensor = make_np(tensor)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
scale_factor = _calc_scale_factor(tensor)
tensor = tensor.astype(np.float32)
tensor = (tensor * scale_factor).astype(np.uint8)
if tensor.dtype != np.uint8:
tensor = (tensor * 255.0).astype(np.uint8)

video = make_video(tensor, fps)
return Summary(value=[Summary.Value(tag=tag, image=video)])

Expand Down
4 changes: 2 additions & 2 deletions tensorboardX/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ def make_grid(I, ncols=8):
I, np.ndarray), 'plugin error, should pass numpy array here'
if I.shape[1] == 1:
I = np.concatenate([I, I, I], 1)
assert I.ndim == 4 and I.shape[1] == 3
assert I.ndim == 4 and I.shape[1] == 3 or I.shape[1] == 4
nimg = I.shape[0]
H = I.shape[2]
W = I.shape[3]
ncols = min(nimg, ncols)
nrows = int(np.ceil(float(nimg) / ncols))
canvas = np.zeros((3, H * nrows, W * ncols))
canvas = np.zeros((I.shape[1], H * nrows, W * ncols))
i = 0
for y in range(nrows):
for x in range(ncols):
Expand Down
16 changes: 11 additions & 5 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,16 @@ def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformat
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
img_tensor (torch.Tensor, numpy.array, or string/blobname): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is 1, 3, or 4.
The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8).
Users are responsible to scale the data in the correct range/type.
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
walltime (float): Optional override default walltime (time.time()) of event.
dataformats (string): This parameter specifies the meaning of each dimension of the input tensor.
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
convert a batch of tensor into 3xHxW format or use ``add_images()`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
Expand Down Expand Up @@ -542,14 +546,16 @@ def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformat
image(tag, img_tensor, dataformats=dataformats), global_step, walltime)

def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
"""Add batched image data to summary.
Besides pass 4D tensor, you can also pass a list of tensors of the same size.
"""Add batched (4D) image data to summary.
Besides passing 4D (NCHW) tensor, you can also pass a list of tensors of the same size.
In this case, the ``dataformats`` should be `CHW` or `HWC`.
Note that this requires the ``pillow`` package.
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
The elements in img_tensor can either have values in [0, 1] (float32) or [0, 255] (uint8).
Users are responsible to scale the data in the correct range/type.
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
Shape:
Expand Down
9 changes: 9 additions & 0 deletions tests/expect/test_summary.test_float32_image.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
value {
tag: "dummy"
image {
height: 32
width: 32
colorspace: 3
encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000DIDATx\234cd``\370OK\300\370\340\301\003\232Z\3002j\301\360\267\200QAA\201\266\026\214\346\203Q\013\006\277\005\243\371\200 \030\372\221<j\001A0\232\017\010\202\241\037\311\243\026\020\0044\317\007\000]7\325\342\027k\025c\000\000\000\000IEND\256B`\202"
}
}
9 changes: 9 additions & 0 deletions tests/expect/test_summary.test_image_with_four_channel.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
value {
tag: "dummy"
image {
height: 8
width: 8
colorspace: 4
encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\006\000\000\000\304\017\276\213\000\000\000\036IDATx\234cd8\320\340\360\037\017`\371\361\343\307\217\037\204\024\0204a\260+\000\000\240\302\373\327\246\231O\'\000\000\000\000IEND\256B`\202"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
value {
tag: "dummy"
image {
height: 8
width: 16
colorspace: 4
encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\006\000\000\000\360v\177\227\000\000\000-IDATx\234cd8\320\340\360\037\017`ggg\307\'\317\362\343\307\217\037?\360(\370\001\305x\r\300g\003!0j\000\025\014\000\000\356b\366\370\366\336\316\301\000\000\000\000IEND\256B`\202"
}
}
9 changes: 9 additions & 0 deletions tests/expect/test_summary.test_uint8_image.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
value {
tag: "dummy"
image {
height: 32
width: 32
colorspace: 3
encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000CIDATx\234cd```\244)PPP\240\251\371,\243\026\014\177\013\030\037<x@[\013F\363\301\250\005\203\337\202\321|@\020\014\375H\036\265\2000\030\315\007\204\300\320\217\344Q\013\010\003Z\347\003\000\211\014\037}z\035\001}\000\000\000\000IEND\256B`\202"
}
}
20 changes: 16 additions & 4 deletions tests/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@ def test_uint8_image(self):
Tests that uint8 image (pixel values in [0, 255]) is not changed
'''
test_image = tensor_N(shape=(3, 32, 32), dtype=np.uint8)
scale_factor = summary._calc_scale_factor(test_image)
assert scale_factor == 1, 'Values are already in [0, 255], scale factor should be 1'
compare_proto(summary.image('dummy', test_image), self)

def test_float32_image(self):
'''
Tests that float32 image (pixel values in [0, 1]) are scaled correctly
to [0, 255]
'''
test_image = tensor_N(shape=(3, 32, 32))
scale_factor = summary._calc_scale_factor(test_image)
assert scale_factor == 255, 'Values are in [0, 1], scale factor should be 255'
compare_proto(summary.image('dummy', test_image), self)

def test_float_1_converts_to_uint8_255(self):
green_uint8 = np.array([[[0, 255, 0]]], dtype='uint8')
green_float32 = np.array([[[0, 1, 0]]], dtype='float32')

a = summary.image(tensor=green_uint8, tag='')
b = summary.image(tensor=green_float32, tag='')
self.assertEqual(a, b)

def test_list_input(self):
with pytest.raises(Exception):
Expand All @@ -46,12 +52,18 @@ def test_image_with_boxes(self):
def test_image_with_one_channel(self):
compare_proto(summary.image('dummy', tensor_N(shape=(1, 8, 8)), dataformats='CHW'), self)

def test_image_with_four_channel(self):
compare_proto(summary.image('dummy', tensor_N(shape=(4, 8, 8)), dataformats='CHW'), self)

def test_image_with_one_channel_batched(self):
compare_proto(summary.image('dummy', tensor_N(shape=(2, 1, 8, 8)), dataformats='NCHW'), self)

def test_image_with_3_channel_batched(self):
compare_proto(summary.image('dummy', tensor_N(shape=(2, 3, 8, 8)), dataformats='NCHW'), self)

def test_image_with_four_channel_batched(self):
compare_proto(summary.image('dummy', tensor_N(shape=(2, 4, 8, 8)), dataformats='NCHW'), self)

def test_image_without_channel(self):
compare_proto(summary.image('dummy', tensor_N(shape=(8, 8)), dataformats='HW'), self)

Expand Down

0 comments on commit 96b86b7

Please sign in to comment.