# 目标检测模型预测

In [1]:
import sys
import os
import json
import cv2
from tqdm import tqdm
import pandas as pd
from hq_det.models.dino import hq_dino

In [2]:
model_path = "/root/autodl-tmp/seat_model/seat_dino_baseline.pth"
input_path = "/root/autodl-tmp/seat_dataset/chengdu_customer/"
output_path = "out"

In [None]:
def load_coco_data(path):
    with open(path, "r") as f:
        coco_data = json.load(f)
    return coco_data

In [5]:
df = pd.DataFrame(images_info)
df.head()

Unnamed: 0,id,license,file_name,height,width,date_captured,mask_null,interested_area,not_interested_area
0,0,1,1756717183496__p_0002_f9f1d25c48db4b5086d68fa9...,2048.0,2046.0,,False,[],[]
1,1,1,1756717182510__p_0032_34e0b790feab421d8063f798...,2048.0,2046.0,,False,[],[]
2,2,1,1756717183312__p_0019_0683254627f94ff08cfcf1aa...,2048.0,2046.0,,False,[],[]
3,3,1,1756717181427_2088f67d-f81d-448a-bdb5-e7c64377...,4100.0,4096.0,,False,[],[]
4,4,1,1756717181552__p_0005_8b3a587957524633adfc54cf...,2048.0,2046.0,,False,[],[]


In [6]:
os.makedirs(output_path, exist_ok=True)
model = hq_dino.HQDINO(model=model_path)
model.eval()
model.to("cuda:0")

627 627


In [8]:
def replace_coco_file(coco_file, id2names):
    coco_data = load_coco_data(coco_file)
    new_categories = [
        {
            "id": i,
            "name": name,
            "supercategory": "none"
        }
        for i, name in enumerate(id2names.values())
    ]
    ids, raw_ids = [], []
    # 反转，从最大的开始更新，赋予新的id
    for item in reversed(new_categories):
        name = item['name']
        id = item['id']
        raw_id_list = [i for i, c in enumerate(coco_data['categories']) if c['name'] == name]
        if not raw_id_list:
            print(f"not find label: {name}")
            continue 
        raw_id = raw_id_list[0]
        for ann in coco_data['annotations']:
            if ann['category_id'] == raw_id:
                ann['category_id'] = id
    for ann in coco_data['annotations']:
        ann['sorce'] = 'gt'
    coco_data['categories'] = new_categories

    return coco_data

new_coco_data = replace_coco_file(os.path.join(input_path, "_annotations.coco.json"), model.id2names)


not find label: 烫伤
not find label: 吊紧


In [9]:
from collections import Counter
import json
import os

def check_coco_category_counts(coco):
    """
    检查COCO文件中每个类别的标签数量。

    参数:
        coco: 可以是coco字典对象，或coco文件路径（.json）
        show: 是否打印统计信息

    返回:
        result: dict，key为类别id，value为(类别名, 标签数量)
    """
    # 如果传入的是文件路径，则读取json
    if isinstance(coco, str) and os.path.isfile(coco):
        with open(coco, "r", encoding="utf-8") as f:
            coco_data = json.load(f)
    else:
        coco_data = coco

    category_counts = Counter([ann['category_id'] for ann in coco_data['annotations']])
    result = {}
    for cat in coco_data['categories']:
        cat_id = cat['id']
        cat_name = cat['name']
        count = category_counts.get(cat_id, 0)
        result[cat_id] = (cat_name, count)
    return result

In [12]:
import json

coco_output = {
    "images": [],
    "annotations": [],
    "categories": []
}

id2names = model.id2names if hasattr(model, "id2names") else {}
for cls_id, cls_name in id2names.items():
    coco_output["categories"].append({
        "id": int(cls_id),
        "name": cls_name
    })

annotation_id = 1
for image_info in images_info:
    img_path = os.path.join(input_path, image_info['file_name'])
    img = cv2.imread(img_path)
    height, width = img.shape[:2]
    assert height == image_info['height'] and width == image_info['width']
    coco_output["images"].append(image_info.copy())
    result = model.predict([img], bgr=True, confidence=0.3, max_size=1536)[0]
    bboxes = result.bboxes
    scores = result.scores
    labels = result.cls
    for i in range(len(bboxes)):
        bbox = bboxes[i]
        x, y, x2, y2 = bbox
        w = x2 - x
        h = y2 - y
        coco_output["annotations"].append({
            "id": annotation_id,
            "image_id": image_info['id'],
            "category_id": int(labels[i]),
            "bbox": [float(x), float(y), float(w), float(h)],
            "score": float(scores[i]),
            "area": float(w * h),
            "iscrowd": 0,
            "sorce": "predict"
        })
        annotation_id += 1

# 保存为json文件
with open(os.path.join(output_path, "predict_coco.json"), "w", encoding="utf-8") as f:
    json.dump(coco_output, f, ensure_ascii=False, indent=2)

