TODO - DESCRIBE THIS FILE 

In [None]:

# Install Pillow if not already installed
%pip install pillow  

In [None]:
# Standard Library
import json
import math
import os
import shutil
from typing import Any, NewType
import csv 

# Third-Party Libraries
from PIL import Image, ImageDraw
from PIL.TiffTags import TAGS
from tqdm import tqdm

from predictor import make_patches
# Type Alias for Bounding Boxes
Box = NewType('Box', tuple[int, int, int, int])

# Dataset Paths (Remember to replace placeholders with your actual paths)
DATASET_SOURCE = os.path.join( "/mnt", "f", "dataUtils", "raw_data")  # FIXME: Change accordingly
OUTPUT_DATASET = os.path.join( "/mnt", "e", "dataset")   # FIXME: Change accordingly

Math utils

In [None]:
#utils
def get_image_metadata(img: Image) -> dict[str, float]:
    """Extracts resolution metadata from a TIFF image.
    Args:
        img: A PIL Image object.
    Returns:
        A dictionary containing resolution information (x, y, and unit).
    """
    tiff_tags = {TAGS.get(tag, tag): value for tag, value in img.tag.items()}
        
    res = {
        "x_resolution": tiff_tags['XResolution'][0][0]/tiff_tags['XResolution'][0][1],
        "y_resolution": tiff_tags['YResolution'][0][0]/tiff_tags['YResolution'][0][1],
        "resolution_unit": tiff_tags['ResolutionUnit'] 
    }
    return res

class ImageUtils:
    """Provides utility functions for working with image resolutions and areas."""

    def __init__(self, img):
        self.img = img
        self.metadata = get_image_metadata(img)
        self.x_resolution = float(self.metadata["x_resolution"])
        self.y_resolution = float(self.metadata["y_resolution"])

    def area_units_to_pixels(self, area):
        """Converts area from resolution units to pixels."""
        return float(area) * (max(self.x_resolution, self.y_resolution) ** 2)

    def area_pixels_to_units(self, area):
        """Converts area from pixels to resolution units."""
        return float(area) / (max(self.x_resolution, self.y_resolution) ** 2)


class ImageUtilOpener:
    """Context manager for opening and working with an image."""

    def __init__(self, file_name):
        self.file_name = file_name

    def __enter__(self):
        self.img = Image.open(self.file_name)
        return ImageUtils(self.img)

    def __exit__(self, *args):
        self.img.close()


Read csv+tif ImageJ-created dataset, and convert to dict
of images_path as keys and corresponding bounding boxes

In [None]:
def convert_cxcywh_to_xyxy(box: Box) -> Box:
    """Converts a bounding box from center-x, center-y, width, height format (CXCYWH) 
    to top-left-x, top-left-y, bottom-right-x, bottom-right-y format (XYXY).
    Args:
        box: A tuple representing the bounding box in CXCYWH format.
    Returns:
        A tuple representing the bounding box in XYXY format.
    """
    cx, cy, w, h = box
    xmin = cx-w//2
    ymin = cy-h//2
    xmax = cx+(w+1)//2  # ceil
    ymax = cy+(h+1)//2  # ceil
    return xmin, ymin, xmax, ymax

def sanitize_annotations(name:str, annotation: dict[str, Any]) -> None:
    """Sanitizes annotations by removing duplicates and invalid bounding boxes.

    Args:
        name: The name of the image associated with the annotations.
        annotations: A dictionary containing "category_id" and "boxes" lists.

    Returns:
        A dictionary containing sanitized "category_id" and "boxes" lists.
    """
    sanitized_annotations = {
        "category_id": [],
        "boxes": []
    }
    
    bbox_dict = dict()
    for label, bbox in zip(annotation["category_id"], annotation["boxes"]):
        if bbox in bbox_dict:
            if bbox_dict[bbox] == label:
                print(f"WARNING: Duplicate bbox found in {name}: {bbox}")
            else:
                print(f"ERROR: Same bbox with different label found in {name}: {bbox}")
            continue
        if any(coord < 0 for coord in bbox):
            print(f"Corrupted box found in {name}: {bbox}")
            continue

        bbox_dict[bbox] = label

    for key in bbox_dict:
        sanitized_annotations["category_id"].append(bbox_dict[key])
        sanitized_annotations["boxes"].append(key)

    return sanitized_annotations
        
