In [2]:
import sys
import os
import glob
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from ultralytics import YOLO

# ÂØºÂÖ•‰Ω†È°πÁõÆ‰∏≠ÁöÑÊ®°Âùó
from src.YoloSAM.models.sam import SAMModel
from src.YoloSAM.utils.config import SAMFinetuneConfig, SAMDatasetConfig # ÂÄüÁî®ConfigÊù•ÂàùÂßãÂåñÊ®°Âûã

# =========================================================
# 1. ÂÆö‰πâ‰∏Ä‰∏™‰∏ì‰∏öÁöÑ Inference Á±ª
# =========================================================
class YoloSAMInference:
    def __init__(self, yolo_path, sam_path, device='cuda'):
        self.device = torch.device(device)
        
        # --- 1. Âä†ËΩΩ YOLO ---
        print(f"Loading YOLO from: {yolo_path}")
        self.yolo_model = YOLO(yolo_path)
        print("‚úÖ YOLO Model Loaded.")
        
        # --- 2. Âä†ËΩΩ‰Ω†ÂæÆË∞ÉÁöÑ SAMModel ---
        print(f"Loading Fine-tuned SAM from: {sam_path}")
        self.sam_model = self._load_finetuned_sam(sam_path)
        self.sam_model.to(self.device)
        self.sam_model.eval() # ÂàáÊç¢Âà∞ËØÑ‰º∞Ê®°Âºè
        print("‚úÖ Fine-tuned SAM Model Loaded.")

    def _load_finetuned_sam(self, checkpoint_path):
        # ‰ΩøÁî®‰∏éËÆ≠ÁªÉÊó∂Áõ∏ÂêåÁöÑÈÖçÁΩÆÊù•ÂàùÂßãÂåñÊ®°ÂûãÈ™®Êû∂
        # Ê≥®ÊÑèÔºöËøôÈáåÁöÑ sam_path ÊòØ‰∏∫‰∫ÜÂàùÂßãÂåñ SAMModel Á±ªÔºåÂÆÉ‰∏ç‰ºöË¢´ÂÆûÈôÖÂä†ËΩΩ
        config = SAMFinetuneConfig(model_type='vit_b', sam_path='/root/task/checkpoints/sam_vit_b_01ec64.pth')
        model = SAMModel(config)
        
        # Âä†ËΩΩ‰Ω†ËÆ≠ÁªÉÂ•ΩÁöÑÊùÉÈáçÂ≠óÂÖ∏
        state_dict = torch.load(checkpoint_path, map_location=self.device)
        
        # üî• Ê†∏ÂøÉ‰øÆÂ§çÔºöÂ¶ÇÊûúÊùÉÈáçË¢´ÂåÖË£πÂú® 'model_state_dict' ÈáåÔºåÂÖàÊääÂÆÉÂèñÂá∫Êù•
        if 'model_state_dict' in state_dict:
            print("üì¶ Checkpoint format detected, extracting 'model_state_dict'...")
            state_dict = state_dict['model_state_dict']
            
        # Âä†ËΩΩÊùÉÈáçÂà∞Ê®°ÂûãÈ™®Êû∂
        model.load_state_dict(state_dict)
        return model

    def predict(self, image_path, yolo_conf=0.01, yolo_iou=0.5, image_size=1024):
        # --- 1. ÂõæÂÉèÈ¢ÑÂ§ÑÁêÜ ---
        image_pil = Image.open(image_path).convert("RGB")
        image_np = np.array(image_pil)
        
        # Ê®°Êãü Dataset ‰∏≠ÁöÑ Resize
        original_shape = image_np.shape[:2]
        resized_image_np = cv2.resize(image_np, (image_size, image_size))
        
        # --- 2. YOLO Êé®ÁêÜ ---
        yolo_results = self.yolo_model.predict(resized_image_np, conf=yolo_conf, iou=yolo_iou, verbose=False)
        
        detected_boxes = []
        if yolo_results and len(yolo_results[0].boxes) > 0:
            detected_boxes = yolo_results[0].boxes.xyxy.cpu() # Ëé∑ÂèñÊâÄÊúâÊ£ÄÊµãÊ°Ü
        
        # --- 3. SAM Êé®ÁêÜ (ÂØπÊØè‰∏™Ê£ÄÊµãÊ°Ü) ---
        # ÂõæÂÉèÈúÄË¶ÅËΩ¨Êç¢‰∏∫ Tensor: (C, H, W) -> (1, C, H, W)
        image_tensor = torch.from_numpy(resized_image_np).permute(2, 0, 1).float() / 255.0
        image_tensor = image_tensor.to(self.device).unsqueeze(0)

        all_masks = []
        if len(detected_boxes) > 0:
            with torch.no_grad():
                for box in detected_boxes:
                    prompt_box = box.unsqueeze(0).to(self.device) # (1, 4)
                    
                    # ‰ΩøÁî®‰∏éËÆ≠ÁªÉÊó∂ÂÆåÂÖ®Áõ∏ÂêåÁöÑ forward ÊñπÊ≥ï
                    pred_mask_logits, _ = self.sam_model.forward_one_image(
                        image=image_tensor,
                        bounding_box=prompt_box,
                        is_train=False
                    )
                    
                    pred_mask_prob = torch.sigmoid(pred_mask_logits)
                    pred_mask_binary = (pred_mask_prob > 0.5).squeeze().cpu().numpy()
                    all_masks.append(pred_mask_binary)

        return {
            "original_image": resized_image_np,
            "detected_boxes": detected_boxes.numpy() if len(detected_boxes) > 0 else [],
            "predicted_masks": all_masks
        }

    # ++++++++++++++++++++ ËøôÊòØÊ≠£Á°ÆÁöÑÊñ∞ÁâàÊú¨ÔºåËØ∑‰ΩøÁî®ÂÆÉ ++++++++++++++++++++
    def visualize_results(self, results):
        image = results['original_image']
        boxes = results['detected_boxes']
        masks = results['predicted_masks']
    
        # ÂàõÂª∫‰∏Ä‰∏™ÂâØÊú¨Áî®‰∫éÁªòÂà∂ÔºåÈÅøÂÖç‰øÆÊîπÂéüÂßãÊï∞ÊçÆ
        vis_image = image.copy()
    
        if masks:
            # Â∞ÜÊâÄÊúâ mask ÂêàÂπ∂Êàê‰∏Ä‰∏™Âçï‰∏ÄÁöÑÂ∏ÉÂ∞îÊé©Á†Å
            combined_mask = np.zeros_like(masks[0], dtype=bool)
            for mask in masks:
                # Á°Æ‰øù mask ÊòØÂ∏ÉÂ∞îÁ±ªÂûã
                combined_mask = np.logical_or(combined_mask, mask.astype(bool))
    
            # üî• Ê†∏ÂøÉ‰øÆÂ§çÔºöÂàõÂª∫‰∏Ä‰∏™ÂΩ©Ëâ≤ÁöÑË¶ÜÁõñÂ±Ç
            # ÂÆö‰πâÈ¢úËâ≤ (R, G, B)ÔºåÊ≥®ÊÑè OpenCV ‰ΩøÁî® BGR È°∫Â∫è
            color_bgr = (0, 255, 0)
            
            # Â∞ÜÂ∏ÉÂ∞îÊé©Á†ÅËΩ¨Êç¢‰∏∫ uint8 Ê†ºÂºè (0 Êàñ 255)
            binary_mask_uint8 = combined_mask.astype(np.uint8) * 255
            
            # ‰ΩøÁî® findContours ÊâæÂà∞Êé©Á†ÅÁöÑËΩÆÂªìÔºåÁªòÂà∂ËΩÆÂªìÁ∫øÊØîÂ°´ÂÖÖÊõ¥Ê∏ÖÊô∞
            contours, _ = cv2.findContours(binary_mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(vis_image, contours, -1, color_bgr, thickness=2) # ÁªòÂà∂ËΩÆÂªì
    
            # Â¶ÇÊûú‰Ω†Êõ¥ÂñúÊ¨¢ÂçäÈÄèÊòéÂ°´ÂÖÖÊïàÊûúÔºåÂèØ‰ª•‰ΩøÁî®‰ª•‰∏ã‰ª£Á†ÅÊõøÊç¢‰∏äÈù¢ÁöÑËΩÆÂªìÁªòÂà∂
            # overlay = vis_image.copy()
            # alpha = 0.5 # ÈÄèÊòéÂ∫¶
            # overlay[combined_mask] = color_bgr
            # vis_image = cv2.addWeighted(overlay, alpha, vis_image, 1 - alpha, 0)
            
        # if len(boxes) > 0:
        #     for box in boxes:
        #         x0, y0, x1, y1 = map(int, box)
        #         # ÁªòÂà∂Ê£ÄÊµãÊ°Ü (ÁªøËâ≤)
        #         cv2.rectangle(vis_image, (x0, y0), (x1, y1), (0, 255, 0), 2)
    
        # ‰ΩøÁî® Matplotlib ÊòæÁ§∫ÊúÄÁªàÁªìÊûú (Ê≥®ÊÑè OpenCV ÁöÑ BGR -> RGB ËΩ¨Êç¢)
        plt.figure(figsize=(12, 12))
        # plt.imshow(cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB))
        plt.imshow(vis_image)
        plt.title(f"End-to-End Inference\nDetected Objects: {len(boxes)}")
        plt.axis('off')
        plt.show()


# =========================================================
# 2. Ëá™Âä®ÂØªÊâæÊúÄÊñ∞ÁöÑÊùÉÈáçÊñá‰ª∂
# =========================================================
# Êâæ YOLO
# Ê≥®ÊÑèÔºöË∑ØÂæÑÂèØËÉΩÈúÄË¶ÅÊ†πÊçÆ‰Ω†ÁöÑÂÆûÈôÖÊÉÖÂÜµÂæÆË∞É
yolo_files = glob.glob("/root/task/src/YoloSAM/drive/runs/**/weights/best.pt", recursive=True)
if not yolo_files: raise FileNotFoundError("Êâæ‰∏çÂà∞‰ªª‰Ωï YOLO ÁöÑ best.ptÔºÅ")
latest_yolo_path = "/root/task/src/YoloSAM/drive/runs/yolo_vessel_detection3/weights/best.pt"
print(f"üëâ Found latest YOLO weights: {latest_yolo_path}")

# Êâæ SAM
# Ê≥®ÊÑèÔºöË∑ØÂæÑÊòØ /root/autodl-tmp/run_*
sam_files = glob.glob("/root/autodl-tmp/run_*/best_model.pth", recursive=True)
if not sam_files: raise FileNotFoundError("Êâæ‰∏çÂà∞‰ªª‰ΩïÂæÆË∞ÉÁöÑ SAM best_model.pthÔºÅ")
latest_sam_path = max(sam_files, key=os.path.getmtime)
print(f"üëâ Found latest SAM weights: {latest_sam_path}")


# =========================================================
# 3. ÊâßË°åÊé®ÁêÜÂíåÂèØËßÜÂåñ
# =========================================================
# --- ÂàùÂßãÂåñÊé®ÁêÜÂô® ---
pipeline = YoloSAMInference(
    yolo_path=latest_yolo_path,
    sam_path=latest_sam_path
)

# --- ÈÄâÊã©‰∏ÄÂº†ÂõæÁâáËøõË°åÈ¢ÑÊµã ---
image_path = "/root/task/datasets/DRIVE/val/images/01_test.tif"
print(f"\nüöÄ Predicting on image: {image_path}")

# --- ÊâßË°åÈ¢ÑÊµã ---
results = pipeline.predict(image_path)

# --- ÂèØËßÜÂåñÁªìÊûú ---
print(f"üìä YOLO detected {len(results['detected_boxes'])} objects.")
print("üé® Visualizing results...")
pipeline.visualize_results(results)

ModuleNotFoundError: No module named 'src'