Skip to content

Commit

Permalink
forward labels to drawing function (#415)
Browse files Browse the repository at this point in the history
* forward labels to drawing function

* add example and error handling
  • Loading branch information
dtmoodie authored and lanpa committed May 26, 2019
1 parent e0f47c1 commit 3e35c9b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
5 changes: 4 additions & 1 deletion examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
if n_iter % 10 == 0:
x = vutils.make_grid(x, normalize=True, scale_each=True)
writer.add_image('Image', x, n_iter) # Tensor
writer.add_image_with_boxes('imagebox', x, torch.Tensor([[10, 10, 40, 40], [40, 40, 60, 60]]), n_iter)
writer.add_image_with_boxes('imagebox_label', torch.ones(3, 240, 240) * 0.5,
torch.Tensor([[10, 10, 100, 100], [101, 101, 200, 200]]),
n_iter,
labels=['abcde' + str(n_iter), 'fgh' + str(n_iter)])
x = torch.zeros(sample_rate * 2)
for i in range(x.size(0)):
# sound amplitude should in [-1, 1]
Expand Down
12 changes: 6 additions & 6 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def image(tag, tensor, rescale=1, dataformats='NCHW'):
return Summary(value=[Summary.Value(tag=tag, image=image)])


def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats='CHW'):
def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats='CHW', labels=None):
'''Outputs a `Summary` protocol buffer with images.'''
tensor_image = make_np(tensor_image)
tensor_image = convert_to_HWC(tensor_image, dataformats)
Expand All @@ -234,11 +234,11 @@ def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats='CHW'):
np.float32) * _calc_scale_factor(tensor_image)
image = make_image(tensor_image.astype(np.uint8),
rescale=rescale,
rois=tensor_boxes)
rois=tensor_boxes, labels=labels)
return Summary(value=[Summary.Value(tag=tag, image=image)])


def draw_boxes(disp_image, boxes):
def draw_boxes(disp_image, boxes, labels=None):
# xyxy format
num_boxes = boxes.shape[0]
list_gt = range(num_boxes)
Expand All @@ -248,20 +248,20 @@ def draw_boxes(disp_image, boxes):
boxes[i, 1],
boxes[i, 2],
boxes[i, 3],
display_str=None,
display_str=None if labels is None else labels[i],
color='Red')
return disp_image


def make_image(tensor, rescale=1, rois=None):
def make_image(tensor, rescale=1, rois=None, labels=None):
"""Convert an numpy representation image to Image protobuf"""
from PIL import Image
height, width, channel = tensor.shape
scaled_height = int(height * rescale)
scaled_width = int(width * rescale)
image = Image.fromarray(tensor)
if rois is not None:
image = draw_boxes(image, rois)
image = draw_boxes(image, rois, labels=labels)
image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)
import io
output = io.BytesIO()
Expand Down
13 changes: 11 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import six
import time
import logging

from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt
from .event_file_writer import EventFileWriter
Expand Down Expand Up @@ -587,15 +588,17 @@ def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataforma
image(tag, img_tensor, dataformats=dataformats), global_step, walltime)

def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
walltime=None, dataformats='CHW', **kwargs):
walltime=None, dataformats='CHW', labels=None, **kwargs):
"""Add image and draw bounding boxes on the image.
Args:
tag (string): Data identifier
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data
box_tensor (torch.Tensor, numpy.array, or string/blobname): Box data (for detected objects)
box should be represented as [x1, y1, x2, y2].
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time()) of event
labels (list of string): The strings to be show on each bounding box.
Shape:
img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformat`` agrument.
e.g. CHW or HWC
Expand All @@ -607,8 +610,14 @@ def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
img_tensor = workspace.FetchBlob(img_tensor)
if self._check_caffe2_blob(box_tensor):
box_tensor = workspace.FetchBlob(box_tensor)
if labels is not None:
if isinstance(labels, str):
labels = [labels]
if len(labels) != box_tensor.shape[0]:
logging.warning('Number of labels do not equal to number of box, skip the labels.')
labels = None
self._get_file_writer().add_summary(image_boxes(
tag, img_tensor, box_tensor, dataformats=dataformats, **kwargs), global_step, walltime)
tag, img_tensor, box_tensor, dataformats=dataformats, labels=labels, **kwargs), global_step, walltime)

def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
"""Render matplotlib figure into an image and add it to summary.
Expand Down

0 comments on commit 3e35c9b

Please sign in to comment.