# Sampling CoCo image dataset to train and use DPM on

In [1]:
import json
import random
from pycocotools.coco import COCO
import requests
import os
import shutil
from tqdm import tqdm

from xml.etree.ElementTree import Element, SubElement, ElementTree, tostring 
from xml.dom.minidom import parseString 

We clear the folders beforehand

In [2]:
def clear_folder(folder_path):
    """
    Clears all files and subdirectories in the specified folder.

    Args:
        folder_path (str): Path to the folder to clear.
    """
    if os.path.exists(folder_path):
        # Remove all contents of the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or symbolic link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory
            except Exception as e:
                print(f"Failed to delete {file_path}. Reason: {e}")
    else:
        # If folder doesn't exist, create it
        os.makedirs(folder_path)


Helper function to make PASCAL VOC-style XML annotations.

In [3]:
# Function to create Pascal-style XML annotations
def create_pascal_xml(img_info, annotations, output_annotation_dir):
    xml_root = Element('annotation')
    folder = SubElement(xml_root, 'folder')
    folder.text = 'VOC_COCO'

    filename = SubElement(xml_root, 'filename')
    filename.text = img_info['file_name']

    size = SubElement(xml_root, 'size')
    SubElement(size, 'width').text = str(img_info['width'])
    SubElement(size, 'height').text = str(img_info['height'])
    SubElement(size, 'depth').text = '3'  # Assuming RGB images

    for ann in annotations:
        obj = SubElement(xml_root, 'object')
        name = SubElement(obj, 'name')
        name.text = 'person'  # Category name

        # Pose (defaulted to 'Unspecified')
        pose = SubElement(obj, 'pose')
        pose.text = 'Unspecified'

        # Truncated (default 0, set to 1 if bbox exceeds image boundaries)
        truncated = SubElement(obj, 'truncated')
        bbox = ann['bbox']  # COCO format: [xmin, ymin, width, height]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[0] + bbox[2]
        y_max = bbox[1] + bbox[3]
        is_truncated = (
            x_min < 0 or y_min < 0 or x_max > img_info['width'] or y_max > img_info['height']
        )
        truncated.text = '1' if is_truncated else '0'
        # TODO: record how many truncated image data points we have? 

        # Bounding box
        bndbox = SubElement(obj, 'bndbox')
        SubElement(bndbox, 'xmin').text = str(max(0, int(x_min)))  # Clip to image boundaries
        SubElement(bndbox, 'ymin').text = str(max(0, int(y_min)))
        SubElement(bndbox, 'xmax').text = str(min(img_info['width'], int(x_max)))
        SubElement(bndbox, 'ymax').text = str(min(img_info['height'], int(y_max)))

        # # Keypoints (optional)
        # keypoints_elem = SubElement(obj, 'keypoints')
        # keypoints = ann.get('keypoints', [])
        # # if keypoints:
        # #     print(f"image {img_info['file_name']} has keypoints")
        # # else:
        # #     print(f"image {img_info['file_name']} does NOT have keypoints")
        # for i in range(0, len(keypoints), 3):
        #     kp_x, kp_y, visibility = keypoints[i:i+3]
        #     keypoint = SubElement(keypoints_elem, 'keypoint')
        #     SubElement(keypoint, 'x').text = str(int(kp_x)) if visibility > 0 else 'NaN'
        #     SubElement(keypoint, 'y').text = str(int(kp_y)) if visibility > 0 else 'NaN'
        #     SubElement(keypoint, 'visibility').text = str(visibility)

    # Pretty format 
    asstring = tostring(xml_root, 'utf-8')
    parsed_xml = parseString(asstring)
    pretty_xml = parsed_xml.toprettyxml(indent="    ")

    # Save XML
    output_file = os.path.join(output_annotation_dir, f"{img_info['file_name'].split('.')[0]}.xml")
    # tree = ElementTree(xml_root)
    # tree.write(output_file)
    with open(output_file, 'w') as f:
        f.write(pretty_xml)

