Skip to content

v0.7.0 Image API, RT-DETR and Object Detection API, LightGlue Matcher, MobileSam, new Sensors API and many more

Compare
Choose a tag to compare
@edgarriba edgarriba released this 02 Aug 09:57
29e4f96

Highlights

Image API

In this release we have added a new Image API as placeholder to support a more generic multibackend api. You can export/import from files, numpy and dlapck.

>>> # from a torch.tensor
>>> data = torch.randint(0, 255, (3, 4, 5), dtype=torch.uint8)  # CxHxW
>>> pixel_format = PixelFormat(
...     color_space=ColorSpace.RGB,
...     bit_depth=8,
... )
>>> layout = ImageLayout(
...     image_size=ImageSize(4, 5),
...     channels=3,
...     channels_order=ChannelsOrder.CHANNELS_FIRST,
... )
>>> img = Image(data, pixel_format, layout)
>>> assert img.channels == 3

Object Detection API

We have added the ObjectDetector that includes by default the RT-DETR model. The detection pipeline is fully configurable by supplying a pre-processor, a model, and a post-processor. Example usage is shown below.

from io import BytesIO

import cv2
import numpy as np
import requests
import torch
from PIL import Image
import matplotlib.pyplot as plt

from kornia.contrib.models.rt_detr import RTDETR, DETRPostProcessor, RTDETRConfig
from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor

model_type = "hgnetv2_x"  # also available: resnet18d, resnet34d, resnet50d, resnet101d, hgnetv2_l
checkpoint = f"https://github.com/kornia/kornia/releases/download/v0.7.0/rtdetr_{model_type}.ckpt"
config = RTDETRConfig(model_type, 80, checkpoint=checkpoint)
model = RTDETR.from_config(config).eval()

detector = ObjectDetector(model, ResizePreProcessor(640), DETRPostProcessor(0.3))

url = "https://github.com/kornia/data/raw/main/soccer.jpg"
img = Image.open(BytesIO(requests.get(url).content))
img = np.asarray(img, dtype=np.float32) / 255
img_pt = torch.from_numpy(img).permute(2, 0, 1)
detection = detector.predict([img_pt])

for cls_score_xywh in detection[0].numpy():
    class_id = int(cls_score_xywh[0])
    score = cls_score_xywh[1]
    x, y, w, h = cls_score_xywh[2:].round().astype(int)
    cv2.rectangle(img, (x, y, w, h), (255, 0, 0), 3)

    text = f"{class_id}, {score:.2f}"
    font = cv2.FONT_HERSHEY_SIMPLEX
    (text_width, text_height), _ = cv2.getTextSize(text, font, 1, 2)
    cv2.rectangle(img, (x, y - text_height, text_width, text_height), (255, 0, 0), cv2.FILLED)
    cv2.putText(img, text, (x, y), font, 1, (255, 255, 255), 2)

plt.imshow(img)
plt.show()

img

Deep Models

As part of the kornia.contrib module, we started building a models module where Deep Learning models for Computer Vision (Semantic Segmentation, Object Detection, etc.) will exist.

From an abstract base class ModelBase, we will implement and make available these deep learning models (eg Segment anything). Similarly, we provide standard structures to be used with the results of these models such as SegmentationResults.

The idea is that we can abstract and standardize how these models will behave with our High level APIs. Like for example interacting with the Visual Prompter backend (today Segment Anything is available).

ModelBase provides methods for loading checkpoints (load_checkpoint), and compiling itself via the torch.compile API. And we plan to increase it according to the needs of the community.

Within this release, we are also making other models available to be used like RT_DETR and tiny_vit.

Example of using these abstractions to implement a model:

# Each model should be a submodule inside the `kornia.contrib.models`, and the Model class itself will be exposed under this
# `models` module.

from kornia.contrib.models.base import ModelBase
from dataclasses import dataclass
from kornia.contrib.models.structures import SegmentationResults
from enum import Enum

class MyModelType(Enum):
    """Map the model types."""
    a = 0
    ...

@dataclass
class MyModelConfig:
    model_type: str | int | SamModelType | None = None
    checkpoint: str | None = None
    ...

class MyModel(ModelBase[MyModelConfig]):
    def __init__(...) -> None:
        ...

    @staticmethod
    def from_config(config: MyModelConfig) -> MyModel:
        """Build the model based on the config"""
        ...

    def forward(...) -> SegmentationResults:
        ...

RT-DETR

In most object detection models, non-maximum suppression (NMS) is necessary to remove overlapping and similar bounding boxes. This post-processing algorithm has high latency, preventing object detectors from reaching real-time speed. DETR is a new class of detectors that eliminate NMS step by using transformer decoder to directly predict bounding boxes. RT-DETR enhances Deformable DETR to achieve real-time speed on server-class GPUs by using an efficient backbone. More details can be seen here

TinyViT

TinyViT is an efficient and high-performing transformer model for images. It achieves a top-1 accuracy of 84.8% on ImageNet-1k with only 21M parameters. See TinyViT for more information.

MobileSAM

MobileSAM replaces the heavy ViT-H backbone in the original SAM with TinyViT, which is more than 100 times smaller in terms of parameters and around 40 times faster in terms of inference speed. See MobileSAM for more details.

To use MobileSAM, simply specify "mobile_sam" in the SamConfig:

from kornia.contrib.visual_prompter import VisualPrompter
from kornia.contrib.models.sam import SamConfig

prompter = VisualPrompter(SamConfig("mobile_sam", pretrained=True))

LightGlue matcher

Added the LightGlue LightGlue-based matcher in kornia API. This is based on the original code from paper “LightGlue: Local Feature Matching at Light Speed”. See [LSP23] for more details.

The LightGlue algorithm won a money prize in the Image Matching Challenge 2023 @ CVPR23: https://www.kaggle.com/competitions/image-matching-challenge-2023/overview

See a working example integrating with COLMAP: #2469
image

New Sensors API

New kornia.sensors module to interface with sensors like Camera, IMU, GNSS etc.

We added CameraModel , PinholeModel , CameraModelBase for now.

Usage example:

Define a CameraModel

>>> # Pinhole Camera Model
>>> cam = CameraModel(ImageSize(480, 640), CameraModelType.PINHOLE, torch.Tensor([328., 328., 320., 240.]))
>>> # Brown Conrady Camera Model
>>> cam = CameraModel(ImageSize(480, 640), CameraModelType.BROWN_CONRADY, torch.Tensor([1.0, 1.0, 1.0, 1.0,
... 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]))
>>> # Kannala Brandt K3 Camera Model
>>> cam = CameraModel(ImageSize(480, 640), CameraModelType.KANNALA_BRANDT_K3, torch.Tensor([1.0, 1.0, 1.0,
... 1.0, 1.0, 1.0, 1.0, 1.0]))
>>> # Orthographic Camera Model
>>> cam = CameraModel(ImageSize(480, 640), CameraModelType.ORTHOGRAPHIC, torch.Tensor([328., 328., 320., 240.]))
>>> cam.params
tensor([328., 328., 320., 240.])

Added kornia.geometry.solvers submodule

New module for geometric vision solvers that include the following:

This is part of an upgrade of the find_fundamental to support the 7POINT algorithm.

Image terminal printing

Added kornia.utils.print_image API for printing any given image tensors or image path to terminal.

>>> kornia.utils.print_image("panda.jpg")

Screenshot 2023-07-26 at 11 39 00 PM

What's Changed

New Contributors

Full Changelog: v0.6.12...v0.7.0