In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git

In [None]:
!pip install onnxruntime-gpu

In [None]:
!pip install opencv-python
!pip install tqdm

In [1]:
#Make sure the model at models folder
#https://nusu-my.sharepoint.com/personal/e1330352_u_nus_edu/Documents/ISY5004%20Practice%20Module/Models/models/u2net.onnx
#https://nusu-my.sharepoint.com/personal/e1330352_u_nus_edu/Documents/ISY5004%20Practice%20Module/Models/models/sam_vit_h_4b8939.pth

In [1]:
import os
import cv2
import csv
import numpy as np
import torch
import onnxruntime
from glob import glob
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor

# ---------- Helper Functions ----------

def preprocess_for_u2net(image, target_size=(320, 320)):
    """
    Preprocess the image for U2-Net:
      - Resize
      - Convert BGR to RGB
      - Normalize to [0, 1]
      - Rearrange to CHW and add batch dimension
    """
    orig_h, orig_w = image.shape[:2]
    image_resized = cv2.resize(image, target_size)
    image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_input = np.transpose(image_norm, (2, 0, 1))  # CHW
    image_input = np.expand_dims(image_input, axis=0)  # add batch dim
    return image_input, (orig_w, orig_h)

def postprocess_u2net(prediction, orig_size, threshold=0.5):
    """
    Process U2-Net output:
      - Remove batch and channel dimensions
      - Resize to original image size
      - Binarize using threshold
    """
    pred_mask = prediction[0, 0, :, :]
    pred_mask = cv2.resize(pred_mask, orig_size)
    pred_mask_bin = (pred_mask > threshold).astype(np.uint8) * 255
    return pred_mask_bin

def compute_iou(mask_pred, mask_gt):
    """
    Compute the Intersection over Union (IoU) for two binary masks.
    Assumes both masks are binary with values 0 or 255.
    """
    mask_pred_bool = mask_pred.astype(bool)
    mask_gt_bool = mask_gt.astype(bool)
    intersection = np.logical_and(mask_pred_bool, mask_gt_bool).sum()
    union = np.logical_or(mask_pred_bool, mask_gt_bool).sum()
    return intersection / union if union != 0 else 0

