In [1]:
import json
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
from torchvision.utils import draw_segmentation_masks
from IPython.display import display
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import Dataset, DataLoader
from torch import Tensor
import torch
from torchvision.io import decode_image
from torchvision.transforms.functional import resize
import attrs
from typing import Any


NUM_IMAGES = 100
TARGET_IMAGE_SIZE = [1024, 1024]
DATASET_DIR = Path("../../dataset/TextOCR")



In [2]:
# https://huggingface.co/datasets/yunusserhat/TextOCR-Dataset
# {
#   "imgs": {
#     "OpenImages_ImageID_1": {
#       "id": "OpenImages_ImageID_1",
#       "width":  W,
#       "height": H,
#       "set": "train|val|test",
#       "filename": "train|test/OpenImages_ImageID_1.jpg"
#     }
#   },
#   "anns": {
#     "OpenImages_ImageID_1_1": {
#       "id": "OpenImages_ImageID_1_1",
#       "image_id": "OpenImages_ImageID_1",
#       "bbox": [x1, y1, x2, y2],
#       "points": [x1, y1, x2, y2, ..., xN, yN],
#       "utf8_string": "text",
#       "area": A
#     }
#   },
#   "img2Anns": {
#     "OpenImages_ImageID_1": ["OpenImages_ImageID_1_1", "..."]
#   }
# }

In [3]:
from ml.ingest.textocr_to_torch import TextOCRDoctrDetDataset


dataset = TextOCRDoctrDetDataset(num_samples=NUM_IMAGES)
loader = DataLoader(dataset, shuffle=True)

In [4]:
def clean_polygons(polys_p_n_2: torch.Tensor) -> list[Tensor]:
    """
    Remove zero-padded points from polygon tensors.
    polys_p_n_2: [P, N, 2]
    Returns: list of (N_i, 2) numpy arrays (valid polygons)
    """
    cleaned: list[Tensor] = []
    for poly in polys_p_n_2:
        # keep only points that are not [0,0]
        mask = ~torch.all(poly == 0, dim=-1)
        valid = poly[mask]
        if len(valid) >= 3:  # at least a triangle
            cleaned.append(valid)
    return cleaned

def overlay_polygons(image: torch.Tensor,
                     polygons: torch.Tensor,
                     alpha: float = 0.5) -> torch.Tensor:

    _, H, W = image.shape
    unpadded_polys = clean_polygons(polygons)
    if not unpadded_polys:
        return image

    unpadded_polys = [p*TARGET_IMAGE_SIZE[0] for p in unpadded_polys]
    
    masks = torch.zeros((len(unpadded_polys), H, W), dtype=torch.bool)
    for i, poly in enumerate(unpadded_polys):
        m = Image.new(mode="1", size=(W, H), color=0)
        d = ImageDraw.Draw(m)
        d.polygon([p.tolist() for p in poly], fill=1)
        masks[i] = torch.from_numpy(np.array(m, dtype=bool))

    colors = ["red", "lime", "blue", "yellow", "magenta", "cyan"]
    return draw_segmentation_masks(image, masks,
                                   colors=[colors[i % len(colors)] for i in range(len(unpadded_polys))],
                                   alpha=alpha)

In [None]:
images, targets = next(iter(loader))
sample_image = images[0]
sample_polygons = targets['words'][0]
overlay = overlay_polygons(sample_image, sample_polygons, alpha=0.4)
display(to_pil_image(overlay))