def convert_annotations(csv_path: str, img_path: str) -> dict[str, Any]:
    """Converts a CSV file and an image into annotations.
    Args:
        csv_path: The path to the CSV file containing annotations.
        img_path: The path to the image file.
    Returns:
        A dictionary containing "category_id" and "boxes" lists representing the annotations.
    """
    annotations = {}
    annotations = {
        "category_id": [],
        "boxes": []
    }
    
    with ImageUtilOpener(img_path) as util:
        if os.path.exists(csv_path): 
            with open(csv_path) as data_file:
                data = csv.reader(data_file)
                next(data) # Skip header row
                for row in data:
                    try:
                        _, label, area, category_id, _ = row  # Unpack row, ignore filename and category_name
                        _, coordinates = label.split(':')  # Extract coordinates from label
                        y, x = coordinates.split('-')
                        y, x = int(y), int(x)

                        area_in_pixels = util.area_units_to_pixels(float(area))
                        bbox_side = int(math.sqrt(area_in_pixels / math.pi) * 2)  # Calculate square side length

                        bbox = convert_cxcywh_to_xyxy((x, y, bbox_side, bbox_side))
                        annotations["category_id"].append(int(category_id))
                        annotations["boxes"].append(bbox)
                    except Exception as e:
                        
                        annotations["boxes"] = []
                        annotations["category_id"] = []
                        print(f"Error processing row in {csv_path}: {e}")
                        print(f"{csv_path} Failed")
                        break  
                     
    return annotations
    
    
def convert_dataset(data_path: str, imgs_path: str, res_path: str, replace_imgs=True):
    """Converts a dataset from the source format to the desired output format.
    Args:
        data_path: The path to the directory containing CSV annotation files
        imgs_path: The path to the directory containing image files
        res_path: The path to the output directory where the converted dataset will be saved
        replace_imgs: Whether to replace existing images in the output directory
    """
    img_list = [
        os.path.splitext(f)[0] 
        for f in os.listdir(imgs_path) 
        if os.path.isfile(os.path.join(imgs_path, f)) and f.endswith(".tif")
    ]
    
    if replace_imgs:
        try:
            shutil.rmtree(res_path)
        except FileNotFoundError:
            pass

    os.makedirs(os.path.join(res_path, "imgs"), exist_ok=True)
       
    dataset = {}
    image_id = 1  # Start image IDs from 1
    
    for name in img_list:
        csv_path = f'{data_path}/{name}.tif.csv'
        img_path = f'{imgs_path}/{name}.tif'
        annotations = convert_annotations(csv_path, img_path)
        if not annotations["boxes"]:
            continue # Skip images without annotations

        if replace_imgs:
            img = Image.open(img_path)
            img.save(os.path.join(res_path, "imgs", f"{image_id}.jpeg"), quality=100)
        
        annotations = sanitize_annotations(name, annotations)
        dataset[str(image_id)] = annotations
        
        image_id += 1

    with open(os.path.join(res_path, "dataset.json"), "w") as outfile:
        json.dump(dataset, outfile, indent=1)
    print("Finished converting")
    
convert_dataset(f"{DATASET_SOURCE}/csv", f"{DATASET_SOURCE}/imgs", OUTPUT_DATASET, replace_imgs=False)

Patch dataset into images of desired size

In [None]:
from tqdm import tqdm


def box_area(box: Box) -> float:
    """Calculates the box area
    Args:
        box: The box in XYXY format.

    Returns:
        The area of the box.
    """
    x1, y1, x2, y2 = box
    area = max(0, x2 - x1) * max(0, y2 - y1)
    return max(0, x2 - x1) * max(0, y2 - y1)
    