In [4]:
# downloading 2000 images of training data for category person
# Paths
keypoints_annotation_file = 'annotations/person_keypoints_train2017.json'  # Update with your COCO annotation file path
annotation_file = 'annotations/instances_train2017.json'  # Update with your COCO annotation file path
output_dir = 'coco_output'  # Folder to save downloaded images
annotations_dir = os.path.join(output_dir, 'Annotations')
sets_dir = os.path.join(output_dir, 'ImageSets', 'Main')
images_dir = os.path.join(output_dir, 'JPEGImages')
# os.makedirs(output_dir, exist_ok=True)

# Number of images to sample
num_images = 10

In [5]:
# Load COCO annotations
coco = COCO(annotation_file)

loading annotations into memory...
Done (t=12.51s)
creating index...
index created!


In [6]:
categories = [
    'airplane', 
    'bicycle', 
    'bird', 
    'boat', 
    'bottle', 
    'bus',
    'car',
    'cat',
    'chair',
    'cow',
    'dining table',
    'dog', 
    'horse', 
    'motorcycle', 
    'person', 
    'potted plant',
    'sheep',
    'couch',
    'train', 
    'TV',
]

In [7]:
category_ids = {cat['name']: cat['id'] for cat in coco.loadCats(coco.getCatIds()) if cat['name'] in categories}
# category_ids

In [16]:
# Sampling for non-person categories
num_images_per_category = 250
person_instance_target = 4100

## Clean all

In [17]:
# To freshly download (erase existing data) 
clear_folder(annotations_dir)
clear_folder(sets_dir)
clear_folder(images_dir)

## Sampling

In [18]:
all_images = set()
set.seed(429)

# Prepare ImageSets data
train_txt_path = os.path.join(sets_dir, "train.txt")
category_txt_files = {}  # Initialize category-specific text files

# Create category-specific text files
for category in categories:
    file_path = os.path.join(sets_dir, f"{category}_train.txt")
    category_txt_files[category] = open(file_path, "w")

# Dictionary to track images marked as positive for each category
positive_samples = {category: set() for category in categories}

# Process each category and collect positive samples
for category, category_id in category_ids.items():
    print(f"Processing category: {category}")

    # Adjust logic for 'person' category
    if category == 'person':
        instance_target = person_instance_target
        selected_image_ids = []
        instance_count = 0
        # Shuffle all image IDs containing 'person'
        image_ids = coco.getImgIds(catIds=[category_id])
        random.shuffle(image_ids)
        # Collect images until instance_target instances are reached
        for img_id in image_ids:
            ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=[category_id])
            person_count = len(ann_ids)  # Number of person instances in this image

            if instance_count + person_count <= instance_target:
                selected_image_ids.append(img_id)
                instance_count += person_count
            else:
                break
        print(f"Selected {len(selected_image_ids)} images to meet {instance_target} 'person' instances.")
    else:
        # Get all image IDs for the current category
        image_ids = coco.getImgIds(catIds=[category_id])
        random.shuffle(image_ids)
    
        # Select up to `num_images_per_category`
        selected_image_ids = image_ids[:num_images_per_category]
        print(f"Selected {len(selected_image_ids)} images for category {category}.")
    
    # Download and annotate
    for img_id in tqdm(selected_image_ids, desc=f"Downloading {category} images"):
        img_info = coco.loadImgs(img_id)[0]
        ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=[category_id])
        annotations = coco.loadAnns(ann_ids)

        img_filename = os.path.splitext(img_info['file_name'])[0]  # Remove the file extension
        img_filepath = os.path.join(images_dir, img_info['file_name'])
        all_images.add(img_filename)  # Add filename without extension

        try:
            # Skip if file already exists
            if not os.path.exists(img_filepath):
                response = requests.get(img_info['coco_url'], stream=True, timeout=10)
                response.raise_for_status()
                with open(img_filepath, 'wb') as f:
                    for chunk in response.iter_content(1024):
                        f.write(chunk)
            create_pascal_xml(img_info, annotations, annotations_dir)
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {img_info['file_name']}: {e}")

        # Track positive samples for the current category
        positive_samples[category].add(img_filename)  # Mark as positive for this category

    print(f"Completed category: {category}")

# Sort all image filenames (without extensions)
sorted_imgs = sorted(all_images)

# Write to category-specific and train.txt files
for img_filename in sorted_imgs:
    for category in categories:
        # Check if this image is positive or negative for the category
        label = "1" if img_filename in positive_samples[category] else "-1"
        category_txt_files[category].write(f"{img_filename} {label}\n")

