In [34]:
import cv2
import numpy as np
import os
import random
from tqdm import tqdm

def load_mouse_images(mouse_folder):
    mouse_imgs = []
    for file in os.listdir(mouse_folder):
        if file.lower().endswith('.png'):
            img = cv2.imread(os.path.join(mouse_folder, file), cv2.IMREAD_UNCHANGED)
            if img is not None:
                mouse_imgs.append(img)
    return mouse_imgs

def overlay_mouse_and_save_label(cat_img_path, mouse_imgs, output_img_dir, output_label_dir, num_augmentations=3):
    original_cat_img = cv2.imread(cat_img_path, cv2.IMREAD_COLOR)
    if original_cat_img is None:
        print(f"Failed to load cat image: {cat_img_path}")
        return

    cat_h, cat_w = original_cat_img.shape[:2]

    for i in range(num_augmentations):
        cat_img = original_cat_img.copy()

        mouse_img = random.choice(mouse_imgs)
        if mouse_img.shape[2] == 4:
            mouse_color = mouse_img[:, :, :3]
            mouse_alpha = mouse_img[:, :, 3]
        else:
            mouse_color = mouse_img
            mouse_alpha = np.ones(mouse_color.shape[:2], dtype=np.uint8) * 255

        scale_factor = random.uniform(0.1, 0.3)
        mouse_color = cv2.resize(mouse_color, (0, 0), fx=scale_factor, fy=scale_factor)
        mouse_alpha = cv2.resize(mouse_alpha, (0, 0), fx=scale_factor, fy=scale_factor)

        mouse_h, mouse_w = mouse_color.shape[:2]

        x_offset = random.randint(int(cat_w * 0.4), int(cat_w * 0.6))
        y_offset = random.randint(int(cat_h * 0.6), int(cat_h * 0.8))

        if y_offset + mouse_h > cat_h or x_offset + mouse_w > cat_w:
            continue

        roi = cat_img[y_offset:y_offset+mouse_h, x_offset:x_offset+mouse_w]
        alpha = mouse_alpha[:, :, np.newaxis] / 255.0
        blended = alpha * mouse_color + (1 - alpha) * roi
        cat_img[y_offset:y_offset+mouse_h, x_offset:x_offset+mouse_w] = blended.astype(np.uint8)

        # Save the new augmented image
        base_name = os.path.splitext(os.path.basename(cat_img_path))[0]
        out_img_path = os.path.join(output_img_dir, f"{base_name}_aug_{i}.jpg")
        cv2.imwrite(out_img_path, cat_img)

        # Generate YOLO label
        x_center = (x_offset + mouse_w/2) / cat_w
        y_center = (y_offset + mouse_h/2) / cat_h
        w_norm = mouse_w / cat_w
        h_norm = mouse_h / cat_h

        label_line = f"1 {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}\n"

        # Save the label
        out_label_path = os.path.join(output_label_dir, f"{base_name}_aug_{i}.txt")
        with open(out_label_path, 'w') as f:
            f.write(label_line)

def batch_augment_and_label(cat_images_folder, mouse_images_folder, output_img_folder, output_label_folder, num_aug_per_image=3):
    os.makedirs(output_img_folder, exist_ok=True)
    os.makedirs(output_label_folder, exist_ok=True)
    mouse_imgs = load_mouse_images(mouse_images_folder)

    cat_images = [f for f in os.listdir(cat_images_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    for cat_img in tqdm(cat_images, desc="Augmenting and labeling images", unit="image"):
        full_cat_path = os.path.join(cat_images_folder, cat_img)
        overlay_mouse_and_save_label(full_cat_path, mouse_imgs, output_img_folder, output_label_folder, num_aug_per_image)

In [35]:
CAT_IMAGES_FOLDER = "data/elfie"
MOUSE_IMAGES_FOLDER = "data/mice"
OUTPUT_IMG_FOLDER = "data/do_not_open/images"
OUTPUT_LABEL_FOLDER = "data/do_not_open/labels"

batch_augment_and_label(CAT_IMAGES_FOLDER, MOUSE_IMAGES_FOLDER, OUTPUT_IMG_FOLDER, OUTPUT_LABEL_FOLDER, num_aug_per_image=3)

Augmenting and labeling images: 100%|██████████| 343/343 [00:13<00:00, 24.90image/s]
