In [1]:
import json
import random
from pycocotools.coco import COCO

In [2]:
def split_data(ann_file, valid_rate=.15, target_root="./", seed=3407):
    """ 根据类别划分COCO数据集
        0 为背景类
    
    Args:
        ann_file (str): coco的ann文件
        valid_rate (float): 验证集所占比例
        target_root (str): 目标文件目录，分别生成 train_coco.json 和 valid_coco.json
        seed (int): 随机种子
    """
    coco = COCO(ann_file)
    num_classes = len(coco.cats)
    
    train_coco = {
        "images": [],
        "annotations": [],
        "categories": coco.loadCats(list(range(num_classes)))
    }
    val_coco = {
        "images": [],
        "annotations": [],
        "categories": coco.loadCats(list(range(num_classes)))
    }
    
    # 从类别数最少的类开始采样
    """
    cats_count (List[Tuple[2]]): [(类别ID, 类别数量)]
    """
    cats_count = [
        (i, len(coco.getImgIds(catIds=[i])))
        for i in range(1, num_classes)
    ]
    cats_count.sort(key=lambda x: x[-1])
    
    train_count = [0, 0]
    valid_count = [0, 0]
    used_image_id = set()
    for cat_info in cats_count:
        image_id = [img_id for img_id in coco.getImgIds(catIds=[cat_info[0]]) if img_id not in used_image_id]
        used_image_id |= set(image_id)
        random.seed(seed)
        random.shuffle(image_id)
        val_sample = int(len(image_id) * valid_rate)
        val_img_id = image_id[:val_sample]
        train_img_id = image_id[val_sample:]
        
        # # # # #
        # Train #
        # # # # #
        train_image_info = coco.loadImgs(train_img_id)
        train_ann_info = coco.loadAnns(coco.getAnnIds(imgIds=train_img_id))
        train_coco["images"] += train_image_info
        train_coco["annotations"] += train_ann_info
        
        # # # # #
        # Valid #
        # # # # #
        val_image_info = coco.loadImgs(val_img_id)
        val_ann_info = coco.loadAnns(coco.getAnnIds(imgIds=val_img_id))
        val_coco["images"] += val_image_info
        val_coco["annotations"] += val_ann_info
        
        seed += 1
        print(f"类别 - {coco.cats[cat_info[0]]['name']}: \n"
              f"\t Train\n"
              f"\t\t Image: {len(train_image_info)}\n"
              f"\t\t Bbox: {len(train_ann_info)}\n"
              f"\t Valid\n"
              f"\t\t Image: {len(val_image_info)}\n"
              f"\t\t Bbox: {len(val_ann_info)}\n"
        )
        
        train_count[0] += len(train_image_info)
        train_count[1] += len(train_ann_info)
        valid_count[0] += len(val_image_info)
        valid_count[1] += len(val_ann_info)
    
    print(
        f"训练集共有图片{train_count[0]}张，标定框{train_count[1]}个\n"
        f"验证集共有图片{valid_count[0]}张，标定框{valid_count[1]}个\n"
    )
    with open(f"{target_root}/train_coco.json", "w") as f:
        json.dump(train_coco, f)
    with open(f"{target_root}/valid_coco.json", "w") as f:
        json.dump(val_coco, f)

In [3]:
split_data("sift_data/coco_ann.json", target_root="sift_data/")

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
类别 - pressure: 
	 Train
		 Image: 179
		 Bbox: 418
	 Valid
		 Image: 31
		 Bbox: 64

类别 - metalcup: 
	 Train
		 Image: 204
		 Bbox: 333
	 Valid
		 Image: 36
		 Bbox: 56

类别 - tongs: 
	 Train
		 Image: 230
		 Bbox: 497
	 Valid
		 Image: 40
		 Bbox: 82

类别 - lighter: 
	 Train
		 Image: 254
		 Bbox: 520
	 Valid
		 Image: 44
		 Bbox: 85

类别 - laptop: 
	 Train
		 Image: 280
		 Bbox: 386
	 Valid
		 Image: 49
		 Bbox: 64

类别 - scissor: 
	 Train
		 Image: 318
		 Bbox: 439
	 Valid
		 Image: 55
		 Bbox: 79

类别 - knife: 
	 Train
		 Image: 344
		 Bbox: 503
	 Valid
		 Image: 60
		 Bbox: 97

类别 - umbrella: 
	 Train
		 Image: 868
		 Bbox: 1019
	 Valid
		 Image: 153
		 Bbox: 168

类别 - glassbottle: 
	 Train
		 Image: 1150
		 Bbox: 1945
	 Valid
		 Image: 202
		 Bbox: 323

训练集共有图片3827张，标定框6060个
验证集共有图片670张，标定框1018个

