Skip to content

Commit

Permalink
Fixes #471 (#621)
Browse files Browse the repository at this point in the history
* Fix add_video dataformats

* Fix doc string
  • Loading branch information
lanpa committed Apr 2, 2021
1 parent 56acb82 commit 0e4fef3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
5 changes: 3 additions & 2 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .proto.plugin_mesh_pb2 import MeshPluginData
from .proto import layout_pb2
from .x2num import make_np
from .utils import _prepare_video, convert_to_HWC
from .utils import _prepare_video, convert_to_HWC, convert_to_NTCHW

_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')

Expand Down Expand Up @@ -340,9 +340,10 @@ def make_image(tensor, rescale=1, rois=None, labels=None):
encoded_image_string=image_string)


def video(tag, tensor, fps=4):
def video(tag, tensor, fps=4, dataformats="NTCHW"):
tag = _clean_tag(tag)
tensor = make_np(tensor)
tensor = convert_to_NTCHW(tensor, input_format=dataformats)
tensor = _prepare_video(tensor)
# If user passes in uint8, then we don't need to rescale by 255
if tensor.dtype != np.uint8:
Expand Down
14 changes: 11 additions & 3 deletions tensorboardX/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ def make_grid(I, ncols=8):
i = i + 1
return canvas

# if modality == 'IMG':
# if x.dtype == np.uint8:
# x = x.astype(np.float32) / 255.0

def convert_to_NTCHW(tensor, input_format):
assert(len(input_format) == 5), "Only 5D tensor is supported."
assert(len(set(input_format)) == len(input_format)), "You can not use the same dimension shordhand twice. \
input_format: {}".format(input_format)
assert(len(tensor.shape) == len(input_format)), "size of input tensor and input format are different. \
tensor shape: {}, input_format: {}".format(tensor.shape, input_format)
input_format = input_format.upper()
index = [input_format.find(c) for c in 'NTCHW']
tensor_NTCHW = tensor.transpose(index)
return tensor_NTCHW


def convert_to_HWC(tensor, input_format): # tensor: numpy array
Expand Down
10 changes: 6 additions & 4 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,8 @@ def add_video(
vid_tensor: numpy_compatible,
global_step: Optional[int] = None,
fps: Optional[Union[int, float]] = 4,
walltime: Optional[float] = None):
walltime: Optional[float] = None,
dataformats: Optional[str] = 'NTCHW'):
"""Add video data to summary.
Note that this requires the ``moviepy`` package.
Expand All @@ -807,12 +808,13 @@ def add_video(
global_step: Global step value to record
fps: Frames per second
walltime: Optional override default walltime (time.time()) of event
dataformats: Specify different permutation of the video tensor
Shape:
vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type
`uint8` or [0, 1] for type `float`.
vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255]
for type `uint8` or [0, 1] for type `float`.
"""
self._get_file_writer().add_summary(
video(tag, vid_tensor, fps), global_step, walltime)
video(tag, vid_tensor, fps, dataformats=dataformats), global_step, walltime)

def add_audio(
self,
Expand Down

0 comments on commit 0e4fef3

Please sign in to comment.