In [4]:
import os
import json
import shutil
import random
# 类别映射字典：将原类别 ID 映射为新的 1-5
category_mapping = {
    87: 1,     # belt -> 1
    1034: 2,   # sunglasses -> 2
    131: 3,    # boot -> 3
    318: 4,    # cowboy_hat -> 4
    588: 5     # jacket -> 5
}

# 假设你原始的标注文件路径
json_file = 'train.json'

with open(json_file, 'r') as f:
    data = json.load(f)
data

{'info': {'description': 'CowboySuit',
  'url': 'http://github.com/dmlc/gluon-cv',
  'version': '1.0',
  'year': 2021,
  'contributor': 'GluonCV/AutoGluon',
  'date_created': '2021/07/01'},
 'images': [{'id': 9860841628484337660,
   'file_name': '88d8bf3754317ffc.jpg',
   'neg_category_ids': [434],
   'pos_category_ids': [69, 161, 216, 277, 433],
   'width': 1024,
   'height': 681,
   'source': 'OpenImages'},
  {'id': 15984033263460081658,
   'file_name': 'ddd2b190ea90dffa.jpg',
   'neg_category_ids': [],
   'pos_category_ids': [45, 69, 277, 434],
   'width': 768,
   'height': 1024,
   'source': 'OpenImages'},
  {'id': 76077631043502082,
   'file_name': '010e4833cdb38002.jpg',
   'neg_category_ids': [308, 333, 404, 584],
   'pos_category_ids': [35, 69, 106, 228, 433, 502, 567],
   'width': 1024,
   'height': 683,
   'source': 'OpenImages'},
  {'id': 18065680256228130812,
   'file_name': 'fab6307a1a43fffc.jpg',
   'neg_category_ids': [514],
   'pos_category_ids': [69, 216, 228, 277, 333

In [5]:
data['annotations']

[{'id': 12550146,
  'image_id': 15526467552013451612,
  'freebase_id': '/m/017ftj',
  'category_id': 1034,
  'iscrowd': False,
  'bbox': [102.49, 181.12, 137.08, 97.92],
  'area': 13423.09},
 {'id': 9764874,
  'image_id': 12017556593931260822,
  'freebase_id': '/m/025rp__',
  'category_id': 318,
  'iscrowd': False,
  'bbox': [284.8, 297.6, 252.8, 192.0],
  'area': 48537.65},
 {'id': 13729810,
  'image_id': 17157301591781950087,
  'freebase_id': '/m/025rp__',
  'category_id': 318,
  'iscrowd': False,
  'bbox': [431.36, 277.25, 471.04, 314.39],
  'area': 148089.55},
 {'id': 13959212,
  'image_id': 17494374403857745282,
  'freebase_id': '/m/017ftj',
  'category_id': 1034,
  'iscrowd': False,
  'bbox': [195.76, 237.59, 363.09, 85.33],
  'area': 30983.2},
 {'id': 12910638,
  'image_id': 15979914912395313233,
  'freebase_id': '/m/01b638',
  'category_id': 131,
  'iscrowd': False,
  'bbox': [263.04, 368.64, 182.4, 238.72],
  'area': 43542.48},
 {'id': 12910640,
  'image_id': 15979914912395313

In [6]:
random.seed(34)

ci = [87, 1034, 131, 318, 588]  # category_id, 分别对应belt,sunglasses,boot,cowboy_hat,jacket
ann = data['annotations']
random.shuffle(ann)
print('total:')
for i in ci:
    count = 0
    for j in ann:
        if j['category_id'] == i:
            count += 1
    print(f'id: {i} counts: {count}')

total:
id: 87 counts: 25
id: 1034 counts: 2330
id: 131 counts: 449
id: 318 counts: 595
id: 588 counts: 2195


In [7]:

total_id = set(each['image_id'] for each in ann)
val_id = set()
a, b, c, d, e = 0, 0, 0, 0, 0  # 用于每类的计数
for each in ann:
    if (each['category_id'] == ci[0]) and (a < 2):
        val_id.add(each['image_id'])
        a += 1
    elif (each['category_id'] == ci[1]) and (b < 20):
        val_id.add(each['image_id'])
        b += 1
    elif (each['category_id'] == ci[2]) and (c < 4):
        val_id.add(each['image_id'])
        c += 1
    elif (each['category_id'] == ci[3]) and (d < 7):
        val_id.add(each['image_id'])
        d += 1
    elif (each['category_id'] == ci[4]) and (e < 17):
        val_id.add(each['image_id'])
        e += 1

val_ann = []
for imid in val_id:
    for each_ann in ann:
        if each_ann['image_id'] == imid:
            val_ann.append(each_ann)
            
len(val_id),len(val_ann)

(50, 177)

In [8]:
print('val set:')
for kind in ci:
    num = 0
    for i in val_ann:
        if i['category_id'] == kind:
            num += 1
    print(f'id: {kind} counts: {num}')

val set:
id: 87 counts: 3
id: 1034 counts: 57
id: 131 counts: 27
id: 318 counts: 36
id: 588 counts: 54


In [9]:
# The rest images are for training
train_id = total_id - val_id
train_ann = []
for each_ann in ann:
    for tid in train_id:
        if each_ann['image_id'] == tid:
            train_ann.append(each_ann)
            break
len(train_id), len(train_ann)

(3012, 5417)

In [10]:
os.makedirs('./datasets/coco/train2017', exist_ok=True)
os.makedirs('./datasets/coco/val2017', exist_ok=True)
          
train_img = []
# Move train images
for j in data['images']:
    for i in train_id:
        if j['id'] == i:
            shutil.copy('images/'+j['file_name'], './datasets/coco/train2017')
            train_img.append(j)
            
val_img = []
# Move val images
for j in data['images']:
    for i in val_id:
        if j['id'] == i:
            shutil.copy('images/'+j['file_name'], './datasets/coco/val2017')
            val_img.append(j)

len(val_img), len(train_img)

(50, 3012)

In [11]:
os.makedirs('./datasets/coco/annotations', exist_ok=True)
# 用于存储训练集和验证集的标注数据
train_annotations = []
val_annotations = []

# 遍历标注数据，根据image_id将数据分类到train或val
for annotation in data['annotations']:
    image_id = annotation['image_id']
        # 如果类别ID在category_mapping中，进行映射
    if annotation['category_id'] in category_mapping:
        annotation['category_id'] = category_mapping[annotation['category_id']]
    if image_id in train_id:
        train_annotations.append(annotation)
    elif image_id in val_id:
        val_annotations.append(annotation)

# 替换categories中的id
for category in data['categories']:
    if category['id'] in category_mapping:
        category['id'] = category_mapping[category['id']]

# 创建包含image信息和类别的字典（类似于COCO格式）
train_json = {
    'info': data['info'],
    'images': train_img,
    'annotations': train_annotations,  # 这里添加训练集的标注数据
    'categories': data['categories']
}

val_json = {
    'info': data['info'],
    'images': val_img,
    'annotations': val_annotations,
    'categories': data['categories']
}

# 保存为train.json和val.json
with open('./datasets/coco/annotations/instances_train2017.json', 'w') as train_file:
    json.dump(train_json, train_file, indent=4)

with open('./datasets/coco/annotations/instances_val2017.json', 'w') as val_file:
    json.dump(val_json, val_file, indent=4)

print("train.json 和 val.json 已生成")
train_json


train.json 和 val.json 已生成


{'info': {'description': 'CowboySuit',
  'url': 'http://github.com/dmlc/gluon-cv',
  'version': '1.0',
  'year': 2021,
  'contributor': 'GluonCV/AutoGluon',
  'date_created': '2021/07/01'},
 'images': [{'id': 9860841628484337660,
   'file_name': '88d8bf3754317ffc.jpg',
   'neg_category_ids': [434],
   'pos_category_ids': [69, 161, 216, 277, 433],
   'width': 1024,
   'height': 681,
   'source': 'OpenImages'},
  {'id': 15984033263460081658,
   'file_name': 'ddd2b190ea90dffa.jpg',
   'neg_category_ids': [],
   'pos_category_ids': [45, 69, 277, 434],
   'width': 768,
   'height': 1024,
   'source': 'OpenImages'},
  {'id': 76077631043502082,
   'file_name': '010e4833cdb38002.jpg',
   'neg_category_ids': [308, 333, 404, 584],
   'pos_category_ids': [35, 69, 106, 228, 433, 502, 567],
   'width': 1024,
   'height': 683,
   'source': 'OpenImages'},
  {'id': 18065680256228130812,
   'file_name': 'fab6307a1a43fffc.jpg',
   'neg_category_ids': [514],
   'pos_category_ids': [69, 216, 228, 277, 333