In [None]:
from pathlib import Path
import cv2
from ultralytics import YOLO
from tqdm import tqdm

# config
CLASS_NAMES = ["neg_cocci", "pos_cocci", "neg_bacilli", "pos_bacilli"]
CLASS_COLORS = {
    0: (0, 255, 255),  # yellow
    1: (0, 255, 0),    # green
    2: (0, 0, 255),    # red
    3: (255, 0, 0),    # blue
}
GT_COLOR = (0, 0, 0)      # dashed GT box
ERR_COLOR = (0, 0, 255)   # red X

# paths
test_images_dir = Path("/home/siu856582712/Documents/dataset/dataset/Alln/test/images")
test_labels_dir = Path("/home/siu856582712/Documents/dataset/dataset/Alln/test/labels")
save_dir = Path("/home/siu856582712/Documents/dataset/test_vis_analysisdd/imagesMM")
save_dir.mkdir(parents=True, exist_ok=True)

model_path = Path("/home/siu856582712/Documents/dataset/dataset/Allnresults/Allen/weights/best.pt")
model = YOLO(str(model_path))

def draw_dashed_rect(img, pt1, pt2, color, thickness=1, dash_len=5):
    x1, y1 = pt1
    x2, y2 = pt2
    for x in range(x1, x2, dash_len * 2):
        cv2.line(img, (x, y1), (min(x + dash_len, x2), y1), color, thickness)
        cv2.line(img, (x, y2), (min(x + dash_len, x2), y2), color, thickness)
    for y in range(y1, y2, dash_len * 2):
        cv2.line(img, (x1, y), (x1, min(y + dash_len, y2)), color, thickness)
        cv2.line(img, (x2, y), (x2, min(y + dash_len, y2)), color, thickness)

def draw_error_mark(img, x, y, size=10):
    cv2.line(img, (x - size, y - size), (x + size, y + size), ERR_COLOR, 2)
    cv2.line(img, (x - size, y + size), (x + size, y - size), ERR_COLOR, 2)

def iou(box_a, box_b):
    xa1, ya1, xa2, ya2 = box_a
    xb1, yb1, xb2, yb2 = box_b
    inter_w = max(0, min(xa2, xb2) - max(xa1, xb1))
    inter_h = max(0, min(ya2, yb2) - max(ya1, yb1))
    inter = inter_w * inter_h
    area_a = max(0, xa2 - xa1) * max(0, ya2 - ya1)
    area_b = max(0, xb2 - xb1) * max(0, yb2 - yb1)
    denom = area_a + area_b - inter + 1e-6
    return inter / denom

# collect up to 500 test images
image_paths = sorted(test_images_dir.glob("*.jpg"))[:500]

# run predictions as a stream
results = model.predict(
    source=image_paths,
    imgsz=640,
    conf=0.2,
    iou=0.45,
    device=1,
    half=True,
    stream=True,
    verbose=False,
)

for img_path, result in tqdm(zip(image_paths, results), total=len(image_paths)):
    img = cv2.imread(str(img_path))
    if img is None:
        continue
    h, w = img.shape[:2]

    # draw predictions
    pred_boxes = []
    for box in result.boxes:
        cls_id = int(box.cls.item())
        conf = float(box.conf.item())
        x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
        color = CLASS_COLORS.get(cls_id, (255, 255, 255))
        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
        cv2.putText(img, f"{CLASS_NAMES[cls_id]} {conf:.2f}", (x1, max(0, y1 - 5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
        pred_boxes.append((cls_id, [x1, y1, x2, y2]))

    # overlay dashed GT and mark label errors (IoU>=0.5 but wrong class)
    gt_file = test_labels_dir / (img_path.stem + ".txt")
    if gt_file.exists():
        with open(gt_file, "r") as f:
            for line in f:
                parts = line.strip().split()
                if not parts:
                    continue
                cls_id = int(float(parts[0]))
                x, y, bw, bh = map(float, parts[1:5])
                xc, yc = x * w, y * h
                ww, hh = bw * w, bh * h
                gx1 = int(xc - ww / 2)
                gy1 = int(yc - hh / 2)
                gx2 = int(xc + ww / 2)
                gy2 = int(yc + hh / 2)

                draw_dashed_rect(img, (gx1, gy1), (gx2, gy2), GT_COLOR)
                cv2.putText(img, CLASS_NAMES[cls_id], (gx1, min(h - 5, gy2 + 15)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, GT_COLOR, 1)

                for pred_cls, pred_box in pred_boxes:
                    if iou([gx1, gy1, gx2, gy2], pred_box) >= 0.5:
                        if pred_cls != cls_id:
                            cx = (gx1 + gx2) // 2
                            cy = (gy1 + gy2) // 2
                            draw_error_mark(img, cx, cy)
                        break

    cv2.imwrite(str(save_dir / img_path.name), img)


  0%|          | 0/500 [00:00<?, ?it/s]