# Creating Balanced Image Datasets from COCO

This notebook creates balanced datasets from COCO val2017 by:
1. Finding images with significant coverage (>5%) of specific categories
2. Pairing each with 99 random images that don't contain that category
3. Storing the indices to create balanced datasets (1:99 ratio)

## 1. Import Libraries

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
import skimage.io as io
import random
import json
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

## 2. Set Up Paths & Parameters

In [None]:
# Directory containing COCO dataset - update these paths to match your local setup
dataDir = '/your-path'
dataType = 'val2017'
annFile = f'/your-path/coco_annotations/instances_{dataType}.json'
imgDir = f'{dataDir}/{dataType}/'

# Coverage threshold (as a percentage of image area)
coverage_threshold = 5.0  # Minimum coverage percentage

# Number of negative examples to select for each positive example
num_negative_examples = 99  # 99 negative examples + 1 positive = 100 total

# Random seed for reproducibility
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)

## 3. Initialize COCO API

In [None]:
# Initialize COCO API for instance annotations
coco = COCO(annFile)

# Display dataset info
print(f"COCO {dataType} dataset loaded successfully!")
print(f"Number of images: {len(coco.imgs)}")
print(f"Number of categories: {len(coco.cats)}")
print(f"Number of annotations: {len(coco.anns)}")

## 4. Get All Categories

In [None]:
# Get all categories
categories = coco.loadCats(coco.getCatIds())
print(f"COCO has {len(categories)} categories:")

# Display categories in a more readable format
for i, cat in enumerate(categories):
    print(f"{i+1}. ID: {cat['id']}, Name: {cat['name']}, Supercategory: {cat['supercategory']}")

## 5. Define Utility Functions

In [None]:
def calculate_category_coverage(coco, img_id, cat_id):
    """
    Calculate what percentage of the image area is covered by a specific category.
    
    Args:
        coco: COCO API instance
        img_id: Image ID
        cat_id: Category ID
        
    Returns:
        float: Coverage percentage (0-100)
        list: Annotations for the category in this image
    """
    # Get image info
    img_info = coco.loadImgs(img_id)[0]
    image_area = img_info['width'] * img_info['height']
    
    # Get annotations for this category in this image
    ann_ids = coco.getAnnIds(imgIds=img_id, catIds=cat_id)
    anns = coco.loadAnns(ann_ids)
    
    if not anns:
        return 0.0, []
    
    # Calculate total area covered by annotations
    total_area = sum(ann['area'] for ann in anns)
    
    # Calculate coverage percentage
    coverage_percent = (total_area / image_area) * 100
    
    return coverage_percent, anns

def visualize_sample(coco, dataset_info, category_name, num_samples=3):
    """
    Visualize sample images from a dataset.
    
    Args:
        coco: COCO API instance
        dataset_info: Dictionary with dataset information
        category_name: Name of the category
        num_samples: Number of samples to visualize
    """
    if not dataset_info['positive_examples']:
        print(f"No positive examples found for category: {category_name}")
        return
    
    # Get category ID
    cat_id = dataset_info['category_id']
    
    # Select a few positive examples
    samples = min(num_samples, len(dataset_info['positive_examples']))
    positive_samples = random.sample(dataset_info['positive_examples'], samples)
    
    # Select an equal number of negative examples
    negative_samples = random.sample(dataset_info['negative_examples'], samples)
    
    # Create figure with subplots
    fig, axes = plt.subplots(samples, 2, figsize=(12, 5*samples))
    if samples == 1:
        axes = axes.reshape(1, 2)
    
    for i in range(samples):
        # Positive example
        pos_img_id = positive_samples[i]
        pos_img_info = coco.loadImgs(pos_img_id)[0]
        pos_img_path = os.path.join(imgDir, pos_img_info['file_name'])
        pos_img = io.imread(pos_img_path)
        
        # Get annotations for the category
        pos_ann_ids = coco.getAnnIds(imgIds=pos_img_id, catIds=cat_id)
        pos_anns = coco.loadAnns(pos_ann_ids)
        
        # Calculate coverage
        pos_coverage, _ = calculate_category_coverage(coco, pos_img_id, cat_id)
        
        # Plot positive example
        axes[i, 0].imshow(pos_img)
        axes[i, 0].set_title(f"Positive Example\nID: {pos_img_id}, Coverage: {pos_coverage:.1f}%")
        axes[i, 0].axis('off')
        
        # Draw annotations on positive example
        for ax in fig.axes:
            if ax == axes[i, 0]:
                coco.showAnns(pos_anns, draw_bbox=True)
        
        # Negative example
        neg_img_id = negative_samples[i]
        neg_img_info = coco.loadImgs(neg_img_id)[0]
        neg_img_path = os.path.join(imgDir, neg_img_info['file_name'])
        neg_img = io.imread(neg_img_path)
        
        # Plot negative example
        axes[i, 1].imshow(neg_img)
        axes[i, 1].set_title(f"Negative Example\nID: {neg_img_id}, No {category_name}")
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.suptitle(f"{category_name} Examples (Coverage Threshold: {coverage_threshold}%)", 
                 fontsize=16, y=1.02)
    plt.show()

