In [97]:
import json
from pathlib import Path
import random
import shutil
import os

In [98]:
cropped_images_path = Path("/workspaces/gorillatracker/datasets/vast/external-datasets/LeopardID2022/cropped_images/all")
annotation_path = Path("/workspaces/gorillatracker/datasets/vast/external-datasets/LeopardID2022/leopard.coco/annotations/instances_train2022.json")
image_path = Path("/workspaces/gorillatracker/datasets/vast/external-datasets/LeopardID2022/leopard.coco/images/train2022")

In [99]:
def load_json(path: Path) -> dict:
    with open(path, "r") as f:
        data = json.load(f)
    return data

annotation_dict = load_json(annotation_path)
print("Given number of images: ", len(annotation_dict["images"]))
print("Given number of annotations: ", len(annotation_dict["annotations"]))
image_files = [file for file in image_path.glob("*.jpg") if not file.name.startswith("._")]
print("Actual number of images: ", len(image_files))
print("Actual number of cropped images: ", len(list(cropped_images_path.glob("*.png"))))

Given number of images:  6795
Given number of annotations:  6825
Actual number of images:  5848
Actual number of cropped images:  5823


In [100]:
# count number of images that have multiple annotations
image_ids = [ann["image_id"] for ann in annotation_dict["annotations"]]
duplicate_image_ids = set([image_id for image_id in image_ids if image_ids.count(image_id) > 1])
print("Number of images with multiple annotations: ", len(duplicate_image_ids))
# print first 8 duplicate image ids with their names 
for i, image_id in enumerate(list(duplicate_image_ids)[:8]):
    annotations = [ann for ann in annotation_dict["annotations"] if ann["image_id"] == image_id]
    print("Image_Id: ", image_id)
    for ann in annotations:
        print(ann["name"])

Number of images with multiple annotations:  30
Image_Id:  3076
81ad30d8-15c7-492e-918b-d9592d8f3f08
____
Image_Id:  1803
b3cadf4e-6dc5-4d1a-bc0e-adc6805d9329
____
Image_Id:  1807
b3cadf4e-6dc5-4d1a-bc0e-adc6805d9329
____
Image_Id:  2959
____
c3799f7d-5ba9-4266-beff-1cb38febbc86
Image_Id:  6551
96d393be-7d3b-429f-8e51-6048a8593b47
f1116150-8ea6-4a26-a0b8-90208a8a2248
Image_Id:  6552
96d393be-7d3b-429f-8e51-6048a8593b47
f1116150-8ea6-4a26-a0b8-90208a8a2248
Image_Id:  6553
96d393be-7d3b-429f-8e51-6048a8593b47
f1116150-8ea6-4a26-a0b8-90208a8a2248
Image_Id:  2456
c7fa3bc1-65ee-4028-86a7-a64b841d5dba
fc628bad-d8ce-4915-9cf1-29835d930e7c


In [101]:
def group_images(images):
    labels = {}
    for image in images:
        label = image.split("_")[0]
        if label not in labels:
            labels[label] = []
        labels[label].append(image)
    return labels

cropped_images = os.listdir(cropped_images_path)
labels_dict = group_images(cropped_images)
num_individuals = len(labels_dict)
print("Number of Individuals: ", num_individuals)
print(cropped_images[:5])

Number of Individuals:  413
['001e71df-ed06-43b6-bfdc-6beac644b1de_2099.png', '001e71df-ed06-43b6-bfdc-6beac644b1de_2100.png', '001e71df-ed06-43b6-bfdc-6beac644b1de_2107.png', '001e71df-ed06-43b6-bfdc-6beac644b1de_2108.png', '001e71df-ed06-43b6-bfdc-6beac644b1de_2109.png']


In [102]:
# Now filter out individuals with less than 5 images
filtered_individuals = {label: images for label, images in labels_dict.items() if len(images) > 4}
num_filtered_individuals = len(filtered_individuals)

print("Number of Individuals with more than 4 images: ", num_filtered_individuals)
print("Number of Individuals with 4 or less images: ", num_individuals - num_filtered_individuals)

Number of Individuals with more than 4 images:  167
Number of Individuals with 4 or less images:  246


In [103]:
def copy_images(filtered_individuals, max_images_per_individual, labels_dict, input_directory, target_directory):
    if not os.path.exists(target_directory):
        os.makedirs(target_directory)
    
    for individual in filtered_individuals:
        image_files = labels_dict[individual]
        if len(image_files) > max_images_per_individual:
            image_files = random.sample(image_files, max_images_per_individual)
        
        for image_file in image_files:
            src_path = os.path.join(input_directory, image_file)
            dst_path = os.path.join(target_directory, image_file)
            shutil.copy(src_path, dst_path)

filtered_individuals_path = "/workspaces/gorillatracker/datasets/vast/external-datasets/LeopardID2022/cropped_images/filtered"
max_images_per_individual = 50
copy_images(filtered_individuals, max_images_per_individual, labels_dict, cropped_images_path, filtered_individuals_path)