# Write train.txt with all images (in sorted order)
with open(train_txt_path, "w") as train_file:
    for img_filename in sorted_imgs:
        train_file.write(img_filename + "\n")

# Close all category-specific text files
for file in category_txt_files.values():
    file.close()

print(f"ImageSets folder created at {sets_dir}.")

Processing category: person
Selected 980 images to meet 4100 'person' instances.


Downloading person images: 100%|████████████████████████████████████████████| 980/980 [02:31<00:00,  6.46it/s]


Completed category: person
Processing category: bicycle
Selected 250 images for category bicycle.


Downloading bicycle images: 100%|███████████████████████████████████████████| 250/250 [00:38<00:00,  6.57it/s]


Completed category: bicycle
Processing category: car
Selected 250 images for category car.


Downloading car images: 100%|███████████████████████████████████████████████| 250/250 [00:38<00:00,  6.45it/s]


Completed category: car
Processing category: motorcycle
Selected 250 images for category motorcycle.


Downloading motorcycle images: 100%|████████████████████████████████████████| 250/250 [00:37<00:00,  6.64it/s]


Completed category: motorcycle
Processing category: airplane
Selected 250 images for category airplane.


Downloading airplane images: 100%|██████████████████████████████████████████| 250/250 [00:39<00:00,  6.36it/s]


Completed category: airplane
Processing category: bus
Selected 250 images for category bus.


Downloading bus images: 100%|███████████████████████████████████████████████| 250/250 [00:38<00:00,  6.50it/s]


Completed category: bus
Processing category: train
Selected 250 images for category train.


Downloading train images: 100%|█████████████████████████████████████████████| 250/250 [00:38<00:00,  6.45it/s]


Completed category: train
Processing category: boat
Selected 250 images for category boat.


Downloading boat images: 100%|██████████████████████████████████████████████| 250/250 [00:38<00:00,  6.44it/s]


Completed category: boat
Processing category: bird
Selected 250 images for category bird.


Downloading bird images: 100%|██████████████████████████████████████████████| 250/250 [00:37<00:00,  6.66it/s]


Completed category: bird
Processing category: cat
Selected 250 images for category cat.


Downloading cat images: 100%|███████████████████████████████████████████████| 250/250 [00:36<00:00,  6.84it/s]


Completed category: cat
Processing category: dog
Selected 250 images for category dog.


Downloading dog images: 100%|███████████████████████████████████████████████| 250/250 [00:38<00:00,  6.56it/s]


Completed category: dog
Processing category: horse
Selected 250 images for category horse.


Downloading horse images: 100%|█████████████████████████████████████████████| 250/250 [00:39<00:00,  6.37it/s]


Completed category: horse
Processing category: sheep
Selected 250 images for category sheep.


Downloading sheep images: 100%|█████████████████████████████████████████████| 250/250 [00:42<00:00,  5.92it/s]


Completed category: sheep
Processing category: cow
Selected 250 images for category cow.


Downloading cow images: 100%|███████████████████████████████████████████████| 250/250 [00:42<00:00,  5.93it/s]


Completed category: cow
Processing category: bottle
Selected 250 images for category bottle.


Downloading bottle images: 100%|████████████████████████████████████████████| 250/250 [00:39<00:00,  6.35it/s]


Completed category: bottle
Processing category: chair
Selected 250 images for category chair.


Downloading chair images: 100%|█████████████████████████████████████████████| 250/250 [00:41<00:00,  6.03it/s]


Completed category: chair
Processing category: couch
Selected 250 images for category couch.


Downloading couch images: 100%|█████████████████████████████████████████████| 250/250 [00:36<00:00,  6.84it/s]


Completed category: couch
Processing category: potted plant
Selected 250 images for category potted plant.


Downloading potted plant images: 100%|██████████████████████████████████████| 250/250 [00:36<00:00,  6.78it/s]


Completed category: potted plant
Processing category: dining table
Selected 250 images for category dining table.


Downloading dining table images: 100%|██████████████████████████████████████| 250/250 [00:38<00:00,  6.50it/s]

Completed category: dining table
ImageSets folder created at coco_output/ImageSets/Main.



