In [46]:
import os
import json
import pandas as pd

from typing import List, Dict
from PIL import Image


In [47]:

def crop_box(image: Image.Image, x: int, y: int, w: int, h: int, clip_to_bounds: bool = True) -> Image.Image:
    """Return a cropped region from the given PIL image.

    - x, y are the top-left coordinates
    - w, h are width and height
    - If clip_to_bounds=True, the crop rectangle is clipped to the image bounds; otherwise a ValueError is raised if out of bounds
    - No resizing or visualization is performed
    """
    left = int(x)
    top = int(y)
    right = left + int(w)
    bottom = top + int(h)

    if clip_to_bounds:
        left = max(0, left)
        top = max(0, top)
        right = min(image.width, right)
        bottom = min(image.height, bottom)
    else:
        if not (0 <= left < right <= image.width and 0 <= top < bottom <= image.height):
            raise ValueError(
                f"Crop box {(left, top, right, bottom)} is outside image bounds {(0, 0, image.width, image.height)}"
            )

    return image.crop((left, top, right, bottom))


def crop_images_from_coords(
    json_path: str,
    output_dir: str,
    filename_pattern: str = "{image_stem}_{note}.png",
    clip_to_bounds: bool = True,
    skip_missing_images: bool = True,
) -> List[str]:
    """Crop regions from images using coordinates from a JSON file.

    Expected JSON format:
    {
    "images": [
        {
        "bboxes": [
            {
                "h": 116,
                "note": "15",
                "w": 79,
                "x": 1082,
                "y": 500
            },
            {
                "h": 134,
                "note": "14",
                "w": 82,
                "x": 1169,
                "y": 474
            },
            {
                "h": 125,
                "note": "24",
                "w": 64,
                "x": 1733,
                "y": 468
            },
            {
                "h": 131,
                "note": "25",
                "w": 57,
                "x": 1813,
                "y": 451
            }
            ],
            "image": "path/to/patient001.jpg"
            },
        ...
    ]
    }

    - Saves crops to `output_dir` using `filename_pattern`.
    - Returns a list of saved file paths.
    - If paths in JSON are relative, they are resolved relative to the JSON file directory.
    """
    os.makedirs(output_dir, exist_ok=True)

    with open(json_path, "r") as f:
        data = json.load(f)

    images = data.get("images")

    saved_paths: List[str] = []


    for entry in images:
        image_path = entry.get("image")
        if not image_path:
            continue

        image_stem = os.path.splitext(os.path.basename(image_path))[0]

        with Image.open(image_path) as im:
            im = im.convert("RGB")
            bboxes: List[Dict] = entry.get("bboxes") or entry.get("boxes") or []
            for idx, box in enumerate(bboxes):
                x = int(box["x"])  # top-left x
                y = int(box["y"])  # top-left y
                w = int(box["w"])  # width
                h = int(box["h"])  # height
                note = str(box.get("note", ""))

                cropped = crop_box(im, x, y, w, h, clip_to_bounds=clip_to_bounds)

                filename = filename_pattern.format(image_stem=image_stem, note=note, idx=idx)
                save_path = os.path.join(output_dir, filename)
                cropped.save(save_path)
                saved_paths.append(save_path)

    return saved_paths


In [127]:
saved = crop_images_from_coords(
    json_path="data/coords_all.json",
    output_dir="data/crops",
    filename_pattern="{image_stem}_{note}.png",  # customize if needed
    clip_to_bounds=True,   # set False to error on out-of-bounds
    skip_missing_images=True
)
print(f"Saved {len(saved)} crops")

Saved 972 crops


In [48]:
saved = crop_images_from_coords(
    json_path="data/held_out_test/held_out_test.json",
    output_dir="data/held_out_test/crops",
    filename_pattern="{image_stem}_{note}.png",  # customize if needed
    clip_to_bounds=True,   # set False to error on out-of-bounds
    skip_missing_images=True
)
print(f"Saved {len(saved)} crops")

os.makedirs(f"data/held_out_test/crops/1_root_images", exist_ok=True)
os.makedirs(f"data/held_out_test/crops/2_root_images", exist_ok=True)
os.makedirs(f"data/held_out_test/crops_all/1_root_images", exist_ok=True)
os.makedirs(f"data/held_out_test/crops_all/2_root_images", exist_ok=True)
df = pd.read_csv("data/held_out_test/patient_root_info.csv", dtype=str)

Saved 207 crops


In [49]:
import shutil
for patient in os.listdir("data/held_out_test/crops"):
    if patient.startswith("patient") and patient.endswith(".png"):
        patient_id = patient.split("_")[0]
        tooth_num = patient.split("_")[1].replace(".png", "")
        root_num = df.loc[df["patient"] == patient_id, tooth_num].values[0]
        shutil.copy(f"data/held_out_test/crops/{patient}", f"data/held_out_test/crops_all/{root_num}_root_images/{patient}")


In [None]:
import random as rnd
rnd.seed(42)
root_num=len(os.listdir("data/held_out_test/crops_all/2_root_images"))
root_list=rnd.sample(os.listdir("data/held_out_test/crops_all/1_root_images"), root_num)

shutil.copytree("data/held_out_test/crops_all/2_root_images", "data/held_out_test/crops/2_root_images", dirs_exist_ok=True)

for teeth in root_list:
    shutil.copy(f"data/held_out_test/crops_all/1_root_images/{teeth}", f"data/held_out_test/crops/1_root_images/{teeth}")

In [128]:
for i in range(3):
    os.makedirs(f"data/crops_org/{i}/1_root_images", exist_ok=True)
    os.makedirs(f"data/crops_org/{i}/2_root_images", exist_ok=True)


