In [None]:
# augment_with_mice.py

import cv2
import numpy as np
import os
from tqdm import tqdm

def load_mouse_images(mouse_folder):
    """
    Loads all mouse images from a folder.
    """
    mouse_imgs = []
    for file in os.listdir(mouse_folder):
        if file.lower().endswith(('.png')):
            img = cv2.imread(os.path.join(mouse_folder, file), cv2.IMREAD_UNCHANGED)
            if img is not None:
                mouse_imgs.append(img)
    return mouse_imgs

def overlay_random_mouse(cat_img_path, mouse_imgs, output_dir, num_augmentations=3):
    cat_img = cv2.imread(cat_img_path, cv2.IMREAD_COLOR)  # load without alpha
    if cat_img is None:
        print(f"Failed to load cat image: {cat_img_path}")
        return

    cat_h, cat_w = cat_img.shape[:2]

    for i in range(num_augmentations):
        mouse_img = random.choice(mouse_imgs)

        if mouse_img.shape[2] == 4:
            mouse_color = mouse_img[:, :, :3]  # RGB channels
            mouse_alpha = mouse_img[:, :, 3] / 255.0  # Alpha channel normalized
        else:
            mouse_color = mouse_img
            mouse_alpha = np.ones(mouse_color.shape[:2], dtype=float)

        # Resize mouse
        scale_factor = random.uniform(0.1, 0.3)
        mouse_color = cv2.resize(mouse_color, (0, 0), fx=scale_factor, fy=scale_factor)
        mouse_alpha = cv2.resize(mouse_alpha, (0, 0), fx=scale_factor, fy=scale_factor)

        mouse_h, mouse_w = mouse_color.shape[:2]

        x_offset = random.randint(int(cat_w * 0.4), int(cat_w * 0.6))
        y_offset = random.randint(int(cat_h * 0.6), int(cat_h * 0.8))

        if y_offset + mouse_h > cat_h or x_offset + mouse_w > cat_w:
            continue

        # Region of interest on cat image
        roi = cat_img[y_offset:y_offset+mouse_h, x_offset:x_offset+mouse_w]

        # Blend mouse onto cat
        for c in range(0, 3):  # BGR channels
            roi[:, :, c] = (mouse_alpha * mouse_color[:, :, c] + (1 - mouse_alpha) * roi[:, :, c])

        cat_img[y_offset:y_offset+mouse_h, x_offset:x_offset+mouse_w] = roi

        # Save
        base_name = os.path.basename(cat_img_path).split('.')[0]
        out_path = os.path.join(output_dir, f"{base_name}_aug_{i}.jpg")
        cv2.imwrite(out_path, cat_img)
    
def batch_augment(cat_images_folder, mouse_images_folder, output_folder, num_aug_per_image=3):
    os.makedirs(output_folder, exist_ok=True)
    mouse_imgs = load_mouse_images(mouse_images_folder)

    cat_images = [f for f in os.listdir(cat_images_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for cat_img in tqdm(cat_images, desc="Augmenting images", unit="image"):
        full_cat_path = os.path.join(cat_images_folder, cat_img)
        overlay_random_mouse(full_cat_path, mouse_imgs, output_folder, num_augmentations=num_aug_per_image)

In [5]:
CAT_IMAGES_FOLDER = "data/elfie"
MOUSE_IMAGES_FOLDER = "data/mice"
OUTPUT_FOLDER = "data/do_not_open"

batch_augment(CAT_IMAGES_FOLDER, MOUSE_IMAGES_FOLDER, OUTPUT_FOLDER, num_aug_per_image=3)

Augmenting images: 100%|██████████| 343/343 [00:05<00:00, 58.18image/s]
