In [6]:
import os
import json
import shutil
import random
# 类别映射字典：将原类别 ID 映射为新的 1-5
category_mapping = {
    87: 0,     # belt -> 1
    1034: 1,   # sunglasses -> 2
    131: 2,    # boot -> 3
    318: 3,    # cowboy_hat -> 4
    588: 4     # jacket -> 5
}

# 假设你原始的标注文件路径
json_file = r'C:\data\program\python\CowBoy\test\add_json\all_json.json'

with open(json_file, 'r') as f:
    data = json.load(f)
data['annotations']

[{'id': 0,
  'image_id': 1148038,
  'iscrowd': False,
  'bbox': [62.2, 166.03, 698.5, 263.06],
  'category_id': 87,
  'area': 183747.41},
 {'id': 1,
  'image_id': 1148038,
  'iscrowd': False,
  'bbox': [94.67, 98.62, 539.51, 258.13],
  'category_id': 87,
  'area': 139263.72},
 {'id': 2,
  'image_id': 902882,
  'iscrowd': False,
  'bbox': [12.91, 106.63, 715.43, 315.25],
  'category_id': 87,
  'area': 225539.31},
 {'id': 3,
  'image_id': 902882,
  'iscrowd': False,
  'bbox': [142.49, 25.85, 552.59, 309.34],
  'category_id': 87,
  'area': 170938.19},
 {'id': 4,
  'image_id': 948578,
  'iscrowd': False,
  'bbox': [85.12, 250.92, 922.24, 307.25],
  'category_id': 87,
  'area': 283358.24},
 {'id': 5,
  'image_id': 948578,
  'iscrowd': False,
  'bbox': [128.0, 172.19, 712.32, 301.49],
  'category_id': 87,
  'area': 214757.36},
 {'id': 6,
  'image_id': 1502240,
  'iscrowd': False,
  'bbox': [31.1, 183.6, 733.21, 292.4],
  'category_id': 87,
  'area': 214390.6},
 {'id': 7,
  'image_id': 150224

In [7]:
len(data['annotations'])

15798

In [8]:
random.seed(34)

ci = [87, 1034, 131, 318, 588]  # category_id, 分别对应belt,sunglasses,boot,cowboy_hat,jacket
ann = data['annotations']
random.shuffle(ann)
random.shuffle(data['images'])
print('total:')
for i in ci:
    count = 0
    for j in ann:
        if j['category_id'] == i:
            count += 1
    print(f'id: {i} counts: {count}')

total:
id: 87 counts: 2025
id: 1034 counts: 2330
id: 131 counts: 4017
id: 318 counts: 5231
id: 588 counts: 2195


In [9]:

total_id = set(each['image_id'] for each in ann)
val_id = set()
a, b, c, d, e = 0, 0, 0, 0, 0  # 用于每类的计数
for each in ann:
    if (each['category_id'] == ci[0]) and (a < 100):
        val_id.add(each['image_id'])
        a += 1
    elif (each['category_id'] == ci[1]) and (b < 100):
        val_id.add(each['image_id'])
        b += 1
    elif (each['category_id'] == ci[2]) and (c < 100):
        val_id.add(each['image_id'])
        c += 1
    elif (each['category_id'] == ci[3]) and (d < 100):
        val_id.add(each['image_id'])
        d += 1
    elif (each['category_id'] == ci[4]) and (e < 100):
        val_id.add(each['image_id'])
        e += 1

val_ann = []
for imid in val_id:
    for each_ann in ann:
        if each_ann['image_id'] == imid:
            val_ann.append(each_ann)
            
len(val_id),len(val_ann)

(488, 1493)

In [10]:
print('val set:')
for kind in ci:
    num = 0
    for i in val_ann:
        if i['category_id'] == kind:
            num += 1
    print(f'id: {kind} counts: {num}')

val set:
id: 87 counts: 126
id: 1034 counts: 247
id: 131 counts: 370
id: 318 counts: 436
id: 588 counts: 314


In [11]:
# The rest images are for training
train_id = total_id - val_id
train_ann = []
for each_ann in ann:
    for tid in train_id:
        if each_ann['image_id'] == tid:
            train_ann.append(each_ann)
            break
len(train_id), len(train_ann)

(8669, 14305)

In [12]:
os.makedirs('./datasets/coco/train2017', exist_ok=True)
os.makedirs('./datasets/coco/val2017', exist_ok=True)
          
train_img = []
# Move train images
for j in data['images']:
    for i in train_id:
        if j['id'] == i:
            shutil.copy(r'C:\data\program\python\CowBoy\test\add/'+j['file_name'], './datasets/coco/train2017')
            train_img.append(j)
            
val_img = []
# Move val images
for j in data['images']:
    for i in val_id:
        if j['id'] == i:
            shutil.copy(r'C:\data\program\python\CowBoy\test\add/'+j['file_name'], './datasets/coco/val2017')
            val_img.append(j)

len(val_img), len(train_img)

(488, 8670)

In [13]:
os.makedirs('./datasets/coco/annotations', exist_ok=True)
# 用于存储训练集和验证集的标注数据
train_annotations = []
val_annotations = []

# 遍历标注数据，根据image_id将数据分类到train或val
for annotation in data['annotations']:
    image_id = annotation['image_id']
        # 如果类别ID在category_mapping中，进行映射
    if annotation['category_id'] in category_mapping:
        annotation['category_id'] = category_mapping[annotation['category_id']]
    if image_id in train_id:
        train_annotations.append(annotation)
    elif image_id in val_id:
        val_annotations.append(annotation)

# 替换categories中的id
for category in data['categories']:
    if category['id'] in category_mapping:
        category['id'] = category_mapping[category['id']]

# 创建包含image信息和类别的字典（类似于COCO格式）
train_json = {
    'info': data['info'],
    'images': train_img,
    'annotations': train_annotations,  # 这里添加训练集的标注数据
    'categories': data['categories']
}

val_json = {
    'info': data['info'],
    'images': val_img,
    'annotations': val_annotations,
    'categories': data['categories']
}

# 保存为train.json和val.json
with open('./datasets/coco/annotations/instances_train2017.json', 'w') as train_file:
    json.dump(train_json, train_file, indent=4)

with open('./datasets/coco/annotations/instances_val2017.json', 'w') as val_file:
    json.dump(val_json, val_file, indent=4)

print("train.json 和 val.json 已生成")
train_json


train.json 和 val.json 已生成


{'info': {'description': 'CowboySuit',
  'url': 'http://github.com/dmlc/gluon-cv',
  'version': '1.0',
  'year': 2021,
  'contributor': 'GluonCV/AutoGluon',
  'date_created': '2021/07/01'},
 'images': [{'id': 1897635,
   'file_name': 'add_0_131_394a6fd791ec5039.jpg',
   'neg_category_ids': [277, 355, 433],
   'pos_category_ids': [45, 434],
   'width': 768,
   'height': 768,
   'source': 'OpenImages'},
  {'id': 1661912,
   'file_name': 'add_12_87_03d1492ceb1b94a1.jpg',
   'neg_category_ids': [247, 293, 308, 333, 391, 409, 433, 434],
   'pos_category_ids': [34, 334],
   'width': 768,
   'height': 512,
   'source': 'OpenImages'},
  {'id': 543794809731845226,
   'file_name': '078bf27f90c48c6a.jpg',
   'neg_category_ids': [514],
   'pos_category_ids': [69, 211, 216, 228, 308, 433, 434, 462],
   'width': 683,
   'height': 1024,
   'source': 'OpenImages'},
  {'id': 2373721,
   'file_name': 'add_6_131_3fc57e737d0e631a.jpg',
   'neg_category_ids': [32],
   'pos_category_ids': [45, 58, 69, 433, 

In [14]:
len(train_json['annotations'])

14305