Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions flax/metrics/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,37 @@ def scalar(self, tag, value, step):
with self._event_writer.as_default():
tf.summary.scalar(name=tag, data=value, step=step)

def image(self, tag, image, step):
def image(self, tag, image, step, max_outputs=3):
"""Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3].

Args:
tag: str: label for this data
image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors.
image: ndarray: [H,W], [H,W,1], [H,W,3], [K,H,W], [K,H,W,1], [K,H,W,3]
Save image in greyscale or colors.
Pixel values could be either uint8 or float.
Floating point values should be in range [0, 1).
step: int: training step
max_outputs: At most this many images will be emitted at each step.
Defaults to 3.
"""
image = np.array(image)
# tf.summary.image expects image to have shape [k, h, w, c] where,
# k = number of samples, h = height, w = width, c = number of channels.
if len(np.shape(image)) == 2:
image = image[:, :, np.newaxis]
image = image[np.newaxis, :, :, np.newaxis]
elif len(np.shape(image)) == 3:
# this could be either [k, h, w] or [h, w, c]
if np.shape(image)[-1] in (1, 3):
image = image[np.newaxis, :, :, :]
else:
image = image[:, :, :, np.newaxis]
if np.shape(image)[-1] == 1:
image = np.repeat(image, 3, axis=-1)
# tf.summary.image expects image to have shape [k, h, w, c] where,
# k = number of samples, h = height, w = width, c = number of channels.
image = image[np.newaxis, :, :, :]

# Convert to tensor value as tf.summary.image expects data to be a tensor.
image = tf.convert_to_tensor(image)
with self._event_writer.as_default():
tf.summary.image(name=tag, data=image, step=step)
tf.summary.image(name=tag, data=image, step=step, max_outputs=max_outputs)

def audio(self, tag, audiodata, step, sample_rate=44100, max_outputs=3):
"""Saves audio as wave.
Expand Down
27 changes: 27 additions & 0 deletions tests/tensorboard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,33 @@ def test_summarywriter_single_channel_image_scaled(self):
# assert the image was increased in dimension
self.assertEqual(actual_img.shape, (30, 30, 3))

def test_summarywriter_multiple_images(self):
log_dir = tempfile.mkdtemp()
summary_writer = SummaryWriter(log_dir=log_dir)
expected_img = np.random.uniform(low=0., high=255., size=(2, 30, 30, 3))
expected_img = expected_img.astype(np.uint8)
summary_writer.image(tag='multiple_images_test', image=expected_img, step=1)
summary_value = self.parse_and_return_summary_value(path=log_dir)

self.assertEqual(summary_value.tag, 'multiple_images_test')
actual_imgs = [tf.image.decode_image(s)
for s in summary_value.tensor.string_val[2:]]
self.assertTrue(np.allclose(np.stack(actual_imgs, axis=0), expected_img))

def test_summarywriter_multiple_2dimages_scaled(self):
log_dir = tempfile.mkdtemp()
summary_writer = SummaryWriter(log_dir=log_dir)
img = np.random.uniform(low=0., high=255., size=(2, 30, 30))
img = img.astype(np.uint8)
summary_writer.image(tag='multiple_2dimages_test', image=img, step=1)
summary_value = self.parse_and_return_summary_value(path=log_dir)

self.assertEqual(summary_value.tag, 'multiple_2dimages_test')
actual_imgs = [tf.image.decode_image(s)
for s in summary_value.tensor.string_val[2:]]
# assert the images were increased in dimension
self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3))

def test_summarywriter_audio(self):
log_dir = tempfile.mkdtemp()
summary_writer = SummaryWriter(log_dir=log_dir)
Expand Down