# Sampling CoCo image dataset to train and use DPM on

In [44]:
import json
import random
from pycocotools.coco import COCO
import requests
import os
import shutil

from xml.etree.ElementTree import Element, SubElement, ElementTree, tostring 
from xml.dom.minidom import parseString 

We clear the folders beforehand

In [32]:
def clear_folder(folder_path):
    """
    Clears all files and subdirectories in the specified folder.

    Args:
        folder_path (str): Path to the folder to clear.
    """
    if os.path.exists(folder_path):
        # Remove all contents of the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or symbolic link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory
            except Exception as e:
                print(f"Failed to delete {file_path}. Reason: {e}")
    else:
        # If folder doesn't exist, create it
        os.makedirs(folder_path)


Helper function to make PASCAL VOC-style XML annotations.

In [56]:
# Function to create Pascal-style XML annotations
def create_pascal_xml(img_info, annotations, output_annotation_dir):
    xml_root = Element('annotation')
    folder = SubElement(xml_root, 'folder')
    folder.text = 'VOC_COCO'

    filename = SubElement(xml_root, 'filename')
    filename.text = img_info['file_name']

    size = SubElement(xml_root, 'size')
    SubElement(size, 'width').text = str(img_info['width'])
    SubElement(size, 'height').text = str(img_info['height'])
    SubElement(size, 'depth').text = '3'  # Assuming RGB images

    for ann in annotations:
        obj = SubElement(xml_root, 'object')
        name = SubElement(obj, 'name')
        name.text = 'person'  # Category name

        # Pose (defaulted to 'Unspecified')
        pose = SubElement(obj, 'pose')
        pose.text = 'Unspecified'

        # Truncated (default 0, set to 1 if bbox exceeds image boundaries)
        truncated = SubElement(obj, 'truncated')
        bbox = ann['bbox']  # COCO format: [xmin, ymin, width, height]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[0] + bbox[2]
        y_max = bbox[1] + bbox[3]
        is_truncated = (
            x_min < 0 or y_min < 0 or x_max > img_info['width'] or y_max > img_info['height']
        )
        truncated.text = '1' if is_truncated else '0'
        # TODO: record how many truncated image data points we have? 

        # Bounding box
        bndbox = SubElement(obj, 'bndbox')
        SubElement(bndbox, 'xmin').text = str(max(0, int(x_min)))  # Clip to image boundaries
        SubElement(bndbox, 'ymin').text = str(max(0, int(y_min)))
        SubElement(bndbox, 'xmax').text = str(min(img_info['width'], int(x_max)))
        SubElement(bndbox, 'ymax').text = str(min(img_info['height'], int(y_max)))

        # Keypoints (optional)
        keypoints_elem = SubElement(obj, 'keypoints')
        keypoints = ann.get('keypoints', [])
        # if keypoints:
        #     print(f"image {img_info['file_name']} has keypoints")
        # else:
        #     print(f"image {img_info['file_name']} does NOT have keypoints")
        for i in range(0, len(keypoints), 3):
            kp_x, kp_y, visibility = keypoints[i:i+3]
            keypoint = SubElement(keypoints_elem, 'keypoint')
            SubElement(keypoint, 'x').text = str(int(kp_x)) if visibility > 0 else 'NaN'
            SubElement(keypoint, 'y').text = str(int(kp_y)) if visibility > 0 else 'NaN'
            SubElement(keypoint, 'visibility').text = str(visibility)

    # Pretty format 
    asstring = tostring(xml_root, 'utf-8')
    parsed_xml = parseString(asstring)
    pretty_xml = parsed_xml.toprettyxml(indent="    ")

    # Save XML
    output_file = os.path.join(output_annotation_dir, f"{img_info['file_name'].split('.')[0]}.xml")
    # tree = ElementTree(xml_root)
    # tree.write(output_file)
    with open(output_file, 'w') as f:
        f.write(pretty_xml)

In [61]:
# downloading 2000 images of training data for category person
# Paths
annotation_file = 'annotations/person_keypoints_train2017.json'  # Update with your COCO annotation file path
output_dir = 'coco_output'  # Folder to save downloaded images
annotations_dir = os.path.join(output_dir, 'Annotations')
sets_dir = os.path.join(output_dir, 'ImageSets', 'Main')
images_dir = os.path.join(output_dir, 'JPEGImages')
# os.makedirs(output_dir, exist_ok=True)

# Number of images to sample
num_images = 10

In [62]:
# Load COCO annotations
coco = COCO(annotation_file)

loading annotations into memory...
Done (t=8.12s)
creating index...
index created!


In [63]:
# Get category ID for 'person'
person_category_id = coco.getCatIds(catNms=['person'])[0]

# Get all image IDs containing persons
person_image_ids = coco.getImgIds(catIds=[person_category_id])

In [64]:
# Setting number of images & instances to be sampled
# sampled_image_ids = random.sample(person_image_ids, num_images)
instance_target = 50
image_count = 0

# Randomly sample images and instances
random.seed(429)
random.shuffle(person_image_ids)
selected_image_ids = []
instance_count = 0

for img_id in person_image_ids:
    ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=[person_category_id])
    person_count = len(ann_ids)  # in this image

    if instance_count + person_count < instance_target:
        selected_image_ids.append(img_id)
        instance_count += person_count
    else:
        break  # target instances fulfilled 

In [65]:
len(selected_image_ids)

14

In [66]:
# To freshly download (erase existing data) 
clear_folder(annotations_dir)
clear_folder(sets_dir)
clear_folder(images_dir)

# Download sampled images
for img_id in selected_image_ids:
    # load image info and annotations
    img_info = coco.loadImgs(img_id)[0]
    ann_ids = coco.getAnnIds(imgIds=[img_id], catIds=[person_category_id])
    annotations = coco.loadAnns(ann_ids)

    # download image
    img_url = img_info['coco_url']
    img_filename = os.path.join(output_dir, 'JPEGImages', img_info['file_name'])
    
    # Download image
    response = requests.get(img_url, stream=True)
    if response.status_code == 200:
        with open(img_filename, 'wb') as f:
            for chunk in response.iter_content(1024):
                f.write(chunk)
        # if image downloaded then we save the annotation
        create_pascal_xml(img_info, annotations, annotations_dir)
        print(f"Downloaded {img_info['file_name']}")

print(f"Downloaded {num_images} person images to {output_dir}")

image 000000155268.jpg has keypoints
image 000000155268.jpg has keypoints
Downloaded 000000155268.jpg
image 000000407626.jpg has keypoints
Downloaded 000000407626.jpg
image 000000181886.jpg has keypoints
image 000000181886.jpg has keypoints
image 000000181886.jpg has keypoints
Downloaded 000000181886.jpg
image 000000263041.jpg has keypoints
image 000000263041.jpg has keypoints
image 000000263041.jpg has keypoints
Downloaded 000000263041.jpg
image 000000358767.jpg has keypoints
Downloaded 000000358767.jpg
image 000000568956.jpg has keypoints
image 000000568956.jpg has keypoints
image 000000568956.jpg has keypoints
image 000000568956.jpg has keypoints
Downloaded 000000568956.jpg
image 000000256431.jpg has keypoints
image 000000256431.jpg has keypoints
image 000000256431.jpg has keypoints
image 000000256431.jpg has keypoints
image 000000256431.jpg has keypoints
image 000000256431.jpg has keypoints
Downloaded 000000256431.jpg
image 000000052281.jpg has keypoints
Downloaded 000000052281.jpg