Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tqdm = "*"
pyyaml = "*"
pillow = "*"
pydantic-xml = "*"
numpy = "*"

[tool.poetry.group.dev.dependencies]
mypy = "*"
Expand Down
95 changes: 77 additions & 18 deletions src/labelformat/formats/coco.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

import json
from argparse import ArgumentParser
from pathlib import Path
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, TypedDict

from labelformat.cli.registry import Task, cli_register
from labelformat.model.binary_mask_segmentation import (
BinaryMaskSegmentation,
RLEDecoderEncoder,
)
from labelformat.model.bounding_box import BoundingBox, BoundingBoxFormat
from labelformat.model.category import Category
from labelformat.model.image import Image
Expand Down Expand Up @@ -54,6 +60,14 @@ def get_images(self) -> Iterable[Image]:
)


class _COCOInstanceSegmentationRLE(TypedDict):
counts: list[int]
size: list[int]


_COCOInstanceSegmentationMultiPolygon = List[List[float]]


@cli_register(format="coco", task=Task.OBJECT_DETECTION)
class COCOObjectDetectionInput(_COCOBaseInput, ObjectDetectionInput):
def get_labels(self) -> Iterable[ImageObjectDetection]:
Expand Down Expand Up @@ -103,14 +117,15 @@ def get_labels(self) -> Iterable[ImageInstanceSegmentation]:
for ann in annotations:
if "segmentation" not in ann:
raise ParseError(f"Segmentation missing for image id {image_id}")
segmentation: MultiPolygon | BinaryMaskSegmentation
if ann["iscrowd"] == 1:
raise ParseError(
"Parsing segmentations with iscrowd=1 is not yet supported. "
f"(image id {image_id})"
segmentation = _coco_segmentation_to_binary_mask_rle(
segmentation=ann["segmentation"], bbox=ann["bbox"]
)
else:
segmentation = _coco_segmentation_to_multipolygon(
coco_segmentation=ann["segmentation"]
)
segmentation = _coco_segmentation_to_multipolygon(
coco_segmentation=ann["segmentation"]
)
objects.append(
SingleInstanceSegmentation(
category=category_id_to_category[ann["category_id"]],
Expand Down Expand Up @@ -173,19 +188,41 @@ def save(self, label_input: InstanceSegmentationInput) -> None:
data["annotations"] = []
for label in label_input.get_labels():
for obj in label.objects:
annotation = {
"image_id": label.image.id,
"category_id": obj.category.id,
"bbox": [
segmentation: (
_COCOInstanceSegmentationMultiPolygon | _COCOInstanceSegmentationRLE
)
if isinstance(obj.segmentation, BinaryMaskSegmentation):
segmentation = _binary_mask_rle_to_coco_segmentation(
binary_mask_rle=obj.segmentation
)
bbox = [
float(v)
for v in obj.segmentation.bounding_box().to_format(
for v in obj.segmentation.bounding_box.to_format(
BoundingBoxFormat.XYWH
)
],
"iscrowd": 0,
"segmentation": _multipolygon_to_coco_segmentation(
]
iscrowd = 1
elif isinstance(obj.segmentation, MultiPolygon):
segmentation = _multipolygon_to_coco_segmentation(
multipolygon=obj.segmentation
),
)
bbox = [
float(v)
for v in obj.segmentation.bounding_box().to_format(
BoundingBoxFormat.XYWH
)
]
iscrowd = 0
else:
raise ParseError(
f"Unsupported segmentation type: {type(obj.segmentation)}"
)
annotation = {
"image_id": label.image.id,
"category_id": obj.category.id,
"bbox": bbox,
"iscrowd": iscrowd,
"segmentation": segmentation,
}
data["annotations"].append(annotation)

Expand All @@ -195,7 +232,7 @@ def save(self, label_input: InstanceSegmentationInput) -> None:


def _coco_segmentation_to_multipolygon(
coco_segmentation: List[List[float]],
coco_segmentation: _COCOInstanceSegmentationMultiPolygon,
) -> MultiPolygon:
"""Convert COCO segmentation to MultiPolygon."""
polygons = []
Expand All @@ -213,14 +250,36 @@ def _coco_segmentation_to_multipolygon(
return MultiPolygon(polygons=polygons)


def _multipolygon_to_coco_segmentation(multipolygon: MultiPolygon) -> List[List[float]]:
def _multipolygon_to_coco_segmentation(
multipolygon: MultiPolygon,
) -> _COCOInstanceSegmentationMultiPolygon:
"""Convert MultiPolygon to COCO segmentation."""
coco_segmentation = []
for polygon in multipolygon.polygons:
coco_segmentation.append([x for point in polygon for x in point])
return coco_segmentation


def _coco_segmentation_to_binary_mask_rle(
segmentation: _COCOInstanceSegmentationRLE, bbox: list[float]
) -> BinaryMaskSegmentation:
counts = segmentation["counts"]
height, width = segmentation["size"]
binary_mask = RLEDecoderEncoder.decode_column_wise_rle(counts, height, width)
bounding_box = BoundingBox.from_format(bbox=bbox, format=BoundingBoxFormat.XYWH)
return BinaryMaskSegmentation.from_binary_mask(
binary_mask, bounding_box=bounding_box
)


def _binary_mask_rle_to_coco_segmentation(
binary_mask_rle: BinaryMaskSegmentation,
) -> _COCOInstanceSegmentationRLE:
binary_mask = binary_mask_rle.get_binary_mask()
counts = RLEDecoderEncoder.encode_column_wise_rle(binary_mask)
return {"counts": counts, "size": [binary_mask_rle.height, binary_mask_rle.width]}


def _get_output_images_dict(
images: Iterable[Image],
) -> List[JsonDict]:
Expand Down
4 changes: 4 additions & 0 deletions src/labelformat/formats/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def save(self, label_input: InstanceSegmentationInput) -> None:
label_path.parent.mkdir(parents=True, exist_ok=True)
with label_path.open("w") as file:
for obj in label.objects:
if not isinstance(obj.segmentation, MultiPolygon):
raise ValueError(
f"YOLOv8 format only supports MultiPolygon segmentation."
)
polygon = _multipolygon_to_polygon(multipolygon=obj.segmentation)
polygon_str = " ".join(
[
Expand Down
114 changes: 114 additions & 0 deletions src/labelformat/model/binary_mask_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

from labelformat.model.bounding_box import BoundingBox


@dataclass(frozen=True)
class BinaryMaskSegmentation:
"""
A binary mask.
Internally, the mask is represented as a run-length encoding (RLE) format.
"""

_rle_row_wise: list[int]
width: int
height: int
bounding_box: BoundingBox

@classmethod
def from_binary_mask(
cls, binary_mask: NDArray[np.int_], bounding_box: BoundingBox
) -> "BinaryMaskSegmentation":
"""
Create a BinaryMaskSegmentation instance from a binary mask (2D numpy array)
by converting it to RLE format.
"""
if not isinstance(binary_mask, np.ndarray):
raise ValueError("Binary mask must be a numpy array.")
if binary_mask.ndim != 2:
raise ValueError("Binary mask must be a 2D array.")
height, width = binary_mask.shape

rle_row_wise = RLEDecoderEncoder.encode_row_wise_rle(binary_mask)
return cls(
_rle_row_wise=rle_row_wise,
width=width,
height=height,
bounding_box=bounding_box,
)

def get_binary_mask(self) -> NDArray[np.int_]:
"""
Get the binary mask (2D numpy array) from the RLE format.
"""
return RLEDecoderEncoder.decode_row_wise_rle(
self._rle_row_wise, self.height, self.width
)


class RLEDecoderEncoder:
"""
A class for encoding and decoding binary masks using run-length encoding (RLE).
This class provides methods to encode a binary mask into RLE format and
decode an RLE list back into a binary mask.

The encoding and decoding can be done both row-wise and column-wise.

Example:
Consider a binary mask of shape 2x4:
[[0, 1, 1, 0],
[1, 1, 1, 1]]
Row-wise RLE: [1, 2, 1, 4]
Column-wise RLE: [1, 5, 1, 1]
"""

@staticmethod
def encode_row_wise_rle(binary_mask: NDArray[np.int_]) -> list[int]:
# Encodes a binary mask using row-major order.
flat = np.concatenate(([-1], binary_mask.ravel(order="C"), [-1]))
borders = np.nonzero(np.diff(flat))[0]
rle = np.diff(borders)
if flat[1]:
rle = np.concatenate(([0], rle))
rle_list: list[int] = rle.tolist()
return rle_list

@staticmethod
def encode_column_wise_rle(binary_mask: NDArray[np.int_]) -> list[int]:
# Encodes a binary mask using column-major order.
flat = np.concatenate(([-1], binary_mask.ravel(order="F"), [-1]))
borders = np.nonzero(np.diff(flat))[0]
rle = np.diff(borders)
if flat[1]:
rle = np.concatenate(([0], rle))
rle_list: list[int] = rle.tolist()
return rle_list

@staticmethod
def decode_row_wise_rle(
rle: list[int], height: int, width: int
) -> NDArray[np.int_]:
# Decodes a row-major run-length encoded list into a 2D binary mask.
run_val = 0
decoded = []
for count in rle:
decoded.extend([run_val] * count)
run_val = 1 - run_val
return np.array(decoded, dtype=np.int_).reshape((height, width), order="C")

@staticmethod
def decode_column_wise_rle(
rle: list[int], height: int, width: int
) -> NDArray[np.int_]:
# Decodes a column-major run-length encoded list into a 2D binary mask.
run_val = 0
decoded = []
for count in rle:
decoded.extend([run_val] * count)
run_val = 1 - run_val
return np.array(decoded, dtype=np.int_).reshape((height, width), order="F")
9 changes: 6 additions & 3 deletions src/labelformat/model/instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from argparse import ArgumentParser
from dataclasses import dataclass
from typing import Iterable, List
from typing import Iterable

from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
from labelformat.model.category import Category
from labelformat.model.image import Image
from labelformat.model.multipolygon import MultiPolygon
Expand All @@ -11,13 +14,13 @@
@dataclass(frozen=True)
class SingleInstanceSegmentation:
category: Category
segmentation: MultiPolygon
segmentation: MultiPolygon | BinaryMaskSegmentation


@dataclass(frozen=True)
class ImageInstanceSegmentation:
image: Image
objects: List[SingleInstanceSegmentation]
objects: list[SingleInstanceSegmentation]


class InstanceSegmentationInput(ABC):
Expand Down
Loading