Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support batched and float data in apply colormap #2886

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def main():
# ITERATE OVER THE COLORMAPS
for colormap_name, args in colormaps_list.items():
cm = K.color.ColorMap(base=colormap_name, num_colors=args[0])
out = K.color.rgb_to_bgr(K.color.apply_colormap(bar_img_gray, cm))
out = K.color.rgb_to_bgr(K.color.apply_colormap(bar_img_gray, cm))[0]

out = torch.cat([bar_img, out], dim=-1)

Expand Down Expand Up @@ -399,7 +399,7 @@ def main():
for i, ax in enumerate(axes.flat):
if i < num_colormaps:
cmap = K.color.ColorMap(base=colormap_list[i], num_colors=num_colors)
res = K.color.ApplyColorMap(colormap=cmap)(input_tensor)
res = K.color.ApplyColorMap(colormap=cmap)(input_tensor)[0]
ax.imshow(res.permute(1, 2, 0).numpy())
ax.set_title(colormap_list[i], fontsize=12)
ax.axis("off")
Expand Down
87 changes: 49 additions & 38 deletions kornia/color/colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import kornia.color._colormap_data as cm_data
from kornia.color._colormap_data import RGBColor
from kornia.core import Module, Tensor, tensor
from kornia.core.check import KORNIA_CHECK_IS_GRAY
from kornia.core.check import KORNIA_CHECK
from kornia.utils.helpers import deprecated


Expand Down Expand Up @@ -84,8 +84,8 @@ class ColorMap:
the `ColorMapType` enum class to view all available colormaps.

Args:
base: A list of RGB colors to define a new custom colormap or
the name of a built-in colormap as str or using ColorMapType class.
base: A list of RGB colors to define a new custom colormap or the name of a built-in colormap as str or
using `ColorMapType` class.
num_colors: Number of colors in the colormap.
device: The device to put the generated colormap on.
dtype: The data type of the generated colormap.
Expand Down Expand Up @@ -164,7 +164,7 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
.. image:: _static/img/apply_colormap.png

Args:
input_tensor: the input tensor of a gray image.
input_tensor: the input tensor of image.
colormap: the colormap desired to be applied to the input tensor.

Returns:
Expand All @@ -174,41 +174,49 @@ def apply_colormap(input_tensor: Tensor, colormap: ColorMap) -> Tensor:
ValueError: If `colormap` is not a ColorMap object.

.. note::
The image data is assumed to be integer values in range of [0-255].
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].

Example:
>>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63]]])
>>> colormap = ColorMap(base='autumn')
>>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]])
>>> colormap = ColorMap(base=ColorMapType.autumn)
>>> apply_colormap(input_tensor, colormap)
tensor([[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]],
tensor([[[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]],
<BLANKLINE>
[[0.0000, 0.0159, 0.0317],
[0.3968, 0.7937, 1.0000]],
[[0.0000, 0.0159, 0.0159],
[0.0635, 0.1111, 0.1429],
[0.5079, 0.6190, 0.7302]],
<BLANKLINE>
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]])
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]]])
"""
# FIXME: implement to work with RGB images
# should work with KORNIA_CHECK_SHAPE(x, ["B","C", "H", "W"])

KORNIA_CHECK_IS_GRAY(input_tensor)
KORNIA_CHECK(isinstance(input_tensor, Tensor), f"`input_tensor` must be a Tensor. Got: {type(input_tensor)}")
valid_types = [torch.half, torch.float, torch.double, torch.uint8, torch.int, torch.long, torch.short]
KORNIA_CHECK(
input_tensor.dtype in valid_types, f"`input_tensor` must be a {valid_types}. Got: {input_tensor.dtype}"
)
KORNIA_CHECK(len(input_tensor.shape) in (3, 4), "Wrong input tensor dimension.")
if len(input_tensor.shape) == 3:
input_tensor = input_tensor.unsqueeze_(0)
johnnv1 marked this conversation as resolved.
Show resolved Hide resolved

if len(input_tensor.shape) == 4 and input_tensor.shape[1] == 1: # if (B x 1 X H x W)
input_tensor = input_tensor[:, 0, ...] # (B x H x W)
elif len(input_tensor.shape) == 3 and input_tensor.shape[0] == 1: # if (1 X H x W)
input_tensor = input_tensor[0, ...] # (H x W)
B, C, H, W = input_tensor.shape
input_tensor = input_tensor.reshape(B, C, -1)
max_value = 1.0 if input_tensor.max() <= 1.0 else 255.0
input_tensor = input_tensor.float().div_(max_value)

keys = torch.arange(0, len(colormap) - 1, dtype=input_tensor.dtype, device=input_tensor.device) # (num_colors)
colors = colormap.colors.permute(1, 0)
num_colors, channels_cmap = colors.shape
keys = torch.linspace(0.0, 1.0, num_colors - 1, device=input_tensor.device, dtype=input_tensor.dtype)
indices = torch.bucketize(input_tensor, keys).unsqueeze(-1).expand(-1, -1, -1, 3)

index = torch.bucketize(input_tensor, keys) # shape equals <input_tensor>: (B x H x W) or (H x W)
output = torch.gather(colors.expand(B, C, -1, -1), 2, indices)
# (B, C, H*W, channels_cmap) -> (B, C*channels_cmap, H, W)
output = output.permute(0, 1, 3, 2).reshape(B, C * channels_cmap, H, W)

output = colormap.colors[:, index] # (3 x B x H x W) or (3 x H x W)

if len(output.shape) == 4:
output = output.permute(1, 0, -2, -1) # (B x 3 x H x W)

return output # (B x 3 x H x W) or (3 x H x W)
return output


class ApplyColorMap(Module):
Expand All @@ -229,20 +237,23 @@ class ApplyColorMap(Module):
ValueError: If `colormap` is not a ColorMap object.

.. note::
The image data is assumed to be integer values in range of [0-255].
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].

Example:
>>> input_tensor = torch.tensor([[[0, 1, 2], [25, 50, 63]]])
>>> colormap = ColorMap(base='autumn')
>>> input_tensor = torch.tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]])
>>> colormap = ColorMap(base=ColorMapType.autumn)
>>> ApplyColorMap(colormap=colormap)(input_tensor)
tensor([[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]],
tensor([[[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]],
<BLANKLINE>
[[0.0000, 0.0159, 0.0317],
[0.3968, 0.7937, 1.0000]],
[[0.0000, 0.0159, 0.0159],
[0.0635, 0.1111, 0.1429],
[0.5079, 0.6190, 0.7302]],
<BLANKLINE>
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]])
[[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000]]]])
"""

