In [13]:
import os
import shutil
import random
import xml.etree.ElementTree as ET
from pathlib import Path

# ===============================
# 🧱 配置参数区域
# ===============================
DATASET_URL = "andrewmvd/face-mask-detection"  # kaggle 数据集名称
DATASET_DIR = "dataset"        # 解压数据集的路径
ANNOT_DIR = os.path.join(DATASET_DIR, "annotations")
IMAGES_DIR = os.path.join(DATASET_DIR, "images")

OUTPUT_DIR = "dataset/output"
TRAIN_DIR = os.path.join(OUTPUT_DIR, "train")
TEST_DIR = os.path.join(OUTPUT_DIR, "test")
YAML_PATH = "voc.yaml"

CLASS_NAMES = ["with_mask", "without_mask"]
TRAIN_RATIO = 0.8

# ===============================
# 📥 下载并解压数据集
# ===============================
def download_dataset(dataset_url, output_dir):
    try:
        import kagglehub
    except ImportError:
        print("❌ 请先安装 kagglehub：pip install kagglehub")
        return

    if os.path.exists(output_dir):
        print(f"✅ 数据集已存在：{os.path.abspath(output_dir)}")
        return

    print("🚀 正在下载数据集中...")
    downloaded_path = kagglehub.dataset_download(dataset_url)
    shutil.copytree(downloaded_path, output_dir)
    shutil.rmtree(downloaded_path)
    print(f"✅ 数据集下载完成：{os.path.abspath(output_dir)}")

