Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Image API #1562

Merged
merged 47 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7feaf52
initial commit for image api
edgarriba Feb 6, 2022
39f00dd
implement normalize/denormalize
edgarriba Feb 27, 2022
76d11ed
add more tests
edgarriba Feb 28, 2022
c9539aa
add docs
edgarriba Feb 28, 2022
f6300ce
remove normalize and polish api
edgarriba Mar 5, 2022
54f1cbe
add docs
edgarriba Mar 5, 2022
6c482de
polishing gray functionality
edgarriba Mar 5, 2022
3f22d91
more polishing
edgarriba Mar 6, 2022
2737f48
implement onxx tests
edgarriba Mar 6, 2022
28d5f06
rename ImageColor, and add contrib.io
edgarriba Mar 20, 2022
4c92b62
image as non-tensor
edgarriba Mar 24, 2022
fe8f6bf
magix functions
edgarriba Mar 26, 2022
8669dcd
rename ImagePrompter to VisualPrompter
edgarriba Apr 29, 2023
8d34619
Merge branch 'master' into feat/image_api
edgarriba May 27, 2023
762ec63
Merge branch 'master' into feat/image_api
edgarriba May 30, 2023
8ae1bea
[pre-commit.ci] pre-commit suggestions (#2394)
pre-commit-ci[bot] May 30, 2023
a5f973a
remove artifacts
edgarriba May 31, 2023
fd60abe
Merge branch 'master' into feat/image_api
edgarriba May 31, 2023
60952ee
remove requirements
edgarriba May 31, 2023
ac3ab33
Merge branch 'master' into feat/image_api
edgarriba Jul 10, 2023
aefdcea
simplify base class
edgarriba Jul 10, 2023
30ab7e9
add TensorWrapper
edgarriba Jul 10, 2023
6c499ee
fix docs generation
edgarriba Jul 10, 2023
3fcaec2
fix test
edgarriba Jul 12, 2023
0c57bb4
recover setup
edgarriba Jul 12, 2023
d211998
more fixes
edgarriba Jul 12, 2023
c7e2a86
implement write
edgarriba Jul 12, 2023
2379134
Merge branch 'master' into feat/image_api
edgarriba Jul 12, 2023
9621d72
add docs
edgarriba Jul 12, 2023
9e9e5ff
fix backend
edgarriba Jul 12, 2023
cf69857
fix backend
edgarriba Jul 12, 2023
009ec77
fix mypy
edgarriba Jul 12, 2023
fbf8daa
Fix write shape
edgarriba Jul 13, 2023
a5aa2cb
apply code review
edgarriba Jul 13, 2023
d631498
artifact
edgarriba Jul 13, 2023
58ba513
fix docstring
edgarriba Jul 13, 2023
af2030b
fix annotations
edgarriba Jul 13, 2023
2a83848
skip test because of dlpack
edgarriba Jul 13, 2023
53f7893
Update kornia/image/image.py
edgarriba Jul 13, 2023
0d9d455
simplify
edgarriba Jul 13, 2023
dd9c050
fix doctest
edgarriba Jul 13, 2023
ed5f92f
fix doctest
edgarriba Jul 13, 2023
c31123c
fix doctest
edgarriba Jul 13, 2023
21728c2
fix build docs
edgarriba Jul 13, 2023
089f648
iterate on the pixel format definition
edgarriba Jul 15, 2023
7e5ff6d
Update kornia/image/base.py
edgarriba Jul 16, 2023
42a1509
adjust test tolerance
edgarriba Jul 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 26 additions & 9 deletions kornia/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,32 @@ class ImageSize:
width: int | Tensor


class PixelFormat(Enum):
class ColorSpace(Enum):
r"""Enum that represents the pixel format of an image."""
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
GRAY = 0
RGB = 1
BGR = 2
UNKNOWN = 0 # for now, in case of multi band images
GRAY = 1
RGB = 2
BGR = 3


@dataclass(frozen=True)
class PixelFormat:
r"""Data class to represent the pixel format of an image.

Args:
color_space: color space.
bit_depth: the number of bits per channel.

Example:
>>> pixel_format = PixelFormat(ColorSpace.RGB, 8)
>>> pixel_format.color_space
<ColorSpace.RGB: 2>
>>> pixel_format.bit_depth
8
"""

color_space: ColorSpace
bit_depth: int


class ChannelsOrder(Enum):
Expand All @@ -45,24 +66,20 @@ class ImageLayout:
Args:
image_size: image size.
channels: number of channels.
pixel_format: pixel format.
channels_order: channels order.

Example:
>>> layout = ImageLayout(ImageSize(3, 4), 3, PixelFormat.RGB, ChannelsOrder.CHANNELS_LAST)
>>> layout = ImageLayout(ImageSize(3, 4), 3, ChannelsOrder.CHANNELS_LAST)
>>> layout.image_size
ImageSize(height=3, width=4)
>>> layout.channels
3
>>> layout.pixel_format
<PixelFormat.RGB: 1>
>>> layout.channels_order
<ChannelsOrder.CHANNELS_LAST: 1>
"""

edgarriba marked this conversation as resolved.
Show resolved Hide resolved
image_size: ImageSize
channels: int
pixel_format: PixelFormat
channels_order: ChannelsOrder


Expand Down
67 changes: 41 additions & 26 deletions kornia/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.utils.dlpack import from_dlpack, to_dlpack

from kornia.core import Device, Dtype, Tensor
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.image.base import ChannelsOrder, ImageLayout, ImageSize, PixelFormat
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
from kornia.image.base import ChannelsOrder, ColorSpace, ImageLayout, ImageSize, PixelFormat
from kornia.io.io import ImageLoadType, load_image, write_image

# placeholder for numpy
Expand All @@ -35,29 +35,33 @@

Examples:
>>> # from a torch.tensor
>>> data = torch.randint(0, 255, (3, 4, 5)) # CxHxW
>>> data = torch.randint(0, 255, (3, 4, 5), dtype=torch.uint8) # CxHxW
>>> pixel_format = PixelFormat(
... color_space=ColorSpace.RGB,
... bit_depth=8,
... )
>>> layout = ImageLayout(
... image_size=ImageSize(4, 5),
... channels=3,
... pixel_format=PixelFormat.RGB,
... channels_order=ChannelsOrder.CHANNELS_FIRST,
... )
>>> img = Image(data, layout)
>>> img = Image(data, pixel_format, layout)
>>> assert img.channels == 3

>>> # from a numpy array (like opencv)
>>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
>>> img = Image.from_numpy(data, pixel_format=PixelFormat.BGR)
>>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
>>> assert img.channels == 3
>>> assert img.width == 5
>>> assert img.height == 4
"""

def __init__(self, data: Tensor, layout: ImageLayout) -> None:
def __init__(self, data: Tensor, pixel_format: PixelFormat, layout: ImageLayout) -> None:
"""Image constructor.

Args:
data: a torch tensor containing the image data.
pixel_format: the pixel format of the image.
layout: a dataclass containing the image layout information.
"""
# TODO: move this to a function KORNIA_CHECK_IMAGE_LAYOUT
Expand All @@ -69,12 +73,14 @@
raise NotImplementedError(f"Layout {layout.channels_order} not implemented.")

KORNIA_CHECK_SHAPE(data, shape)
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
KORNIA_CHECK(data.element_size() == pixel_format.bit_depth // 8, "Invalid bit depth.")

self._data = data
edgarriba marked this conversation as resolved.
Show resolved Hide resolved
self._pixel_format = pixel_format
self._layout = layout

def __repr__(self) -> str:
return f"Image data: {self.data}\nLayout: {self.layout}"
return f"Image data: {self.data}\nPixel Format: {self.pixel_format}\n Layout: {self.layout}"

# TODO: explore use TensorWrapper
def to(self, device: Device = None, dtype: Dtype = None) -> Image:
Expand All @@ -86,7 +92,7 @@

# TODO: explore use TensorWrapper
def clone(self) -> Image:
return Image(self.data.clone(), self.layout)
return Image(self.data.clone(), self.pixel_format, self.layout)

@property
def data(self) -> Tensor:
Expand All @@ -108,6 +114,11 @@
"""Return the image device."""
return self.data.device

@property
def pixel_format(self) -> PixelFormat:
"""Return the pixel format."""
return self._pixel_format

@property
def layout(self) -> ImageLayout:
"""Return the image layout."""
Expand All @@ -133,11 +144,6 @@
"""Return the image width (rows)."""
return int(self.layout.image_size.width)

@property
def pixel_format(self) -> PixelFormat:
"""Return the pixel format."""
return self.layout.pixel_format

@property
def channels_order(self) -> ChannelsOrder:
"""Return the channels order."""
Expand All @@ -149,23 +155,28 @@
self._data = self.data.float()
return self

# TODO implement this
def to_color_space(self, color_space: ColorSpace) -> Image:
"""Convert the image to a different color space."""
raise NotImplementedError

@classmethod
def from_numpy(
cls,
data: np_ndarray,
color_space: ColorSpace = ColorSpace.RGB,
channels_order: ChannelsOrder = ChannelsOrder.CHANNELS_LAST,
pixel_format: PixelFormat = PixelFormat.RGB,
) -> Image:
"""Construct an image tensor from a numpy array.

Args:
data: a numpy array containing the image data.
channels_order: the channel order of the image.
color_space: the color space of the image.
pixel_format: the pixel format of the image.

Example:
>>> data = np.ones((4, 5, 3), dtype=np.uint8) # HxWxC
>>> img = Image.from_numpy(data, pixel_format=PixelFormat.BGR)
>>> img = Image.from_numpy(data, color_space=ColorSpace.RGB)
>>> assert img.channels == 3
>>> assert img.width == 5
>>> assert img.height == 4
Expand All @@ -179,17 +190,18 @@
else:
raise ValueError("channels_order must be either `CHANNELS_LAST` or `CHANNELS_FIRST`")

# create the pixel format based on the input data
pixel_format = PixelFormat(color_space=color_space, bit_depth=data.itemsize * 8)

# create the image layout based on the input data
layout = ImageLayout(
image_size=image_size, channels=channels, pixel_format=pixel_format, channels_order=channels_order
)
layout = ImageLayout(image_size=image_size, channels=channels, channels_order=channels_order)

# create the image tensor
return cls(torch.from_numpy(data), layout)
return cls(torch.from_numpy(data), pixel_format, layout)

Check warning on line 200 in kornia/image/image.py

View check run for this annotation

Codecov / codecov/patch

kornia/image/image.py#L200

Added line #L200 was not covered by tests

def to_numpy(self) -> np_ndarray:
"""Return a numpy array in cpu from the image tensor."""
return self.data.cpu().detach().numpy()

Check warning on line 204 in kornia/image/image.py

View check run for this annotation

Codecov / codecov/patch

kornia/image/image.py#L204

Added line #L204 was not covered by tests

@classmethod
def from_dlpack(cls, data: DLPack) -> Image:
Expand All @@ -202,21 +214,22 @@
>>> x = np.ones((4, 5, 3))
>>> img = Image.from_dlpack(x.__dlpack__())
"""
_data = from_dlpack(data)
_data: Tensor = from_dlpack(data)

Check warning on line 217 in kornia/image/image.py

View check run for this annotation

Codecov / codecov/patch

kornia/image/image.py#L217

Added line #L217 was not covered by tests

pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=_data.element_size() * 8)

# create the image layout based on the input data
layout = ImageLayout(
image_size=ImageSize(height=_data.shape[1], width=_data.shape[2]),
channels=_data.shape[0],
pixel_format=PixelFormat.RGB,
channels_order=ChannelsOrder.CHANNELS_FIRST,
)

return cls(_data, layout)
return cls(_data, pixel_format, layout)

def to_dlpack(self) -> DLPack:
"""Return a DLPack capsule from the image tensor."""
return to_dlpack(self.data)

Check warning on line 232 in kornia/image/image.py

View check run for this annotation

Codecov / codecov/patch

kornia/image/image.py#L232

Added line #L232 was not covered by tests

@classmethod
def from_file(cls, file_path: str | Path) -> Image:
Expand All @@ -227,13 +240,15 @@
"""
# TODO: allow user to specify the desired type and device
data: Tensor = load_image(file_path, desired_type=ImageLoadType.RGB8, device="cpu")

pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=data.element_size() * 8)

layout = ImageLayout(
image_size=ImageSize(height=data.shape[1], width=data.shape[2]),
channels=data.shape[0],
pixel_format=PixelFormat.RGB,
channels_order=ChannelsOrder.CHANNELS_FIRST,
)
return cls(data, layout)
return cls(data, pixel_format, layout)

def write(self, file_path: str | Path) -> None:
"""Write the image to a file.
Expand Down
33 changes: 13 additions & 20 deletions test/image/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from kornia.image.base import ChannelsOrder, ImageLayout, ImageSize, PixelFormat
from kornia.image.base import ChannelsOrder, ColorSpace, ImageLayout, ImageSize, PixelFormat
from kornia.image.image import Image
from kornia.testing import assert_close
from kornia.utils._compat import torch_version_le
Expand All @@ -13,14 +13,10 @@
class TestImage:
def test_smoke(self, device):
data = torch.randint(0, 255, (3, 4, 5), device=device, dtype=torch.uint8)
layout = ImageLayout(
image_size=ImageSize(4, 5),
channels=3,
pixel_format=PixelFormat.RGB,
channels_order=ChannelsOrder.CHANNELS_FIRST,
)
pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=8)
layout = ImageLayout(image_size=ImageSize(4, 5), channels=3, channels_order=ChannelsOrder.CHANNELS_FIRST)