## 6. Process All Categories to Create Balanced Datasets

In [None]:
# Create a dictionary to store all datasets
new_data = []
# Process each category
for category in tqdm(categories, desc="Processing Categories"):
    cat_id = category['id']
    cat_name = category['name']
    if category['supercategory'] == 'animal':
        continue
    
    # Initialize dataset info for this category
    dataset_info = {
        'category_id': cat_id,
        'category_name': cat_name,
        'positive_examples': [],
        'negative_examples': [],
        'dataset_indices': []
    }
    
    # Get all image IDs containing this category
    img_ids_with_category = coco.getImgIds(catIds=cat_id)
    
    # For each image containing the category, check if coverage > threshold
    images_with_significant_coverage = []
    
    for img_id in img_ids_with_category:
        coverage, anns = calculate_category_coverage(coco, img_id, cat_id)
        if coverage >= coverage_threshold:
            images_with_significant_coverage.append(img_id)
            dataset_info['positive_examples'].append(img_id)
    
    # Get all image IDs that do NOT contain this category
    all_img_ids = list(coco.imgs.keys())
    img_ids_without_category = list(set(all_img_ids) - set(img_ids_with_category))
    dataset_info['negative_examples'] = img_ids_without_category
    
    # For each positive example, randomly select negative examples
    for i, pos_img_id in enumerate(images_with_significant_coverage):
        # If we don't have enough negative examples, use all available with replacement
        if len(img_ids_without_category) < num_negative_examples:
            neg_img_ids = random.choices(img_ids_without_category, k=num_negative_examples)
        else:
            neg_img_ids = random.sample(img_ids_without_category, num_negative_examples)
        
        # Create dataset with 1 positive + 99 negative examples
        dataset_indices = [pos_img_id] + neg_img_ids

        new_item = {}
        new_item['qry_text'] = f"Find me an image that contains any {cat_name}.\n" # the scene of ...
        new_item['qry_img_path'] = ''
        new_item['tgt_text'] = "<|image_1|> Represent the given image."
        new_item['tgt_img_path'] = ["val2017/{:012d}.jpg".format(img_id) for img_id in dataset_indices]
        new_data.append(new_item)



In [None]:
len(new_data)

In [None]:
new_data[0]

In [None]:
import json

with open('COCO_object_retrieval.json', 'w') as f:
    json.dump(new_data, f, indent=4)

In [None]:
from datasets import load_dataset
# Test whether we can load it using load_dataset
new_eval_data = load_dataset('json', 
                      data_files='COCO_object_retrieval.json',
                      split="train")

In [None]:
# Define animal and non-living object categories
animal_categories = ['cat', 'dog', 'bird', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe']

# Filter the dataset
animal_data = []
non_living_data = []

for item in new_eval_data:
    if 'person' in item['qry_text']:
        continue  # Skip rows related to the "person" category
    if 'teddy bear' in item['qry_text'] or 'hot dog' in item['qry_text']:
        non_living_data.append(item)
    elif any(animal in item['qry_text'] for animal in animal_categories):
        animal_data.append(item)
    else:
        non_living_data.append(item)

# Save the filtered data to JSON files
with open('COCO_animal_retrieval.json', 'w') as f:
    json.dump(animal_data, f, indent=4)

with open('COCO_object_retrieval_new.json', 'w') as f:
    json.dump(non_living_data, f, indent=4)