# ===============================
# 🔁 VOC → YOLO 格式转换函数
# ===============================
def convert_voc_to_yolo(xml_file, yolo_save_dir, class_names):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    img_width = int(root.find("size/width").text)
    img_height = int(root.find("size/height").text)
    label_lines = []

    for obj in root.findall("object"):
        cls_name = obj.find("name").text
        if cls_name not in class_names:
            continue
        cls_id = class_names.index(cls_name)

        bbox = obj.find("bndbox")
        xmin = int(bbox.find("xmin").text)
        ymin = int(bbox.find("ymin").text)
        xmax = int(bbox.find("xmax").text)
        ymax = int(bbox.find("ymax").text)

        # 转换为归一化中心坐标 + 宽高
        x_center = ((xmin + xmax) / 2) / img_width
        y_center = ((ymin + ymax) / 2) / img_height
        width = (xmax - xmin) / img_width
        height = (ymax - ymin) / img_height

        label_lines.append(f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")

    # 保存为 .txt 文件
    txt_file = Path(yolo_save_dir) / (Path(xml_file).stem + ".txt")
    with open(txt_file, "w") as f:
        f.write("\n".join(label_lines))

# ===============================
# 🔀 划分数据集并执行转换
# ===============================
def prepare_dataset(train_ratio):
    # 创建输出文件夹
    os.makedirs(TRAIN_DIR, exist_ok=True)
    os.makedirs(TEST_DIR, exist_ok=True)

    # 加载所有标注文件和对应图像
    xml_files = sorted(Path(ANNOT_DIR).glob("*.xml"))
    img_files = [Path(IMAGES_DIR) / (xml.stem + ".png") for xml in xml_files]

    # 验证文件是否存在
    for img in img_files:
        if not img.exists():
            raise FileNotFoundError(f"未找到图像文件：{img}")

    # 打乱并划分数据集
    data = list(zip(xml_files, img_files))
    random.shuffle(data)
    split_idx = int(len(data) * train_ratio)
    train_data, test_data = data[:split_idx], data[split_idx:]

    def process(data_list, target_dir):
        for xml_path, img_path in data_list:
            shutil.copy(xml_path, target_dir)
            shutil.copy(img_path, target_dir)
            convert_voc_to_yolo(xml_path, target_dir, CLASS_NAMES)

    process(train_data, TRAIN_DIR)
    process(test_data, TEST_DIR)

    print(f"✅ 数据集划分完成，训练集数量：{len(train_data)}，测试集数量：{len(test_data)}")

# ===============================
# 📝 生成 YOLO YAML 配置文件
# ===============================
def create_yaml(path, train_dir, val_dir, class_names):
    with open(path, "w") as f:
        f.write(f"train: {os.path.abspath(train_dir)}\n")
        f.write(f"val: {os.path.abspath(val_dir)}\n\n")
        f.write(f"nc: {len(class_names)}\n")
        f.write(f"names: {class_names}\n")
    print(f"📄 已生成配置文件：{path}")

# ===============================
# 🚀 主执行入口
# ===============================

download_dataset(DATASET_URL, DATASET_DIR)
prepare_dataset(TRAIN_RATIO)
create_yaml(YAML_PATH, TRAIN_DIR, TEST_DIR, CLASS_NAMES)


✅ 数据集已存在：/content/dataset
✅ 数据集划分完成，训练集数量：682，测试集数量：171
📄 已生成配置文件：voc.yaml


In [None]:
!pip install ultralytics
!nvidia-smi
from ultralytics import YOLO
device = 'cuda'  # 使用GPU训练,可选cuda或cpu

model = YOLO("baseModel/yolov8n.pt")  # 使用预训练模型
model.train(
    data="voc.yaml",
    device=0 if device == "cuda" else "cpu",
    epochs=50,
    batch=32,
    imgsz=640,
    optimizer="AdamW",
    multi_scale=True,
    augment=True,
    lr0=0.0001,               # 适当提高初始学习率
    lrf=0.01,                # 添加余弦退火最终学习率
    amp=True,               # 保持混合精度训练
    pretrained=True,        # 确保使用预训练权重
    patience=10,               # ⭐️ 添加早停机制 如果10个 epoch 没有提升，自动停止
    close_mosaic=10            # 提前关闭 mosaic 增强以稳定收敛
    save=True,  # 保存模型
    exist_ok=True,
    )  # 训练模型

Sat Apr 19 10:09:24 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   72C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

[34m[1mtrain: [0mScanning /content/dataset/output/train.cache... 682 images, 17 backgrounds, 0 corrupt: 100%|██████████| 682/682 [00:00<?, ?it/s]


[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, num_output_channels=3, method='weighted_average'), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))
[34m[1mval: [0mFast image access ✅ (ping: 0.0±0.0 ms, read: 1807.3±1425.2 MB/s, size: 340.0 KB)


[34m[1mval: [0mScanning /content/dataset/output/test.cache... 171 images, 4 backgrounds, 0 corrupt: 100%|██████████| 171/171 [00:00<?, ?it/s]


Plotting labels to runs/detect/train/labels.jpg... 
[34m[1moptimizer:[0m AdamW(lr=0.0001, momentum=0.937) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 2 dataloader workers
Logging results to [1mruns/detect/train[0m
Starting training for 50 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       1/50      10.4G      1.958      3.533      1.534        139        640: 100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.10s/it]

                   all        171        805    0.00279      0.222      0.023    0.00736






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       2/50      10.7G       1.51      2.252      1.202         44        512: 100%|██████████| 22/22 [00:15<00:00,  1.44it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.15it/s]

                   all        171        805    0.00402      0.308      0.107     0.0378






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       3/50      9.19G       1.38      1.576      1.128         57        864: 100%|██████████| 22/22 [00:15<00:00,  1.42it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.46it/s]

                   all        171        805     0.0176      0.709      0.358      0.203






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       4/50       9.2G       1.36        1.5        1.1         72        352: 100%|██████████| 22/22 [00:15<00:00,  1.42it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.10it/s]

                   all        171        805       0.97      0.265       0.45      0.281






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       5/50      11.7G      1.334      1.337      1.104         59        896: 100%|██████████| 22/22 [00:15<00:00,  1.44it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.44it/s]

                   all        171        805      0.801      0.431      0.532      0.331






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       6/50      9.31G      1.312      1.317      1.044         80        384: 100%|██████████| 22/22 [00:14<00:00,  1.48it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]

                   all        171        805      0.677      0.545      0.607      0.379






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       7/50      8.97G      1.345      1.279      1.035        106        864: 100%|██████████| 22/22 [00:16<00:00,  1.32it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.01s/it]

                   all        171        805      0.716      0.581      0.642      0.398






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       8/50      11.8G      1.272      1.111      1.033        102        672: 100%|██████████| 22/22 [00:15<00:00,  1.44it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.24it/s]

                   all        171        805      0.793      0.606      0.686      0.417






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       9/50      12.5G      1.254      1.111      1.045         64        672: 100%|██████████| 22/22 [00:15<00:00,  1.40it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.42it/s]

                   all        171        805      0.789      0.646      0.708      0.438






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      10/50      11.3G      1.224      1.014      1.028         50        448: 100%|██████████| 22/22 [00:15<00:00,  1.40it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.33it/s]

                   all        171        805      0.786       0.65      0.731      0.451






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      11/50      10.5G      1.206     0.9955       1.02         66        352: 100%|██████████| 22/22 [00:14<00:00,  1.50it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.46it/s]

                   all        171        805      0.811      0.666      0.755      0.469






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      12/50      12.5G      1.214      1.005      1.026         45        384: 100%|██████████| 22/22 [00:14<00:00,  1.48it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.06s/it]

                   all        171        805      0.874       0.65      0.758      0.463






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      13/50      9.97G      1.214     0.9752      1.004         95        800: 100%|██████████| 22/22 [00:13<00:00,  1.61it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.27it/s]

                   all        171        805      0.841      0.661      0.767      0.479






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      14/50      9.15G      1.183      0.939     0.9971         93        576: 100%|██████████| 22/22 [00:15<00:00,  1.44it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.15it/s]

                   all        171        805      0.812      0.685      0.779      0.476






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      15/50      13.1G      1.164        0.9      1.016        123        448: 100%|██████████| 22/22 [00:16<00:00,  1.35it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.28s/it]

                   all        171        805      0.866       0.69      0.789      0.493






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      16/50      7.62G      1.164      0.868     0.9954         80        544: 100%|██████████| 22/22 [00:15<00:00,  1.46it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.00s/it]

                   all        171        805      0.776      0.744      0.789      0.493






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      17/50      8.61G      1.136     0.8459     0.9975         82        736: 100%|██████████| 22/22 [00:14<00:00,  1.53it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.42it/s]

                   all        171        805      0.839      0.708      0.793       0.49






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      18/50      9.33G      1.176     0.8564     0.9828         72        608: 100%|██████████| 22/22 [00:14<00:00,  1.49it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.12it/s]

                   all        171        805      0.858      0.697      0.794      0.496






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      19/50      10.1G      1.155     0.8382     0.9886        134        800: 100%|██████████| 22/22 [00:13<00:00,  1.60it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.37it/s]

                   all        171        805      0.814       0.73        0.8      0.502






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      20/50      9.13G      1.162     0.8471      1.001        113        960: 100%|██████████| 22/22 [00:15<00:00,  1.46it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.45it/s]

                   all        171        805      0.847      0.723      0.799      0.506






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      21/50      10.7G      1.167     0.8346      1.007         54        704: 100%|██████████| 22/22 [00:15<00:00,  1.40it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.43it/s]

                   all        171        805      0.811      0.736      0.808      0.506






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      22/50        10G      1.145     0.8074      0.986        120        704: 100%|██████████| 22/22 [00:13<00:00,  1.59it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.38it/s]

                   all        171        805      0.829      0.747      0.809      0.515






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      23/50      10.4G      1.146     0.8413      1.003         61        480: 100%|██████████| 22/22 [00:14<00:00,  1.47it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:03<00:00,  1.06s/it]

                   all        171        805      0.811       0.73      0.807      0.514






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      24/50      11.3G      1.137     0.7848     0.9778         87        800: 100%|██████████| 22/22 [00:14<00:00,  1.56it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.43it/s]

                   all        171        805      0.828      0.753      0.819       0.52






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      25/50      10.5G      1.137     0.8118      0.984         58        448: 100%|██████████| 22/22 [00:14<00:00,  1.53it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.07it/s]

                   all        171        805      0.857      0.737      0.807      0.514






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      26/50       9.6G      1.108     0.7783     0.9994         79        704: 100%|██████████| 22/22 [00:14<00:00,  1.47it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.48it/s]

                   all        171        805      0.861       0.73      0.807      0.513






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      27/50      9.39G      1.133     0.7802     0.9833         43        512: 100%|██████████| 22/22 [00:14<00:00,  1.54it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.34it/s]

                   all        171        805      0.864      0.726       0.82      0.517






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      28/50      8.92G      1.098     0.7767     0.9767         67        768: 100%|██████████| 22/22 [00:14<00:00,  1.51it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.38it/s]

                   all        171        805      0.874      0.732      0.815      0.519






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      29/50      9.01G      1.164     0.7786      0.968        163        512: 100%|██████████| 22/22 [00:13<00:00,  1.58it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 3/3 [00:02<00:00,  1.46it/s]

                   all        171        805      0.885       0.73      0.816      0.522






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      30/50      2.43G      1.125     0.7675     0.9468        185        320:  18%|█▊        | 4/22 [00:01<00:04,  3.70it/s]

In [None]:
# 预测输出
import os
import cv2
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import matplotlib.pyplot as plt

# ------------ 全局配置 ------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "runs/detect/train/weights/best.pt"
model = YOLO(MODEL_PATH)
INPUT_PATH = "dataset/output/test/"  # 输入路径,可以为图片,视频,文件夹,摄像头编号
# INPUT_PATH = "dataset/output/video/test.mp4"  # 输入路径,可以为图片,视频,文件夹,摄像头编号
# INPUT_PATH=0

SAVE = True  # 是否保存预测结果
OUTPUT_PATH = "predict/"  # 预测结果保存路径

# ------------ 工具函数 ------------
def draw_boxes_pil(image, results):
    draw = ImageDraw.Draw(image)
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except:
        font = ImageFont.load_default()

    for box in results[0].boxes:
        x1, y1, x2, y2 = box.xyxy[0].tolist()
        cls_id = int(box.cls)
        conf = float(box.conf)
        label = f"{model.names[cls_id]} {conf:.2f}"

        text_bbox = font.getbbox(label)
        text_w, text_h = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
        draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
        draw.rectangle([x1, y1 - text_h, x1 + text_w, y1], fill="red")
        draw.text((x1, y1 - text_h), label, fill="white", font=font)

    return image

def save_image(image, save_path, origin_path=None):
    if os.path.isdir(save_path):
        filename = os.path.basename(origin_path)
        save_path = os.path.join(save_path, filename)
    else:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    image.save(save_path)
    print(f"✅ 已保存图片: {save_path}")

# ------------ 单图预测 ------------
def predict_image(image_path, save=False, save_path=None):
    image = Image.open(image_path).convert("RGB")
    results = model.predict(image_path, imgsz=640, device=DEVICE)
    image = draw_boxes_pil(image, results)

    plt.imshow(image)
    plt.axis("off")
    plt.title("预测结果")
    plt.show()

    if save and save_path:
        save_image(image, save_path, origin_path=image_path)

# ------------ 视频预测 ------------
def predict_video(video_path, save=False, save_path=None):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("❌ 视频文件无法打开")
        return

    if save:
        if os.path.isdir(save_path):
            filename = os.path.basename(video_path)
            save_path = os.path.join(save_path, f"{os.path.splitext(filename)[0]}.mp4")
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        fps, w, h = cap.get(5), int(cap.get(3)), int(cap.get(4))
        out = cv2.VideoWriter(save_path, fourcc, fps, (w, h))

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model.predict(frame, imgsz=640, device=DEVICE)
        for box in results[0].boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            cls_id = int(box.cls)
            conf = float(box.conf)
            label = f"{model.names[cls_id]} {conf:.2f}"
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

        cv2.imshow("预测中 - 按 Q 退出", frame)
        if save:
            out.write(frame)

        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    cap.release()
    if save:
        out.release()
        print(f"✅ 已保存视频: {save_path}")
    cv2.destroyAllWindows()

# ------------ 文件夹批量图片 ------------
def predict_folder(folder_path, save=False, output_dir=None):
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
                img_path = os.path.join(root, file)
                image = Image.open(img_path).convert("RGB")
                results = model.predict(img_path, imgsz=640, device=DEVICE)
                image = draw_boxes_pil(image, results)

                if save and output_dir:
                    rel_path = os.path.relpath(img_path, folder_path)
                    save_path = os.path.join(output_dir, rel_path)
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)
                    image.save(save_path)

    if save:
        print(f"✅ 文件夹预测完成，结果已保存至: {output_dir}")

# ------------ 摄像头实时预测 ------------
def predict_camera(index=0):
    cap = cv2.VideoCapture(index)
    if not cap.isOpened():
        print(f"❌ 无法打开摄像头 {index}")
        return

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        results = model.predict(frame, imgsz=640, device=DEVICE)
        for box in results[0].boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            cls_id = int(box.cls)
            conf = float(box.conf)
            label = f"{model.names[cls_id]} {conf:.2f}"
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)

        cv2.imshow("摄像头预测 - 按 Q 退出", frame)
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    cap.release()
    cv2.destroyAllWindows()

# ------------ 总入口函数 ------------
def run_predict(path, save=False, save_path=None):
    if isinstance(path, int):
        predict_camera(index=path)
    elif os.path.isfile(path):
        ext = os.path.splitext(path)[1].lower()
        if ext in [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]:
            predict_image(path, save, save_path)
        elif ext in [".mp4", ".avi", ".mov", ".mkv"]:
            predict_video(path, save, save_path)
    elif os.path.isdir(path):
        predict_folder(path, save, save_path)
    else:
        print("❌ 无效路径，请确认输入正确的图片/视频/文件夹/摄像头编号")

# ------------ 示例调用 ------------
run_predict(INPUT_PATH, SAVE, OUTPUT_PATH)      # 预测输出