In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from ultralytics import YOLO
from timm import create_model
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm

class Config():
    YOLO_WEIGHT_DIR = r"E:\ISICDM2025\yolo_weight"
    CLS_WEIGHT_DIR = r"E:\ISICDM2025\cls_weight"
    TEST_IMG_DIR = r"E:\ISICDM2025\ISICDM2025_images_for_test"
    IMG_SIZE = 512
    NUM_CLASSES = 7
    YOLO_CONF_THRESH = 0.02

In [None]:
print(f"Loading YOLO model (size {Config.IMG_SIZE})...")
weight_path = os.path.join(Config.YOLO_WEIGHT_DIR, f"{Config.IMG_SIZE}yolov8.pt")
yolo_model = YOLO(weight_path)
yolo_model.to(DEVICE)

print(f"Loading EfficientNet classifiers (size {Config.IMG_SIZE})...")
cls_models = []
for fold in range(1, 6):
    WEIGHT_PATH = os.path.join(Config.CLS_WEIGHT_DIR, f"{Config.IMG_SIZE}_efficientnet_b0_fold_{fold}.pth")
    model = create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=False, in_chans=1)
    checkpoint = torch.load(WEIGHT_PATH, map_location=DEVICE)
    num_classes = Config.NUM_CLASSEAS
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    pretrained_dict = checkpoint['model_state_dict']
    model_dict = model.state_dict()
    filtered_dict = {
        k: v for k, v in pretrained_dict.items()
        if k in model_dict and 'classifier' not in k
    }
    model_dict.update(filtered_dict)
    model.load_state_dict(model_dict)
    model = model.to(DEVICE)
    model.eval()
    cls_models.append(model)

def get_transform(img_size):
    return transforms.Compose([
        transforms.ToPILImage(mode='L'),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),  # 输出 [1, H, W]
        transforms.Normalize(mean=[0.98755298], std=[0.026713483])
    ])

results = []
test_images = sorted([f for f in os.listdir(Config.TEST_IMG_DIR) if f.endswith('.png')])

for img_name in tqdm(test_images):
    img_path = os.path.join(Config.TEST_IMG_DIR, img_name)
    orig_img = cv2.imread(img_path)  # 读为三通道 BGR（YOLO 要求）
    if orig_img is None:
        print(f"Warning: cannot read {img_path}")
        continue
    h, w = orig_img.shape[:2]

    # Step 1: YOLO 推理（仅 512）
    results_yolo = yolo_model(
        orig_img,
        imgsz=Config.IMG_SIZE,
        conf=Config.YOLO_CONF_THRESH,
        verbose=False
    )
    pred = results_yolo[0].boxes

    if pred is None or len(pred) == 0:
        continue

    boxes_xyxy = pred.xyxy.cpu()  # [N, 4]
    confs = pred.conf.cpu()       # [N]

    # Step 2: 对每个框进行分类预测
    for box, det_conf in zip(boxes_xyxy, confs):
        x1, y1, x2, y2 = box.int().tolist()
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)
        if x2 <= x1 or y2 <= y1:
            continue

        crop = orig_img[y1:y2, x1:x2]  # 三通道 BGR
        if crop.size == 0:
            continue

        crop_gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)  # [H, W]

        total_logits = torch.zeros(Config.NUM_CLASSEAS, device=DEVICE)
        transform = get_transform(Config.IMG_SIZE)
        input_tensor = transform(crop_gray).unsqueeze(0).to(DEVICE)  # [1, 1, H, W]
        fold_logits = []
        for model in cls_models:
            with torch.no_grad():
                logits = model(input_tensor)
                fold_logits.append(logits)
        avg_logits = torch.mean(torch.cat(fold_logits, dim=0), dim=0)
        total_logits += avg_logits

        probs = F.softmax(total_logits, dim=0)
        pred_class = torch.argmax(probs).item()
        confidence = probs[pred_class].item()

        results.append({
            'image_name': img_name,
            'xmin': x1,
            'ymin': y1,
            'xmax': x2,
            'ymax': y2,
            'predicted_class': pred_class,
            'confidence': confidence
        })