In [None]:
%pip install pycocotools

In [None]:
import os
import json
from pycocotools.coco import COCO

def filter_coco_data(coco, class_names):
    filtered_images = []
    filtered_annotations = []
    
    for class_name in class_names:
        cat_ids = coco.getCatIds(catNms=[class_name])
        img_ids = coco.getImgIds(catIds=cat_ids)
        ann_ids = coco.getAnnIds(catIds=cat_ids)
        
        images = coco.loadImgs(ids=img_ids)
        annotations = coco.loadAnns(ids=ann_ids)
        
        filtered_images.extend(images)
        filtered_annotations.extend(annotations)
    
    return filtered_images, filtered_annotations

# COCO 데이터셋 경로와 특정 클래스명을 설정합니다.
data_dir = './dataset/coco'
data_type = 'train2017' # 또는 'val2017' 등
ann_file = os.path.join(data_dir, 'annotations', f'instances_{data_type}.json')

coco = COCO(ann_file)

# 원하는 클래스명을 리스트에 추가하세요.
SUBCLASS = ['dog', 'cat']

filtered_images, filtered_annotations = filter_coco_data(coco, SUBCLASS)

# 새로운 JSON 파일을 생성합니다.
filter_name = "_".join(SUBCLASS)

new_ann_file = os.path.join(data_dir, 'annotations', f'instances_{data_type}_filtered_{filter_name}.json')
with open(new_ann_file, 'w') as f:
    json.dump({
        'images': filtered_images,
        'annotations': filtered_annotations,
        'categories': coco.loadCats(coco.getCatIds())
    }, f)