def ensure_dir(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def get_random_points_from_mask(mask, n_points=10):
    """
    From the binary mask, randomly sample up to n_points coordinates
    where the mask is foreground (non-zero).
    Returns an array of shape [N, 2] where N <= n_points.
    """
    # Get indices of foreground pixels
    ys, xs = np.where(mask > 0)
    coords = np.stack([xs, ys], axis=1)
    if len(coords) == 0:
        # If no foreground, fallback to image center
        h, w = mask.shape
        return np.array([[w // 2, h // 2]])
    # If more than n_points, randomly sample without replacement
    if len(coords) > n_points:
        indices = np.random.choice(len(coords), size=n_points, replace=False)
        coords = coords[indices]
    return coords

def guided_sam_inference_with_multiple_points(predictor, image, prompt_points):
    """
    Use SAM predictor with multiple prompt points.
    For each prompt point, get a mask and then combine (union) all masks.
    """
    predictor.set_image(image)
    input_labels = np.ones(len(prompt_points), dtype=np.int32)  # all positive prompts
    # SAM returns [N, H, W] if multimask_output=False and multiple points are given.
    masks, _, _ = predictor.predict(
        point_coords=prompt_points,
        point_labels=input_labels,
        multimask_output=False
    )
    # Combine all masks using a pixel-wise maximum (union)
    combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
    for m in masks:
        mask_uint8 = (m.astype(np.uint8)) * 255
        combined_mask = cv2.bitwise_or(combined_mask, mask_uint8)
    return combined_mask

# ---------- Device Setup ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------- Load Models ----------

# 1. Load U2-Net (ONNX) with GPU support if available
u2net_path = os.path.join("models", "u2net.onnx")
providers = onnxruntime.get_available_providers()
if "CUDAExecutionProvider" in providers:
    u2net_session = onnxruntime.InferenceSession(u2net_path, providers=["CUDAExecutionProvider"])
    print("U2-Net using CUDAExecutionProvider")
else:
    u2net_session = onnxruntime.InferenceSession(u2net_path)
    print("U2-Net using CPU ExecutionProvider")

# 2. Load SAM model using its registry (example: SamPredictor)
sam_checkpoint = os.path.join("models", "sam_vit_h_4b8939.pth")
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)

# ---------- Setup Directories for Saving Results ----------

image_dir = "DUT-OMRON-image"
mask_dir = "DUT-OMRON-mask"

# We'll save outputs using the naming convention:
# {original file name}_mask.png, {original file name}_foreground.png, and {original file name}_background.png
u2net_output_dir = "u2net_result"
sam_output_dir = "sam_result"
ensure_dir(u2net_output_dir)
ensure_dir(sam_output_dir)

csv_filename = "compare_scores.csv"

# ---------- Processing and Evaluation ----------

image_paths = sorted(glob(os.path.join(image_dir, "*.*")))
iou_u2net_list = []
iou_sam_list = []
results = []

pbar = tqdm(image_paths, total=len(image_paths), desc="Processing images", leave=True)

for i, image_path in enumerate(pbar):
    filename = os.path.basename(image_path)
    basename = os.path.splitext(filename)[0]
    gt_mask_path = os.path.join(mask_dir, basename + ".png")  # Adjust extension if needed

    image = cv2.imread(image_path)
    gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
    if image is None or gt_mask is None:
        pbar.set_postfix_str(f"Skipping {basename} (missing image or mask)")
        continue

    # --- U2-Net Inference ---
    input_tensor, orig_size = preprocess_for_u2net(image)
    input_name = u2net_session.get_inputs()[0].name
    pred = u2net_session.run(None, {input_name: input_tensor})[0]
    u2net_mask = postprocess_u2net(pred, orig_size)
    
    # Generate filenames for U2-Net outputs
    u2net_mask_filename = f"{basename}_mask.png"
    u2net_fg_filename = f"{basename}_foreground.png"
    u2net_bg_filename = f"{basename}_background.png"
    
    # Save U2-Net mask using the new naming convention
    cv2.imwrite(os.path.join(u2net_output_dir, u2net_mask_filename), u2net_mask)
    
    # Create U2-Net foreground and background
    foreground_u2net = cv2.bitwise_and(image, image, mask=u2net_mask)
    background_u2net = cv2.bitwise_and(image, image, mask=cv2.bitwise_not(u2net_mask))
    
    # Save U2-Net foreground and background images
    cv2.imwrite(os.path.join(u2net_output_dir, u2net_fg_filename), foreground_u2net)
    cv2.imwrite(os.path.join(u2net_output_dir, u2net_bg_filename), background_u2net)

    # --- SAM Inference ---
    # Here, prompt points are derived from the ground truth mask (for demonstration)
    prompt_points = get_random_points_from_mask(gt_mask, n_points=10)
    sam_mask = guided_sam_inference_with_multiple_points(sam_predictor, image, prompt_points)
    
    # Generate filenames for SAM outputs
    sam_mask_filename = f"{basename}_mask.png"
    sam_fg_filename = f"{basename}_foreground.png"
    sam_bg_filename = f"{basename}_background.png"
    
    # Save SAM mask using the new naming convention
    cv2.imwrite(os.path.join(sam_output_dir, sam_mask_filename), sam_mask)
    
    # Create SAM foreground and background images
    foreground_sam = cv2.bitwise_and(image, image, mask=sam_mask)
    background_sam = cv2.bitwise_and(image, image, mask=cv2.bitwise_not(sam_mask))
    
    # Save SAM foreground and background images
    cv2.imwrite(os.path.join(sam_output_dir, sam_fg_filename), foreground_sam)
    cv2.imwrite(os.path.join(sam_output_dir, sam_bg_filename), background_sam)

    # --- Evaluation (IoU) ---
    # Evaluate both U2-Net and SAM predictions against the ground truth
    gt_mask_bin = (gt_mask > 0).astype(np.uint8) * 255
    iou_u2net = compute_iou(u2net_mask, gt_mask_bin)
    iou_sam = compute_iou(sam_mask, gt_mask_bin)
    iou_u2net_list.append(iou_u2net)
    iou_sam_list.append(iou_sam)
    results.append([filename, iou_u2net, iou_sam])

    # Update the progress bar with the latest IoU scores
    pbar.set_postfix_str(
        f"{i+1}/{len(image_paths)} U2Net IoU: {iou_u2net:.4f}, SAM IoU: {iou_sam:.4f}"
    )

# After processing all images, write the CSV file.
with open(csv_filename, mode='w', newline='') as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow(["Image", "U2-Net IoU", "SAM IoU"])
    writer.writerows(results)

# ---------- Summary of Results ----------
if iou_u2net_list and iou_sam_list:
    avg_iou_u2net = np.mean(iou_u2net_list)
    avg_iou_sam = np.mean(iou_sam_list)
    print("\n--- Overall Performance ---")
    print(f"Average U2-Net IoU: {avg_iou_u2net:.4f}")
    print(f"Average SAM IoU: {avg_iou_sam:.4f}")
else:
    print("No valid results to summarize.")


Using device: cuda
U2-Net using CPU ExecutionProvider


Processing images: 100%|██████████| 5168/5168 [2:07:13<00:00,  1.48s/it, 5168/5168 U2Net IoU: 0.7334, SAM IoU: 0.8967]  


--- Overall Performance ---
Average U2-Net IoU: 0.6151
Average SAM IoU: 0.8544



