Skip to content

Commit

Permalink
made optional data format args actually optional
Browse files Browse the repository at this point in the history
  • Loading branch information
ttf2050 committed Jun 16, 2024
1 parent ce08ab4 commit 8eae8a8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
11 changes: 7 additions & 4 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .proto.plugin_mesh_pb2 import MeshPluginData
from .proto import layout_pb2
from .x2num import make_np
from .utils import _prepare_video, convert_to_HWC, convert_to_NTCHW
from .utils import _prepare_video, convert_to_HWC, convert_to_NTCHW, infer_image_format
logger = logging.getLogger(__name__)

_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
Expand Down Expand Up @@ -257,7 +257,7 @@ def make_histogram(values, bins, max_bins=None):
bucket=counts.tolist())


def image(tag, tensor, rescale=1, dataformats='CHW'):
def image(tag, tensor, rescale=1, dataformats=None):
"""Outputs a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
images are built from `tensor` which must be 3-D with shape `[height, width,
Expand All @@ -280,6 +280,7 @@ def image(tag, tensor, rescale=1, dataformats='CHW'):
"""
tag = _clean_tag(tag)
tensor = make_np(tensor)
dataformats = dataformats or infer_image_format(tensor)
tensor = convert_to_HWC(tensor, dataformats)
# Do not assume that user passes in values in [0, 255], use data type to detect
if tensor.dtype != np.uint8:
Expand All @@ -289,9 +290,10 @@ def image(tag, tensor, rescale=1, dataformats='CHW'):
return Summary(value=[Summary.Value(tag=tag, image=image)])


def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats='CHW', labels=None):
def image_boxes(tag, tensor_image, tensor_boxes, rescale=1, dataformats=None, labels=None):
'''Outputs a `Summary` protocol buffer with images.'''
tensor_image = make_np(tensor_image)
dataformats = dataformats or infer_image_format(tensor)
tensor_image = convert_to_HWC(tensor_image, dataformats)
tensor_boxes = make_np(tensor_boxes)

Expand Down Expand Up @@ -346,9 +348,10 @@ def make_image(tensor, rescale=1, rois=None, labels=None):
encoded_image_string=image_string)


def video(tag, tensor, fps=4, dataformats="NTCHW"):
def video(tag, tensor, fps=4, dataformats=None):
tag = _clean_tag(tag)
tensor = make_np(tensor)
dataformats = dataformats or infer_image_format(tensor)
tensor = convert_to_NTCHW(tensor, input_format=dataformats)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
Expand Down
25 changes: 25 additions & 0 deletions tensorboardX/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,28 @@ def convert_to_HWC(tensor, input_format): # tensor: numpy array
tensor = tensor.transpose(index)
tensor = np.stack([tensor, tensor, tensor], 2)
return tensor

def infer_image_format(tensor):
"""
Attempt to infer the image format from the data.
This function can really only detect the color channel, and assumes that
all other channels are in the order NTHW (if present).
"""
import numpy as np
tensor: np.ndarray = tensor
if tensor.ndim == 2:
return "HW"
elif tensor.ndim == 3:
# Hopefully we can determine that exactly one of these is the channel
C_first = tensor.shape[0] in [1,3,4]
C_last = tensor.shape[2] in [1,3,4]
assert C_first != C_last, "Could not uniquely determine the color channel index: \
you must spedicify the format yourself"
return "CHW" if C_first else "HWC"
elif tensor.ndim == 4:
return 'N'+ infer_image_format(tensor[0])
elif tensor.ndim == 5:
return 'NT'+ infer_image_format(tensor[0,0])
else:
raise NotImplementedError("Unable to infer image format for arrays with <2 or >5 dimensions.")
15 changes: 6 additions & 9 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def add_image(
img_tensor: numpy_compatible,
global_step: Optional[int] = None,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'CHW'):
dataformats: Optional[str] = None):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Expand Down Expand Up @@ -694,7 +694,7 @@ def add_images(
img_tensor: numpy_compatible,
global_step: Optional[int] = None,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'NCHW'):
dataformats: Optional[str] = None):
"""Add batched (4D) image data to summary.
Besides passing 4D (NCHW) tensor, you can also pass a list of tensors of the same size.
In this case, the ``dataformats`` should be `CHW` or `HWC`.
Expand Down Expand Up @@ -734,18 +734,15 @@ def add_images(
if self._check_caffe2_blob(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor)
if isinstance(img_tensor, list): # a list of tensors in CHW or HWC
if dataformats.upper() != 'CHW' and dataformats.upper() != 'HWC':
print('A list of image is passed, but the dataformat is neither CHW nor HWC.')
print('Nothing is written.')
return
import torch
try:
img_tensor = torch.stack(img_tensor, 0)
except TypeError as e:
import numpy as np
img_tensor = np.stack(img_tensor, 0)

dataformats = 'N' + dataformats
if dataformats is not None and len(dataformats) == 3:
dataformats = 'N' + dataformats

summary = image(tag, img_tensor, dataformats=dataformats)
encoded_image_string = summary.value[0].image.encoded_image_string
Expand All @@ -760,7 +757,7 @@ def add_image_with_boxes(
box_tensor: numpy_compatible,
global_step: Optional[int] = None,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'CHW',
dataformats: Optional[str] = None,
labels: Optional[List[str]] = None,
**kwargs):
"""Add image and draw bounding boxes on the image.
Expand Down Expand Up @@ -827,7 +824,7 @@ def add_video(
global_step: Optional[int] = None,
fps: Optional[Union[int, float]] = 4,
walltime: Optional[float] = None,
dataformats: Optional[str] = 'NTCHW'):
dataformats: Optional[str] = None):
"""Add video data to summary.
Note that this requires the ``moviepy`` package.
Expand Down

0 comments on commit 8eae8a8

Please sign in to comment.