In [129]:
df = pd.read_csv("data/patient_root_info.csv", dtype=str)

df.head()

Unnamed: 0,patient_id,14,24,15,25,14.1,24.1,15.1,25.1
0,patient001,2,2,1,1,1,1,2,2
1,patient002,2,2,2,2,2,2,1,1
2,patient003,2,2,1,1,2,1,2,1
3,patient005,-,-,1,1,0,0,2,2
4,patient006,1,1,1,1,0,1,2,2


In [130]:
df[(df["15"] == "1") & (df["15.1"] != "0")]["15.1"].value_counts()

15.1
2    183
1     39
Name: count, dtype: int64

In [131]:
import os
import math
import shutil
from typing import Dict, List, Tuple
from pathlib import Path
import pandas as pd


def distribute_crops_by_quality_and_root(
    crops_dir: str,
    df: pd.DataFrame,
    output_root: str,
    valid_teeth: Tuple[str, ...] = ("14", "24", "15", "25"),
    quality_suffix: str = ".1",
    move_files: bool = False,
) -> Dict[str, int]:
    """Distribute cropped images into quality/root folder structure.

    Assumes crop filenames are like "{image_stem}_{note}.png", where:
      - image_stem corresponds to df['patient_id'] (e.g., "patient001")
      - note is one of the tooth labels in valid_teeth (e.g., "14", "24", "15", "25")

    Folder structure created under `output_root`:
      output_root/
        {quality}/
          {root_count}_root_images/

    - Uses df columns {tooth} for root counts and {tooth}{quality_suffix} for quality values
    - Copies by default; set move_files=True to move instead
    - Returns a dict of destination folder -> number of files written
    """
    crops_path = Path(crops_dir)
    output_path = Path(output_root)

    if not crops_path.exists() or not crops_path.is_dir():
        raise NotADirectoryError(f"Crops directory not found: {crops_dir}")

    # Fast lookups by patient_id
    if "patient_id" not in df.columns:
        raise KeyError("DataFrame must contain a 'patient_id' column")
    df_indexed = df.set_index("patient_id", drop=False)

    written_counts: Dict[str, int] = {}

    for file in crops_path.iterdir():
        if not file.is_file():
            continue
        if file.suffix.lower() not in {".png", ".jpg", ".jpeg"}:
            continue

        # Parse filename as {image_stem}_{note}.ext using rsplit to allow underscores in stem
        stem = file.stem
        if "_" not in stem:
            continue  # skip unexpected filenames
        image_stem, tooth_note = stem.rsplit("_", 1)

        if tooth_note not in valid_teeth:
            continue

        # Find row by patient_id (image_stem)
        if image_stem not in df_indexed.index:
            continue
        row = df_indexed.loc[image_stem]

        # Get root count and quality for the tooth
        root_value = row.get(tooth_note)
        quality_value = row.get(f"{tooth_note}{quality_suffix}")

        # Normalize values: handle '-', NaN, strings
        def to_int_or_none(value):
            if value is None:
                return None
            if isinstance(value, str):
                value = value.strip()
                if value == "-" or value == "":
                    return None
                try:
                    value = float(value)
                except ValueError:
                    return None
            if isinstance(value, (int, float)):
                if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
                    return None
                try:
                    return int(round(float(value)))
                except Exception:
                    return None
            return None

        root_count = to_int_or_none(root_value)
        quality = to_int_or_none(quality_value)

        if root_count not in {1, 2}:
            continue  # skip unknown/invalid root counts
        if quality not in {0, 1, 2}:
            continue  # skip unknown/invalid qualities

        dest_dir = output_path / str(quality) / f"{root_count}_root_images"
        dest_dir.mkdir(parents=True, exist_ok=True)

        dest_file = dest_dir / file.name
        if move_files:
            if dest_file.exists():
                dest_file.unlink()
            shutil.move(str(file), str(dest_file))
        else:
            shutil.copy2(str(file), str(dest_file))

        written_counts[str(dest_dir)] = written_counts.get(str(dest_dir), 0) + 1

    return written_counts


In [132]:
counts = distribute_crops_by_quality_and_root(
    crops_dir="data/crops",
    df=df,  # from the previous cell
    output_root="data/crops_org",
    move_files=False,  # set True if you prefer moving instead of copying
)

# Pretty-print results
for dest, n in sorted(counts.items()):
    print(f"{dest}: {n}")
print(f"Total distributed: {sum(counts.values())}")


data/crops_org/0/1_root_images: 17
data/crops_org/0/2_root_images: 19
data/crops_org/1/1_root_images: 126
data/crops_org/1/2_root_images: 124
data/crops_org/2/1_root_images: 385
data/crops_org/2/2_root_images: 290
Total distributed: 961


In [133]:
tooth_1root = {14: 0, 15: 0, 24: 0, 25: 0}
for img in os.listdir("data/crops_org/1/1_root_images") + os.listdir("data/crops_org/2/1_root_images"):
    teeth_num = int(img.split("_")[1].split(".")[0])
    tooth_1root[teeth_num] += 1

tooth_2root = {14: 0, 15: 0, 24: 0, 25: 0}
for img in os.listdir("data/crops_org/1/2_root_images") + os.listdir("data/crops_org/2/2_root_images"):
    teeth_num = int(img.split("_")[1].split(".")[0])
    tooth_2root[teeth_num] += 1

print("1 root: ", tooth_1root)
print("2 root: ", tooth_2root)

1 root:  {14: 58, 15: 222, 24: 38, 25: 193}
2 root:  {14: 160, 15: 52, 24: 154, 25: 48}
