In [None]:
import os
import cv2
import numpy as np
import random
from glob import glob

class FingerprintDatasetYOLO:
    def __init__(self, image_dir, save_dir, image_size=(640, 640)):
        """
        Args:
            image_dir (str): Directory containing fingerprint images in .tif format.
            save_dir (str): Directory to save the processed images and labels.
            image_size (tuple): Desired size of the final white image.
        """
        self.image_paths = glob(f"{image_dir}/*.tif")
        self.image_size = image_size
        self.save_dir = save_dir

        # Create save directories
        self.images_save_dir = os.path.join(save_dir, "images")
        self.labels_save_dir = os.path.join(save_dir, "labels")
        os.makedirs(self.images_save_dir, exist_ok=True)
        os.makedirs(self.labels_save_dir, exist_ok=True)

    def _crop_fingerprint(self, img):
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if contours:
            x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
            return img[y:y+h, x:x+w]
        return img

    def _paste_on_blank(self, img1, img2):
        blank_image = np.ones((*self.image_size, 3), dtype=np.uint8) * 255
        h1, w1 = img1.shape[:2]
        h2, w2 = img2.shape[:2]

        x1 = random.randint(0, self.image_size[1] - w1)
        y1 = random.randint(0, self.image_size[0] - h1)
        x2 = random.randint(0, self.image_size[1] - w2)
        y2 = random.randint(0, self.image_size[0] - h2)

        blank_image[y1:y1+h1, x1:x1+w1] = img1
        blank_image[y2:y2+h2, x2:x2+w2] = img2

        bboxes = [
            [0, (x1 + w1 / 2) / self.image_size[1], (y1 + h1 / 2) / self.image_size[0], w1 / self.image_size[1], h1 / self.image_size[0]],
            [0, (x2 + w2 / 2) / self.image_size[1], (y2 + h2 / 2) / self.image_size[0], w2 / self.image_size[1], h2 / self.image_size[0]],
        ]
        return blank_image, bboxes

    def save_processed_images_and_labels(self, num_samples):
        if len(self.image_paths) < 2:
            raise ValueError("Dataset must contain at least two images.")

        for i in range(num_samples):
            if len(self.image_paths) >= 2:
                img1_path, img2_path = random.sample(self.image_paths, 2)
            else:
                img1_path = random.choice(self.image_paths)
                img2_path = random.choice(self.image_paths)

            img1 = cv2.imread(img1_path, cv2.IMREAD_UNCHANGED)
            img2 = cv2.imread(img2_path, cv2.IMREAD_UNCHANGED)

            # If images are grayscale, convert to BGR for uniformity
            if len(img1.shape) == 2:
                img1 = cv2.cvtColor(img1, cv2.COLOR_GRAY2BGR)
            if len(img2.shape) == 2:
                img2 = cv2.cvtColor(img2, cv2.COLOR_GRAY2BGR)

            cropped_img1 = self._crop_fingerprint(img1)
            cropped_img2 = self._crop_fingerprint(img2)

            final_image, bboxes = self._paste_on_blank(cropped_img1, cropped_img2)

            image_name = f"image_{i:04d}.jpg"
            image_path = os.path.join(self.images_save_dir, image_name)
            cv2.imwrite(image_path, final_image)

            label_name = f"image_{i:04d}.txt"
            label_path = os.path.join(self.labels_save_dir, label_name)
            with open(label_path, 'w') as f:
                for bbox in bboxes:
                    f.write(" ".join(map(str, bbox)) + "\n")