Skip to content

Commit

Permalink
Add Image API (#1562)
Browse files Browse the repository at this point in the history
* initial commit for image api

* implement normalize/denormalize

* add more tests

* add docs

* remove normalize and polish api

* add docs

* polishing gray functionality

* more polishing

* implement onxx tests

* rename ImageColor, and add contrib.io

* image as non-tensor

* magix functions

* rename ImagePrompter to VisualPrompter

* [pre-commit.ci] pre-commit suggestions (#2394)

updates:
- [github.com/charliermarsh/ruff-pre-commit: v0.0.269 → v0.0.270](astral-sh/ruff-pre-commit@v0.0.269...v0.0.270)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* remove artifacts

* remove requirements

* simplify base class

* add TensorWrapper

* fix docs generation

* fix test

* recover setup

* more fixes

* implement write

* add docs

* fix backend

* fix backend

* fix mypy

* Fix write shape

* apply code review

* artifact

* fix docstring

* fix annotations

* skip test because of dlpack

* Update kornia/image/image.py

Co-authored-by: Christie Jacob <christeejacobs@gmail.com>

* simplify

* fix doctest

* fix doctest

* fix doctest

* fix build docs

* iterate on the pixel format definition

* Update kornia/image/base.py

Co-authored-by: Christie Jacob <christeejacobs@gmail.com>

* adjust test tolerance

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Christie Jacob <christeejacobs@gmail.com>
  • Loading branch information
3 people committed Jul 16, 2023
1 parent ebd402c commit 34017fc
Show file tree
Hide file tree
Showing 16 changed files with 640 additions and 130 deletions.
8 changes: 8 additions & 0 deletions docs/source/core.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
kornia.core
===========

.. currentmodule:: kornia.core

.. autoclass:: TensorWrapper
:members:
:undoc-members:
26 changes: 26 additions & 0 deletions docs/source/image.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
kornia.image
============

Module to provide a high level API to process images.

.. currentmodule:: kornia.image

.. autoclass:: ImageSize
:members:
:undoc-members:

.. autoclass:: PixelFormat
:members:
:undoc-members:

.. autoclass:: ChannelsOrder
:members:
:undoc-members:

.. autoclass:: ImageLayout
:members:
:undoc-members:

.. autoclass:: Image
:members:
:undoc-members:
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ Join the community
augmentation
color
contrib
core
enhance
feature
filters
geometry
sensors
io
image
losses
metrics
morphology
Expand Down
1 change: 1 addition & 0 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ natively in Rust to reduce the memory footprint during the decoding and types co
# will load 3xHxW / in torch.float32 in range [0,1] in "cuda"
.. autofunction:: load_image
.. autofunction:: write_image

.. autoclass:: ImageLoadType
:members:
Expand Down
14 changes: 8 additions & 6 deletions kornia/color/gray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from __future__ import annotations

import torch

Expand All @@ -15,7 +15,7 @@ def grayscale_to_rgb(image: Tensor) -> Tensor:
The image data is assumed to be in the range of (0, 1).
Args:
image: grayscale image to be converted to RGB with shape :math:`(*,1,H,W)`.
image: grayscale image tensor to be converted to RGB with shape :math:`(*,1,H,W)`.
Returns:
RGB version of the image with shape :math:`(*,3,H,W)`.
Expand All @@ -26,13 +26,13 @@ def grayscale_to_rgb(image: Tensor) -> Tensor:
"""
KORNIA_CHECK_IS_TENSOR(image)

if image.dim() < 3 or image.size(-3) != 1:
if len(image.shape) < 3 or image.shape[-3] != 1:
raise ValueError(f"Input size must have a shape of (*, 1, H, W). " f"Got {image.shape}.")

return concatenate([image, image, image], -3)


def rgb_to_grayscale(image: Tensor, rgb_weights: Optional[Tensor] = None) -> Tensor:
def rgb_to_grayscale(image: Tensor, rgb_weights: Tensor | None = None) -> Tensor:
r"""Convert a RGB image to grayscale version of image.
.. image:: _static/img/rgb_to_grayscale.png
Expand Down Expand Up @@ -81,7 +81,7 @@ def rgb_to_grayscale(image: Tensor, rgb_weights: Optional[Tensor] = None) -> Ten
return w_r * r + w_g * g + w_b * b


def bgr_to_grayscale(image: torch.Tensor) -> torch.Tensor:
def bgr_to_grayscale(image: Tensor) -> Tensor:
r"""Convert a BGR image to grayscale.
The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts.
Expand Down Expand Up @@ -145,8 +145,10 @@ class RgbToGrayscale(Module):
>>> output = gray(input) # 2x1x4x5
"""

def __init__(self, rgb_weights: Optional[Tensor] = None) -> None:
def __init__(self, rgb_weights: Tensor | None = None) -> None:
super().__init__()
if rgb_weights is None:
rgb_weights = Tensor([0.299, 0.587, 0.114])
self.rgb_weights = rgb_weights

def forward(self, image: Tensor) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions kornia/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
zeros,
zeros_like,
)
from .tensor_wrapper import TensorWrapper # type: ignore

__all__ = [
"arange",
Expand Down Expand Up @@ -52,4 +53,5 @@
"zeros_like",
"linspace",
"diag",
"TensorWrapper",
]
12 changes: 6 additions & 6 deletions kornia/core/_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from typing import Union

import torch
import torch.nn.functional as F
from torch import device

# classes
Tensor = torch.Tensor
Expand All @@ -17,8 +17,8 @@
concatenate = torch.cat
stack = torch.stack
linspace = torch.linspace
normalize = F.normalize
pad = F.pad
normalize = torch.nn.functional.normalize
pad = torch.nn.functional.pad
eye = torch.eye
einsum = torch.einsum
zeros = torch.zeros
Expand All @@ -28,7 +28,7 @@
where = torch.where
complex = torch.complex
diag = torch.diag
softmax = F.softmax
softmax = torch.nn.functional.softmax


# constructors
Expand All @@ -38,5 +38,5 @@
rand = torch.rand

# type alias
Device = Union[str, device, None]
Device = Union[str, torch.device, None]
Dtype = Union[torch.dtype, None]
5 changes: 4 additions & 1 deletion kornia/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .base import ImageSize
from .base import ChannelsOrder, ImageLayout, ImageSize, PixelFormat
from .image import Image

__all__ = ["ImageSize", "PixelFormat", "ChannelsOrder", "ImageLayout", "Image"]
67 changes: 65 additions & 2 deletions kornia/image/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from kornia.core import Tensor


@dataclass
@dataclass(frozen=True)
class ImageSize:
r"""Data class to represent image shape.
Args:
height: image height.
width: image width
width: image width.
Example:
>>> size = ImageSize(3, 4)
>>> size.height
Expand All @@ -21,3 +23,64 @@ class ImageSize:
"""
height: int | Tensor
width: int | Tensor


class ColorSpace(Enum):
r"""Enum that represents the color space of an image."""
UNKNOWN = 0 # for now, in case of multi band images
GRAY = 1
RGB = 2
BGR = 3


@dataclass(frozen=True)
class PixelFormat:
r"""Data class to represent the pixel format of an image.
Args:
color_space: color space.
bit_depth: the number of bits per channel.
Example:
>>> pixel_format = PixelFormat(ColorSpace.RGB, 8)
>>> pixel_format.color_space
<ColorSpace.RGB: 2>
>>> pixel_format.bit_depth
8
"""

color_space: ColorSpace
bit_depth: int


class ChannelsOrder(Enum):
r"""Enum that represents the channels order of an image."""
CHANNELS_FIRST = 0
CHANNELS_LAST = 1


@dataclass(frozen=True)
class ImageLayout:
"""Data class to represent the layout of an image.
Args:
image_size: image size.
channels: number of channels.
channels_order: channels order.
Example:
>>> layout = ImageLayout(ImageSize(3, 4), 3, ChannelsOrder.CHANNELS_LAST)
>>> layout.image_size
ImageSize(height=3, width=4)
>>> layout.channels
3
>>> layout.channels_order
<ChannelsOrder.CHANNELS_LAST: 1>
"""

image_size: ImageSize
channels: int
channels_order: ChannelsOrder


# TODO: define CompressedImage

0 comments on commit 34017fc

Please sign in to comment.