In [None]:
class YOLOv5SlidingWindowDetector:
    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path)
        self.model.to(self.device).eval()
        self.class_names = self.model.names

    def detect(
        self,
        image_path,
        crop_size=640,
        stride=512,
        conf_thres=0.25,
        iou_thres=0.45,
        min_box_size=10,  # ✅ 新增参数
        save_result=True,
        save_img_path='detection_result.jpg',
        save_txt_path='detection_result.txt'
    ):
        orig_img = cv2.imread(str(image_path))
        if orig_img is None:
            raise FileNotFoundError(f"Cannot read image: {image_path}")
        H, W = orig_img.shape[:2]

        pad_H = (crop_size - H % crop_size) % crop_size
        pad_W = (crop_size - W % crop_size) % crop_size
        img = cv2.copyMakeBorder(orig_img, 0, pad_H, 0, pad_W, cv2.BORDER_CONSTANT, value=0)
        padded_H, padded_W = img.shape[:2]

        all_boxes, all_scores, all_classes = [], [], []

        for y in tqdm(range(0, padded_H - crop_size + 1, stride)):
            for x in range(0, padded_W - crop_size + 1, stride):
                crop = img[y:y+crop_size, x:x+crop_size]
                results = self.model(crop)
                for *xyxy, conf, cls in results.xyxy[0].cpu().numpy():
                    if conf < conf_thres:
                        continue
                    x1, y1, x2, y2 = xyxy
                    abs_x1 = x1 + x
                    abs_y1 = y1 + y
                    abs_x2 = x2 + x
                    abs_y2 = y2 + y
                    box_w = abs_x2 - abs_x1
                    box_h = abs_y2 - abs_y1

                    if box_w < min_box_size or box_h < min_box_size:
                        continue  # ✅ 忽略过小框

                    all_boxes.append([abs_x1, abs_y1, abs_x2, abs_y2])
                    all_scores.append(conf)
                    all_classes.append(int(cls))

        if not all_boxes:
            print("⚠️ No valid objects detected.")
            return

        boxes_tensor = torch.tensor(all_boxes)
        scores_tensor = torch.tensor(all_scores)
        keep = nms(boxes_tensor, scores_tensor, iou_thres)

        boxes_tensor = boxes_tensor[keep]
        scores_tensor = scores_tensor[keep]
        classes_tensor = torch.tensor(all_classes)[keep]

        result_img = orig_img.copy()
        yolo_labels = []

        for box, score, cls_id in zip(boxes_tensor, scores_tensor, classes_tensor):
            x1, y1, x2, y2 = map(int, box.tolist())
            cv2.rectangle(result_img, (x1, y1), (x2, y2), (0, 255, 0), 2)

            # YOLO 格式归一化标签
            cx = (x1 + x2) / 2 / W
            cy = (y1 + y2) / 2 / H
            bw = (x2 - x1) / W
            bh = (y2 - y1) / H
            yolo_labels.append(f"{int(cls_id)} {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}")

        if save_result:
            cv2.imwrite(save_img_path, result_img)
            with open(save_txt_path, 'w') as f:
                f.write('\n'.join(yolo_labels))
            print(f"✅ Saved detection result to {save_img_path}")
            print(f"✅ Saved YOLO label to {save_txt_path}")
        else:
            return result_img, yolo_labels


In [None]:

detector = YOLOv5SlidingWindowDetector(model_path='yolov5/run/train/stomata_yolov5x2/weights/best.pt')

detector.detect(
    image_path='海南裁剪/47 (2).jpg',
    crop_size=640,
    stride=128,
    conf_thres=0.25,
    iou_thres=0.15,
    save_result=True,
    save_img_path='test.jpg',
    save_txt_path='test.txt', 
    min_box_size=20
)

In [None]:
import os
from glob import glob

def batch_detect_from_folder(
    detector, 
    input_folder, 
    output_folder='outputs', 
    crop_size=640,
    stride=512,
    conf_thres=0.25,
    iou_thres=0.45,
    min_box_size=10
):
    os.makedirs(output_folder, exist_ok=True)

    image_paths = glob(os.path.join(input_folder, '*.[jp][pn]g'))  # 支持jpg、jpeg、png
    print(f"📂 Found {len(image_paths)} images in {input_folder}")

    for img_path in image_paths:
        filename = Path(img_path).stem
        save_img_path = os.path.join(output_folder, f'{filename}_det.jpg')
        save_txt_path = os.path.join(output_folder, f'{filename}.txt')
        
        print(f"🔍 Processing {img_path}...")
        detector.detect(
            image_path=img_path,
            crop_size=crop_size,
            stride=stride,
            conf_thres=conf_thres,
            iou_thres=iou_thres,
            min_box_size=min_box_size,
            save_result=True,
            save_img_path=save_img_path,
            save_txt_path=save_txt_path
        )


In [None]:
detector = YOLOv5SlidingWindowDetector(model_path='yolov5/run/train/stomata_yolov5x2/weights/best.pt')


batch_detect_from_folder(
    detector,
    input_folder='海南裁剪',      # 输入图像文件夹路径
    output_folder='海南预测',    # 输出文件夹
    crop_size=640,
    stride=128,
    conf_thres=0.25,
    iou_thres=0.15,
    min_box_size=20
)

In [None]:
detector = YOLOv5SlidingWindowDetector(model_path='yolov5/run/train/stomata_yolov5x2/weights/best.pt')


batch_detect_from_folder(
    detector,
    input_folder='河南自交系穗位叶',      # 输入图像文件夹路径
    output_folder='河南预测',    # 输出文件夹
    crop_size=640,
    stride=128,
    conf_thres=0.25,
    iou_thres=0.15,
    min_box_size=20
)