img = Image(data, layout)
img = Image(data, pixel_format, layout)
assert isinstance(img, Image)
assert img.channels == 3
assert img.height == 4
Expand All @@ -29,19 +25,20 @@ def test_smoke(self, device):
assert img.device == device
assert img.dtype == torch.uint8
assert img.layout == layout
assert img.pixel_format == PixelFormat.RGB
assert img.pixel_format.color_space == ColorSpace.RGB
assert img.pixel_format.bit_depth == 8
assert img.channels_order == ChannelsOrder.CHANNELS_FIRST

def test_numpy(self, device):
# as it was from cv2.imread
data = np.ones((4, 5, 3), dtype=np.uint8)
img = Image.from_numpy(data, pixel_format=PixelFormat.RGB)
img = Image.from_numpy(data, color_space=ColorSpace.RGB)
img = img.to(device)
assert isinstance(img, Image)
assert img.channels == 3
assert img.height == 4
assert img.width == 5
assert img.pixel_format == PixelFormat.RGB
assert img.pixel_format.color_space == ColorSpace.RGB
assert img.shape == (4, 5, 3)
assert img.device == device
assert img.dtype == torch.uint8
Expand All @@ -60,13 +57,9 @@ def test_numpy(self, device):

def test_dlpack(self, device, dtype):
data = torch.rand((3, 4, 5), device=device, dtype=dtype)
layout = ImageLayout(
image_size=ImageSize(4, 5),
channels=3,
pixel_format=PixelFormat.RGB,
channels_order=ChannelsOrder.CHANNELS_FIRST,
)
img = Image(data, layout=layout)
pixel_format = PixelFormat(color_space=ColorSpace.RGB, bit_depth=data.element_size() * 8)
layout = ImageLayout(image_size=ImageSize(4, 5), channels=3, channels_order=ChannelsOrder.CHANNELS_FIRST)
img = Image(data, pixel_format=pixel_format, layout=layout)
assert_close(data, Image.from_dlpack(img.to_dlpack()).data)

@pytest.mark.skipif(torch_version_le(1, 9, 1), reason="dlpack is broken in torch<=1.9.1")
Expand All @@ -80,9 +73,9 @@ def test_load_write(self, tmp_path: Path) -> None:
img2 = Image.from_file(file_name)

# NOTE: the tolerance is high due to the jpeg compression
assert (img.float().data - img2.float().data).pow(2).mean() > 0.5
assert (img.float().data - img2.float().data).pow(2).mean() < 0.6

def test_write_first_channel(self, tmp_path: Path) -> None:
data = np.ones((4, 5, 3), dtype=np.uint8)
img = Image.from_numpy(data, pixel_format=PixelFormat.RGB, channels_order=ChannelsOrder.CHANNELS_LAST)
img = Image.from_numpy(data, color_space=ColorSpace.RGB, channels_order=ChannelsOrder.CHANNELS_LAST)
img.write(tmp_path / "image.jpg")