In [None]:
import json
import os
from shutil import copy2

def split_coco_dataset_by_filename(coco_json_path, images_dir, output_dir, test_keyword):
    """
    按文件名关键字将 COCO 数据集划分为训练集和测试集
    :param coco_json_path: COCO 标注文件路径
    :param images_dir: 图片文件夹路径
    :param output_dir: 输出目录
    :param test_keyword: 测试集文件名中包含的关键字
    """
    # 创建输出目录
    train_images_dir = os.path.join(output_dir, 'train', 'images')
    test_images_dir = os.path.join(output_dir, 'test', 'images')
    os.makedirs(train_images_dir, exist_ok=True)
    os.makedirs(test_images_dir, exist_ok=True)

    # 加载 COCO 标注
    with open(coco_json_path, 'r') as f:
        coco_data = json.load(f)

    # 根据文件名关键字划分图片
    images = coco_data['images']
    test_images = [img for img in images if test_keyword in img['file_name']]
    train_images = [img for img in images if test_keyword not in img['file_name']]

    # 更新 COCO 标注
    def filter_annotations(images_subset):
        image_ids = {img['id'] for img in images_subset}
        return [ann for ann in coco_data['annotations'] if ann['image_id'] in image_ids]

    train_annotations = filter_annotations(train_images)
    test_annotations = filter_annotations(test_images)

    train_coco = {
        "images": train_images,
        "annotations": train_annotations,
        "categories": coco_data['categories']
    }
    test_coco = {
        "images": test_images,
        "annotations": test_annotations,
        "categories": coco_data['categories']
    }

    # 保存新的 COCO 标注文件
    train_annotations_path = os.path.join(output_dir, 'train', 'annotations.json')
    test_annotations_path = os.path.join(output_dir, 'test', 'annotations.json')
    with open(train_annotations_path, 'w') as f:
        json.dump(train_coco, f, indent=4)
    with open(test_annotations_path, 'w') as f:
        json.dump(test_coco, f, indent=4)

    # 复制图片到新目录
    def copy_images(images_subset, target_dir):
        for img in images_subset:
            src_path = os.path.join(images_dir, img['file_name'])
            dst_path = os.path.join(target_dir, img['file_name'])
            copy2(src_path, dst_path)

    copy_images(train_images, train_images_dir)
    copy_images(test_images, test_images_dir)

    print(f"数据集已成功划分：")
    print(f"训练集图片数：{len(train_images)}，标注数：{len(train_annotations)}")
    print(f"测试集图片数：{len(test_images)}，标注数：{len(test_annotations)}")

if __name__ == "__main__":
    # TODO: Update these paths to your dataset location
    coco_json_path = "./data/raw/augmented_annotations.json"  # 原始 COCO 格式标注文件路径
    images_dir = "./data/raw/"  # 图片文件夹路径
    output_dir = "./data/processed/"  # 输出目录路径

    # 文件名关键字，用于划分测试集
    test_keyword = "62e74158-000006"

    # 执行分割
    split_coco_dataset_by_filename(coco_json_path, images_dir, output_dir, test_keyword)