def patch_annots(
         cropped_img: Image.Image,
         crop_box: Box,
         annots: dict[str, Any],
         crop_tolerance: float,
         erase_cropped: bool = True) -> tuple[Image.Image, tuple[list[int], list[Box]]]:
    """Crops an image and adjusts annotations based oWSn the crop region.

    Args:
        cropped_img: The PIL Image object to crop
        crop_box: The bounding box defining the crop region in XYXY format
        annots: A dictionary containing "category_id" and "boxes" lists representing annotations
        crop_tolerance: The tolerance for considering an annotation as fully within the crop
        erase_cropped: Whether to draw over partially cropped annotations on the image

    Returns:
        A tuple containing the cropped image and the adjusted annotations
    """

    cropped_annots = {"category_id": [], "boxes": []}
    labels, bboxes = annots["category_id"], annots["boxes"]
    
    draw_context = ImageDraw.Draw(cropped_img)
    for label, bbox in zip(labels, bboxes):
        relative_bbox = (max(crop_box[0], bbox[0]) - crop_box[0], #xmin
                         max(crop_box[1], bbox[1]) - crop_box[1], #ymin
                         min(crop_box[2], bbox[2]) - crop_box[0], #xmax
                         min(crop_box[3], bbox[3]) - crop_box[1]) #ymax
        cropped = 1 - box_area(relative_bbox) / box_area(bbox)
        if cropped <= crop_tolerance:
            cropped_annots["category_id"].append(label)
            cropped_annots["boxes"].append(relative_bbox)
        elif cropped < 1 and erase_cropped:
            draw_context.rectangle(relative_bbox, width=1, fill="black")

    return cropped_annots

def patch_sample(img: Image.Image,
                 annots: dict[str, Any],
                 desired_image_size: int,
                 overlap: float,
                 crop_tolerance: float) -> tuple[list[Image.Image], list[tuple[list[int], list[Box]]]]:
    """Patches a large image into smaller images with adjusted annotations.
    Args:
        img: The PIL Image object to patch
        annots: A tuple containing category_id and boxes lists representing annotations
        desired_image_size: The desired size of each patch
        overlap: The overlap between adjacent patches (as a fraction of `desired_image_size`)
        crop_tolerance: The tolerance for considering an annotation as fully within a patch
    Returns:
        A tuple containing a list of patched images and a list of corresponding adjusted annotations
    """
    
    padded_width, padded_height, patch_boxes = make_patches(img.width, img.height, desired_image_size, overlap)
    padded_img = Image.new("RGB", (padded_width, padded_height))
    padded_img.paste(img)
    res_imgs, res_annots = [], []
    
    for patch_box in patch_boxes:
        patched_img = padded_img.crop(patch_box)
        patched_annots = patch_annots(patched_img, patch_box, annots, crop_tolerance)
        res_imgs.append(patched_img)
        res_annots.append(patched_annots)

    return res_imgs, res_annots

def patch_dataset(dataset_root: str,
                  desired_image_size: int = 1024,
                  overlap: float = 0.2,
                  crop_tolerance: float=0.3):
    """Patches images in a dataset and saves the patched images and annotations.

    Args:
        dataset_root: The root directory of the dataset
        desired_image_size: The desired size of each patch
        overlap: The overlap between adjacent patches (as a fraction of `desired_image_size`)
        crop_tolerance: The tolerance for considering an annotation as fully within a patch
    """
    dataset_root = os.path.normpath(dataset_root)
    annot_file = os.path.join(dataset_root, "dataset.json")
    imgs_dir = os.path.join(dataset_root, "imgs")
    dataset_parent = os.path.dirname(dataset_root)
    patched_root = os.path.join(dataset_parent, os.path.basename(dataset_root) + "_patched_croptolerance=" + str(crop_tolerance))
    os.makedirs(os.path.join(patched_root, "imgs"), exist_ok=True)
    
    patched_dataset = {}
    
    with open(annot_file, 'r') as annot_file:
        dataset = json.load(annot_file)
        
    for image_id, annotations in tqdm(dataset.items()):
        
        img = Image.open(os.path.join(imgs_dir, f"{image_id}.jpeg"))
        patched_imgs, patched_annots = patch_sample(img, annotations, desired_image_size, overlap, crop_tolerance)
        patch_num = 1
        
        for patched_image, patched_annotations in zip(patched_imgs, patched_annots):
            patch_id = f"{image_id}_{patch_num}"
            patched_dataset[patch_id] = patched_annotations
            patched_image.save(os.path.join(patched_root, "imgs", patch_id + ".jpeg"), quality=95)
            patch_num += 1
            
    with open(os.path.join(patched_root, "dataset.json"), "w") as outfile: 
        json.dump(patched_dataset, outfile, indent=1)
        
patch_dataset(OUTPUT_DATASET, crop_tolerance=0.7) 