Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale images to [0,1] range #1664

Merged
merged 12 commits into from
Feb 29, 2024
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
## [UNRELEASED] neptune 1.9.2
## [UNRELEASED] neptune 1.10.0

### Features
- Added auto-scaling pixel values for image logging ([#1664](https://github.com/neptune-ai/neptune-client/pull/1664))

### Fixes
- Restored support for SSL verification exception ([#1661](https://github.com/neptune-ai/neptune-client/pull/1661))
Expand Down
66 changes: 40 additions & 26 deletions src/neptune/internal/utils/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

__all__ = [
"get_image_content",
"get_html_content",
Expand All @@ -37,6 +39,7 @@
)
from typing import Optional

import numpy as np
from packaging import version
from pandas import DataFrame

Expand All @@ -45,6 +48,7 @@

logger = get_logger()
SEABORN_GRID_CLASSES = {"FacetGrid", "PairGrid", "JointGrid"}
ALLOWED_IMG_PIXEL_RANGES = ("[0, 255]", "[0.0, 1.0]")

try:
from numpy import array as numpy_array
Expand All @@ -65,8 +69,8 @@ def pilimage_fromarray():
pass


def get_image_content(image) -> Optional[bytes]:
content = _image_to_bytes(image)
def get_image_content(image, autoscale=True) -> Optional[bytes]:
content = _image_to_bytes(image, autoscale)

return content

Expand All @@ -83,12 +87,12 @@ def get_pickle_content(obj) -> Optional[bytes]:
return content


def _image_to_bytes(image) -> bytes:
def _image_to_bytes(image, autoscale) -> bytes:
if image is None:
raise ValueError("image is None")

elif is_numpy_array(image):
return _get_numpy_as_image(image)
return _get_numpy_as_image(image, autoscale)

elif is_pil_image(image):
return _get_pil_image_data(image)
Expand All @@ -97,10 +101,10 @@ def _image_to_bytes(image) -> bytes:
return _get_figure_image_data(image)

elif _is_torch_tensor(image):
return _get_numpy_as_image(image.detach().numpy())
return _get_numpy_as_image(image.detach().numpy(), autoscale)

elif _is_tensorflow_tensor(image):
return _get_numpy_as_image(image.numpy())
return _get_numpy_as_image(image.numpy(), autoscale)

elif is_seaborn_figure(image):
return _get_figure_image_data(image.figure)
Expand Down Expand Up @@ -196,38 +200,48 @@ def _image_content_to_html(content: bytes) -> str:
return "<img src='data:image/png;base64," + str_equivalent_image + "'/>"


def _get_numpy_as_image(array):
def _get_numpy_as_image(array: np.ndarray, autoscale: bool) -> bytes:
Raalsky marked this conversation as resolved.
Show resolved Hide resolved
array = array.copy() # prevent original array from modifying
if autoscale:
array = _scale_array(array)

data_range_warnings = []
array_min = array.min()
array_max = array.max()
if array_min < 0:
data_range_warnings.append(f"the smallest value in the array is {array_min}")
if array_max > 1:
data_range_warnings.append(f"the largest value in the array is {array_max}")
if data_range_warnings:
data_range_warning_message = (" and ".join(data_range_warnings) + ".").capitalize()
logger.warning(
"%s To be interpreted as colors correctly values in the array need to be in the [0, 1] range.",
data_range_warning_message,
)
array *= 255
shape = array.shape
if len(shape) == 2:
if len(array.shape) == 2:
return _get_pil_image_data(pilimage_fromarray(array.astype(numpy_uint8)))
if len(shape) == 3:
if shape[2] == 1:
if len(array.shape) == 3:
if array.shape[2] == 1:
array2d = numpy_array([[col[0] for col in row] for row in array])
return _get_pil_image_data(pilimage_fromarray(array2d.astype(numpy_uint8)))
if shape[2] in (3, 4):
if array.shape[2] in (3, 4):
return _get_pil_image_data(pilimage_fromarray(array.astype(numpy_uint8)))
raise ValueError(
"Incorrect size of numpy.ndarray. Should be 2-dimensional or"
"3-dimensional with 3rd dimension of size 1, 3 or 4."
)


def _scale_array(array: np.ndarray) -> np.ndarray:
array_min = array.min()
array_max = array.max()

if array_min >= 0 and 1 < array_max <= 255:
return array

if array_min >= 0 and array_max <= 1:
return array * 255

_warn_about_incorrect_image_data_range(array_min, array_max)
return array


def _warn_about_incorrect_image_data_range(array_min: int | float, array_max: int | float) -> None:
msg = f"Image data is in range [{array_min}, {array_max}]."
logger.warning(
"%s To be interpreted as colors correctly values in the array need to be in the %s or %s range.",
msg,
*ALLOWED_IMG_PIXEL_RANGES,
)


def _get_pil_image_data(image: PILImage) -> bytes:
with io.BytesIO() as image_buffer:
image.save(image_buffer, format="PNG")
Expand Down
4 changes: 2 additions & 2 deletions src/neptune/types/atoms/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def from_stream(stream: IOBase, *, seek: Optional[int] = 0, extension: Optional[
return File(file_composite=file_composite)

@staticmethod
def as_image(image) -> "File":
def as_image(image, autoscale: bool = True) -> "File":
"""Static method for converting image objects or image-like objects to an image File value object.

This way you can upload `Matplotlib` figures, `Seaborn` figures, `PIL` images, `NumPy` arrays, as static images.
Expand Down Expand Up @@ -207,7 +207,7 @@ def as_image(image) -> "File":
.. _as_image docs page:
https://docs.neptune.ai/api/field_types#as_image
"""
content_bytes = get_image_content(image)
content_bytes = get_image_content(image, autoscale=autoscale)
return File.from_content(content_bytes if content_bytes is not None else b"", extension="png")

@staticmethod
Expand Down
58 changes: 46 additions & 12 deletions tests/unit/neptune/new/internal/utils/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
IS_WINDOWS,
)
from neptune.internal.utils.images import (
_scale_array,
get_html_content,
get_image_content,
)
Expand Down Expand Up @@ -75,23 +76,13 @@ def test_get_image_content_from_2d_grayscale_array(self):
def test_get_image_content_from_3d_grayscale_array(self):
# given
image_array = numpy.array([[[1], [0]], [[-3], [4]], [[5], [6]]])
expected_array = numpy.array([[1, 0], [-3, 4], [5, 6]]) * 255
expected_array = numpy.array([[1, 0], [-3, 4], [5, 6]])
expected_image = Image.fromarray(expected_array.astype(numpy.uint8))

# when
_log = partial(format_log, "WARNING")

# expect
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.assertEqual(get_image_content(image_array), self._encode_pil_image(expected_image))
self.assertEqual(
stdout.getvalue(),
_log(
"The smallest value in the array is -3 and the largest value in the array is 6."
" To be interpreted as colors correctly values in the array need to be in the [0, 1] range.\n",
),
)
self.assertEqual(get_image_content(image_array), self._encode_pil_image(expected_image))

def test_get_image_content_from_rgb_array(self):
# given
Expand Down Expand Up @@ -292,3 +283,46 @@ def _random_image_array(w=20, h=30, d: Optional[int] = 3):
return numpy.random.rand(w, h, d)
else:
return numpy.random.rand(w, h)


def test_scale_array_when_array_already_scaled():
# given
arr = numpy.array([[123, 32], [255, 0]])

# when
result = _scale_array(arr)

# then
assert numpy.all(arr == result)


def test_scale_array_when_array_not_scaled():
# given
arr = numpy.array([[0.3, 0], [0.5, 1]])

# when
result = _scale_array(arr)
expected = numpy.array([[76.5, 0.0], [127.5, 255.0]])

# then
assert numpy.all(expected == result)


def test_scale_array_incorrect_range():
# given
arr = numpy.array([[-12, 7], [300, 0]])

# when
_log = partial(format_log, "WARNING")

stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
result = _scale_array(arr)

# then
assert numpy.all(arr == result) # returned original array

assert stdout.getvalue() == _log(
"Image data is in range [-12, 300]. To be interpreted as colors "
"correctly values in the array need to be in the [0, 255] or [0.0, 1.0] range.\n",
)