Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Fix issue when logging images in tensorboardX > 1.12
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Oct 8, 2018
1 parent 28d7c16 commit 1f9831f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions inferno/trainers/callbacks/logging/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,15 @@ def log_images(self, tag, images, step, image_format='CHW'):
pass
else:
raise RuntimeError

# FIXME in tensorboardX > 1.12 this will lead to some error.
# unfortunately tensorboardX does not have a __version__ attribute
# so I don't see how to check for the version and provide backwards
# compatability here
# tensorboardX borks if the number of image channels is not 3
if image.shape[-1] == 1:
image = image[..., [0, 0, 0]]
# if image.shape[-1] == 1:
# image = image[..., [0, 0, 0]]

image = self._normalize_image(image)
# print(image.dtype, image.shape)
self.writer.add_image(tag, img_tensor=image, global_step=step)
Expand Down

0 comments on commit 1f9831f

Please sign in to comment.