def __init__(
Expand All @@ -259,7 +270,7 @@ def forward(self, input_tensor: Tensor) -> Tensor:
input_tensor: The input tensor representing the grayscale image.

.. note::
The image data is assumed to be integer values in range of [0-255].
The input tensor must be integer values in the range of [0-255] or float values in the range of [0-1].

Returns:
The output tensor representing the image with the applied colormap.
Expand Down
60 changes: 31 additions & 29 deletions tests/color/test_colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,26 @@ def test_autumn(device, dtype):

class TestApplyColorMap(BaseTester):
def test_smoke(self, device, dtype):
input_tensor = tensor([[[0, 1, 3], [25, 50, 63]]], device=device, dtype=dtype)

input_tensor = tensor([[[0, 1, 2], [15, 25, 33], [128, 158, 188]]], device=device, dtype=dtype)
expected_tensor = tensor(
[
[[1, 1, 1], [1, 1, 1]],
[[0, 0.01587301587301587, 0.04761904761904762], [0.3968253968253968, 0.7936507936507936, 1]],
[[0, 0, 0], [0, 0, 0]],
[
[
[1.0000000000, 1.0000000000, 1.0000000000],
[1.0000000000, 1.0000000000, 1.0000000000],
[1.0000000000, 1.0000000000, 1.0000000000],
],
[
[0.0000000000, 0.0158730168, 0.0158730168],
[0.0634920672, 0.1111111119, 0.1428571492],
[0.5079365373, 0.6190476418, 0.7301587462],
],
[
[0.0000000000, 0.0000000000, 0.0000000000],
[0.0000000000, 0.0000000000, 0.0000000000],
[0.0000000000, 0.0000000000, 0.0000000000],
],
]
],
device=device,
dtype=dtype,
Expand All @@ -42,39 +55,28 @@ def test_smoke(self, device, dtype):

self.assert_close(actual, expected_tensor)

def test_eye(self, device, dtype):
input_tensor = torch.stack(
[torch.eye(2, dtype=dtype, device=device) * 255, torch.eye(2, dtype=dtype, device=device) * 150]
).view(2, -1, 2, 2)

expected_tensor = tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],
[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],
],
device=device,
dtype=dtype,
)

actual = apply_colormap(input_tensor, ColorMap(base="autumn", device=device, dtype=dtype))
self.assert_close(actual, expected_tensor)

def test_exception(self, device, dtype):
cm = ColorMap(base="autumn", device=device, dtype=dtype)
with pytest.raises(TypeError):
apply_colormap(torch.rand(size=(5, 1, 1), dtype=dtype, device=device), cm)
with pytest.raises(Exception):
apply_colormap(torch.rand(size=(3, 3), dtype=dtype, device=device), cm)

with pytest.raises(Exception):
apply_colormap(torch.rand(size=(3), dtype=dtype, device=device), cm)

with pytest.raises(Exception):
apply_colormap(torch.rand(size=(3), dtype=dtype, device=device).item(), cm)

@pytest.mark.parametrize("shape", [(2, 1, 4, 4), (1, 4, 4), (4, 4)])
@pytest.mark.parametrize("shape", [(2, 1, 3, 3), (1, 3, 3, 3), (1, 3, 3)])
@pytest.mark.parametrize("cmap_base", ColorMapType)
def test_cardinality(self, shape, device, dtype, cmap_base):
cm = ColorMap(base=cmap_base, device=device, dtype=dtype)
input_tensor = torch.randint(0, 63, shape, device=device, dtype=dtype)
cm = ColorMap(base=cmap_base, num_colors=256, device=device, dtype=dtype)
input_tensor = torch.randint(0, 256, shape, device=device, dtype=dtype)
actual = apply_colormap(input_tensor, cm)

if len(shape) == 4:
expected_shape = (shape[0], 3, shape[-2], shape[-1])
expected_shape = (shape[-4], shape[-3] * 3, shape[-2], shape[-1])
else:
expected_shape = (3, shape[-2], shape[-1])
expected_shape = (1, shape[-3] * 3, shape[-2], shape[-1])

assert actual.shape == expected_shape

Expand Down
Loading