In [73]:
import json
from pathlib import Path
import attrs
import functools
import numpy as np

NUM_IMAGES = 100
IMG_ROOT = Path("../../dataset/TextOCR/train_val_images")
TOCR_JSON = Path("../../dataset/TextOCR/TextOCR_0.1_train.json")

In [None]:
# https://huggingface.co/datasets/yunusserhat/TextOCR-Dataset
tocr_json = json.load(open(TOCR_JSON))
# {
#   "imgs": {
#     "OpenImages_ImageID_1": {
#       "id": "OpenImages_ImageID_1",
#       "width":  W,
#       "height": H,
#       "set": "train|val|test",
#       "filename": "train|test/OpenImages_ImageID_1.jpg"
#     }
#   },
#   "anns": {
#     "OpenImages_ImageID_1_1": {
#       "id": "OpenImages_ImageID_1_1",
#       "image_id": "OpenImages_ImageID_1",
#       "bbox": [x1, y1, x2, y2],
#       "points": [x1, y1, x2, y2, ..., xN, yN],
#       "utf8_string": "text",
#       "area": A
#     }
#   },
#   "img2Anns": {
#     "OpenImages_ImageID_1": ["OpenImages_ImageID_1_1", "..."]
#   }
# }

In [144]:
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
import torch
from torchvision.io import decode_image
from typing import Any

class TextOCRDoctrDetDataset(Dataset[tuple[Tensor, Tensor]]):

    def __init__(self, images_dir: Path, json_path: Path, num_samples: int):
        json_data = tocr_json or json.load(open(json_path))

        # Maximum number of labels and points for padding
        max_labels, max_pts = self.get_max_dimensions(json_data) 

        self.image_paths: list[Path] = []
        self.annotations: Tensor = torch.zeros((num_samples, max_labels, max_pts, 2))

        sample_num = 0
        for img_id, ann_ids in json_data["imgToAnns"].items():
            if sample_num >= num_samples:
                break

            # Construct image path
            tocr_img_json = json_data["imgs"][img_id]
            self.image_paths.append(images_dir/tocr_img_json["file_name"])

            # Convert image annotations to Tensors
            img_id=tocr_img_json["id"]

            for i, ann_id in enumerate(ann_ids):
                tocr_ann_json = json_data["anns"][ann_id]
                points: list[float] = tocr_ann_json['points']
                self.annotations[sample_num][i] = self.points_to_tensor(points, padding = max_pts)
            sample_num += 1

    
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        image = decode_image(str(self.image_paths[idx]))
        return image, self.annotations[idx]
    
    def points_to_tensor(self, point_values: list[float], padding: int)->Tensor:
        points: list[tuple[float, float]] = [(0,0)]*padding
        for i in range(len(point_values)//2):
            points[i] = (point_values[2*i], point_values[2*i+1])
        return torch.tensor(points)
    
    def get_max_dimensions(self, json_data: Any) -> tuple[int,int]:
        max_labels = 0
        for ann_ids in json_data["imgToAnns"].values():
            max_labels = max(max_labels, len(ann_ids))

        max_pts = 0
        for ann in json_data["anns"].values():
            max_pts = max(max_pts, len(ann['points']))

        return max_labels, max_pts

dataset = TextOCRDoctrDetDataset(IMG_ROOT, TOCR_JSON, NUM_IMAGES)
loader = DataLoader(dataset, shuffle=True)

In [145]:
import torch
from PIL import Image, ImageDraw
from torchvision.utils import draw_segmentation_masks
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

def clean_polygons(polys_p_n_2: torch.Tensor) -> list:
    """
    Remove zero-padded points from polygon tensors.
    polys_p_n_2: [P, N, 2]
    Returns: list of (N_i, 2) numpy arrays (valid polygons)
    """
    cleaned = []
    for poly in polys_p_n_2:
        # keep only points that are not [0,0]
        mask = ~torch.all(poly == 0, dim=-1)
        valid = poly[mask]
        if len(valid) >= 3:  # at least a triangle
            cleaned.append(valid.numpy())
    return cleaned

def overlay_polygons(img_chw_uint8: torch.Tensor,
                     polys_p_n_2: torch.Tensor,
                     alpha: float = 0.5) -> torch.Tensor:

    _, H, W = img_chw_uint8.shape
    polys = clean_polygons(polys_p_n_2)

    masks = torch.zeros((len(polys), H, W), dtype=torch.bool)
    for i, poly in enumerate(polys):
        m = Image.new(mode="1", size=(W, H), color=0)
        d = ImageDraw.Draw(m)
        d.polygon([tuple(p) for p in poly], fill=1)
        masks[i] = torch.from_numpy(np.array(m, dtype=bool))

    colors = ["red", "lime", "blue", "yellow", "magenta", "cyan"]
    return draw_segmentation_masks(img_chw_uint8, masks,
                                   colors=[colors[i % len(colors)] for i in range(len(polys))],
                                   alpha=alpha)

In [None]:
from torchvision.transforms.functional import to_pil_image

image, polygons = next(iter(loader))
overlay = overlay_polygons(image[0], polygons[0], alpha=0.4)

to_pil_image(overlay).show()