In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection, SamProcessor, SamModel # Added SamProcessor, SamModel
from PIL import Image, ImageDraw
from torchvision.transforms.functional import pil_to_tensor
import torchmetrics

# Define device
device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Using device: {device}")

def imsc(img, *args, quiet=False, lim=None, interpolation='lanczos', **kwargs):
    """
    Rescale and display an image represented as a tensor or PIL Image.
    The function scales the img to the [0, 1] range.
    The img is assumed to have shape 3xHxW (RGB) or 1xHxW (grayscale).

    Args:
        img (torch.Tensor or PIL.Image): image.
        quiet (bool, optional): if False, display image. Default: False.
        lim (list, optional): [min, max] for rescaling. Default: None.
        interpolation (str, optional): Interpolation mode for imshow. Default: 'lanczos'.

    Returns:
        numpy.ndarray: Rescaled image as numpy array.
    """
    if isinstance(img, Image.Image):
        img = pil_to_tensor(img).float()
    handle = None
    with torch.no_grad():
        if lim is None:
            lim = [img.min(), img.max()]
        img = img - lim[0]  # also makes a copy
        img.mul_(1 / (lim[1] - lim[0]))
        img = torch.clamp(img, min=0, max=1)
        if not quiet:
            # Ensure 3 channels for display
            if img.shape[0] == 1:
                img = img.expand(3, *img.shape[1:])
            bitmap = img.permute(1, 2, 0).cpu().numpy()
            return bitmap
        else:
            return img

In [None]:
coco_images = torch.load('data/example_images_coco.pth', map_location=device)
vaihingen_images = torch.load('data/example_images_vaihingen.pth', map_location=device)
coco_labels = torch.load('data/example_labels_coco.pth', map_location=device)
vaihingen_labels = torch.load('data/example_labels_vaihingen.pth', map_location=device)

In [None]:
mammo_images = torch.load('data/mammo_birads5.pt', map_location=device)

In [None]:
processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny")

In [None]:
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base")

In [None]:
coco_categories = "bicycle.car.motorcycle.airplane.bus.train.truck.boat.traffic light.fire hydrant.stop sign.parking meter.bench.bird.cat.dog.horse.sheep.cow.elephant.bear.zebra.giraffe.backpack.umbrella.handbag.tie.suitcase.frisbee.skis.snowboard.sports ball.kite.baseball bat.baseball glove.skateboard.surfboard.tennis racket.bottle.wine glass.cup.fork.knife.spoon.bowl.banana.apple.sandwich.orange.broccoli.carrot.hot dog.pizza.donut.cake.chair.couch.potted plant.bed.dining table.toilet.tv.laptop.mouse.remote.keyboard.cell phone.microwave.oven.toaster.sink.refrigerator.book.clock.vase.scissors.teddy bear.hair drier.toothbrush"
vaihingen_categories = "Impervious surfaces.Buildings.Low vegetation.Trees.Cars.Clutter"
print(f"Created coco_categories string with {len(coco_categories.split('.'))} categories.")
print(f"Created vaihingen_categories string with {len(vaihingen_categories.split('.'))} categories.")

In [None]:
def calculate_batch_iou_torchmetrics(
    predicted_masks: torch.Tensor, 
    ground_truth_masks: torch.Tensor
) -> torch.Tensor:
    """
    Calculates the mean Intersection over Union (IoU) for a batch of binary 
    segmentation masks using the torchmetrics library.

    Args:
        predicted_masks (torch.Tensor): A batch of predicted binary segmentation masks.
                                        Expected shape: (N, H, W), where N is the batch size,
                                        H is height, and W is width.
                                        Values can be boolean, integer (0 or 1), or float (probabilities).
        ground_truth_masks (torch.Tensor): A batch of ground truth binary segmentation masks.
                                           Expected shape: (N, H, W).
                                           Values can be boolean, integer (0 or 1), or float.

    Returns:
        torch.Tensor: A scalar tensor representing the mean IoU for the batch.
    """
    if not isinstance(predicted_masks, torch.Tensor) or not isinstance(ground_truth_masks, torch.Tensor):
        raise TypeError("Inputs must be PyTorch Tensors.")

    if predicted_masks.shape != ground_truth_masks.shape:
        raise ValueError(
            f"Shape mismatch: predicted_masks have shape {predicted_masks.shape} "
            f"while ground_truth_masks have shape {ground_truth_masks.shape}."
        )
    if predicted_masks.ndim != 3:
        raise ValueError(
            f"Input tensors must be 3-dimensional (N, H, W). "
            f"Got {predicted_masks.ndim} dimensions."
        )

    # Ensure masks are on the same device
    if predicted_masks.device != ground_truth_masks.device:
        try:
            ground_truth_masks = ground_truth_masks.to(predicted_masks.device)
        except Exception as e:
            raise RuntimeError(
                f"Could not move ground_truth_masks to device {predicted_masks.device}. "
                f"Ensure both tensors are on the same device. Error: {e}"
            )
    
    metric_device = predicted_masks.device

    # Convert predicted masks to long integers (0 or 1)
    if predicted_masks.dtype == torch.bool:
        preds = predicted_masks.long()
    elif predicted_masks.dtype.is_floating_point:
        preds = (predicted_masks > 0.5).long() # Threshold probabilities
    elif predicted_masks.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
        preds = predicted_masks.long()
    else:
        raise TypeError(
            f"predicted_masks dtype {predicted_masks.dtype} not supported. "
            "Expected bool, float, or int."
        )

    # Convert ground truth masks to long integers (0 or 1)
    if ground_truth_masks.dtype == torch.bool:
        target = ground_truth_masks.long()
    elif ground_truth_masks.dtype.is_floating_point:
        # Assuming ground truth floats are already 0.0 or 1.0, or need thresholding
        target = (ground_truth_masks > 0.5).long() 
    elif ground_truth_masks.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
        target = ground_truth_masks.long()
    else:
        raise TypeError(
            f"ground_truth_masks dtype {ground_truth_masks.dtype} not supported. "
            "Expected bool, float, or int."
        )

    # Initialize Jaccard Index for binary task
    # JaccardIndex is equivalent to IoU
    jaccard = torchmetrics.JaccardIndex(task="binary").to(metric_device)

    # Compute IoU for the batch (mean IoU)
    iou_score = jaccard(preds, target)

    return iou_score

In [None]:
all_predicted_masks_list = []
all_ground_truth_masks_list = []

print(f"Starting IoU evaluation for {len(coco_images)} COCO images...")

# Store the original matplotlib backend and turn off interactive plotting
import matplotlib
original_matplotlib_backend = None
try:
    original_matplotlib_backend = matplotlib.get_backend()
except Exception: # Handle cases where backend might not be set or causes error
    pass
plt.ioff() # Turn off interactive mode to prevent plots from showing

# Backup plotting functions to suppress them during the loop
fig_axs_backup = plt.subplots
plt_show_backup = plt.show

def no_plot_subplots_factory(original_subplots_func):
    def no_plot_subplots_impl(*args, **kwargs):
        fig, axs = original_subplots_func(*args, **kwargs)
        plt.close(fig) # Close the figure immediately
        return fig, axs
    return no_plot_subplots_impl

def no_show_factory():
    return lambda: None


for i in range(len(coco_images)):
    print(f"Processing image {i+1}/{len(coco_images)}...")
    
    if coco_labels is None or i >= len(coco_labels) or coco_labels[i] is None:
        print(f"Skipping image {i} due to missing ground truth label.")
        continue

    # Suppress plots for this iteration
    plt.subplots = no_plot_subplots_factory(fig_axs_backup)
    plt.show = no_show_factory()

    segmentation_output = segment_objects_on_image(
        images_dataset=coco_images,
        image_index=i,
        text_prompts_str=coco_categories, 
        ground_truth_labels=coco_labels,
        verbose=False 
    )
    
    # Restore plotting functions immediately after the call
    plt.subplots = fig_axs_backup
    plt.show = plt_show_backup

    if segmentation_output is None or \
       segmentation_output.get('sam_masks') is None or \
       segmentation_output.get('sam_results') is None or \
       segmentation_output['sam_results'].iou_scores is None:
        print(f"Skipping image {i} due to missing segmentation results or IoU scores.")
        continue

    predicted_sam_masks = segmentation_output['sam_masks'] 
    sam_iou_scores = segmentation_output['sam_results'].iou_scores

    current_gt_mask_tensor = coco_labels[i]
    # Assuming GT tensor is (C, H, W) or (H, W). We need H, W.
    if current_gt_mask_tensor.ndim == 3:
        h_gt, w_gt = current_gt_mask_tensor.shape[1], current_gt_mask_tensor.shape[2]
    elif current_gt_mask_tensor.ndim == 2:
        h_gt, w_gt = current_gt_mask_tensor.shape[0], current_gt_mask_tensor.shape[1]
    else:
        print(f"Skipping image {i} due to unexpected ground truth mask dimensions: {current_gt_mask_tensor.ndim}")
        continue


    if predicted_sam_masks is None or len(predicted_sam_masks) == 0:
        print(f"No objects detected/segmented for image {i}. Creating an empty predicted mask.")
        # Create an empty mask with the same H, W as the ground truth
        joined_predicted_mask_np = np.zeros((h_gt, w_gt), dtype=bool)
    else:
        # Initialize a single mask for the current image (all objects combined)
        joined_predicted_mask_np = np.zeros((h_gt, w_gt), dtype=bool)
        
        for obj_idx, single_object_masks_tensor in enumerate(predicted_sam_masks):
            # single_object_masks_tensor has shape (3, H_pred, W_pred)
            # sam_iou_scores[0, obj_idx, :] has shape (3,)
            if obj_idx >= sam_iou_scores.shape[1]:
                print(f"Warning: obj_idx {obj_idx} out of bounds for sam_iou_scores second dimension {sam_iou_scores.shape[1]} on image {i}. Skipping object.")
                continue

            best_mask_idx = torch.argmax(sam_iou_scores[0, obj_idx, :])
            best_mask_for_object_tensor = single_object_masks_tensor[best_mask_idx] # Shape (H_pred, W_pred)
            
            best_mask_for_object_np = best_mask_for_object_tensor.cpu().numpy().astype(bool)
            
            # Ensure predicted mask is resized to GT dimensions if different
            if best_mask_for_object_np.shape != joined_predicted_mask_np.shape:
                # print(f"Resizing predicted mask for object {obj_idx} from {best_mask_for_object_np.shape} to {joined_predicted_mask_np.shape}")
                pil_img = Image.fromarray(best_mask_for_object_np)
                # PIL resize expects (width, height)
                pil_img_resized = pil_img.resize((w_gt, h_gt), Image.NEAREST)
                best_mask_for_object_np = np.array(pil_img_resized)

            joined_predicted_mask_np |= best_mask_for_object_np

    # Prepare ground truth mask (binarize)
    gt_mask_np = current_gt_mask_tensor.cpu().numpy()
    # Handle different GT mask shapes (e.g., (C,H,W) or (H,W))
    if gt_mask_np.ndim == 3:
        if gt_mask_np.shape[0] == 1: # Single channel (1,H,W)
            gt_mask_np = gt_mask_np.squeeze(0)
        else: # Multi-channel (C,H,W), e.g. for semantic segmentation with class IDs
              # For binary IoU against "any object", we can take a max projection or sum then binarize
            # print(f"Warning: GT mask for image {i} has shape {gt_mask_np.shape}. Taking max across channels for binarization.")
            gt_mask_np = np.max(gt_mask_np, axis=0) 
            
    gt_mask_bin_np = (gt_mask_np > 0) # Binarize: True for any labeled pixel

    # Final check for shape consistency before appending
    if joined_predicted_mask_np.shape != gt_mask_bin_np.shape:
        print(f"Error: Final shape mismatch for image {i}. Pred: {joined_predicted_mask_np.shape}, GT: {gt_mask_bin_np.shape}. Skipping this image.")
        continue

    all_predicted_masks_list.append(torch.from_numpy(joined_predicted_mask_np)) 
    all_ground_truth_masks_list.append(torch.from_numpy(gt_mask_bin_np))

# Restore matplotlib backend and turn interactive mode back on
if original_matplotlib_backend:
    try:
        matplotlib.use(original_matplotlib_backend)
    except Exception as e:
        print(f"Could not restore matplotlib backend: {e}")
plt.ion() # Turn interactive mode back on

if all_predicted_masks_list and all_ground_truth_masks_list:
    # Stack all masks into batch tensors
    try:
        # Ensure all tensors in the list have the same shape before stacking
        # This should be handled by the resizing logic above, but good to be aware
        batched_preds = torch.stack(all_predicted_masks_list).to(device) # Shape: (N, H, W)
        batched_gts = torch.stack(all_ground_truth_masks_list).to(device)   # Shape: (N, H, W)

        # Calculate IoU using the torchmetrics function
        mean_iou = calculate_batch_iou_torchmetrics(batched_preds, batched_gts)
        print(f"\nMean IoU (torchmetrics) for the COCO dataset batch: {mean_iou.item():.4f}")
    except RuntimeError as e:
        print(f"Error during stacking masks or calculating IoU: {e}")
        print("This might be due to inconsistent mask shapes across the batch. Check processing steps and warnings.")
else:
    print("No masks were collected to calculate batch IoU. Please check processing steps and data.")


In [None]:
def detect_objects_on_image(images,image_index, text_prompts_str, labels, verbose = False):
    """
    Performs zero-shot object detection on a specified image from coco_images
    using a given text prompt. Displays the image with detections and prints results.

    Args:
        image_index (int): Index of the image in the global coco_images tensor.
        text_prompts_str (str): Text prompt describing objects to detect (e.g., "a cat . a dog .").
    
    Returns:
        dict: A dictionary containing the detection results (scores, labels, boxes), or None if an error occurs.
    """
    # Ensure global variables coco_images, processor, model are accessible.
    # Ensure necessary imports: Image, ImageDraw, plt, torch, np are available from previous cells.

    # 1. Select and prepare the image from coco_images
    if not (0 <= image_index < len(images)):
        print(f"Error: image_index {image_index} is out of bounds for coco_images (size {len(images)}).")
        return None

    image_tensor = images[image_index] # Expected shape: (C, H, W)

    # Permute to HWC for PIL and Matplotlib if it's CHW
    processed_image_tensor = image_tensor
    if len(image_tensor.shape) == 3 and image_tensor.shape[0] == 3: # CHW
        processed_image_tensor = image_tensor.permute(1, 2, 0) # HWC
    # Add handling for other shapes if necessary, though coco_images[idx] should be (C,H,W)

    image_np = processed_image_tensor.cpu().numpy()

    # Normalize to 0-1 range based on its own min/max, then scale to 0-255 for PIL
    min_val = image_np.min()
    max_val = image_np.max()
    if max_val == min_val: # Handle uniform images (e.g., all black or all white)
        image_scaled_np = np.zeros_like(image_np, dtype=np.float32)
    else:
        image_scaled_np = (image_np - min_val) / (max_val - min_val)
    
    # Convert to PIL Image. Assumes image_scaled_np is HWC and in [0,1]
    image_pil = Image.fromarray((image_scaled_np * 255).astype(np.uint8))

    # 2. Prepare inputs for the model
    inputs = processor(images=image_pil, text=text_prompts_str, return_tensors="pt")
    
    # Optionally, move inputs to model.device if experiencing device mismatches
    # inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # 3. Get model outputs
    with torch.no_grad():
        outputs = model(**inputs)

    # 4. Post-process the outputs
    target_sizes = [image_pil.size[::-1]]  # PIL size is (width, height), model expects (height, width)
    
    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids, 
        box_threshold=0.4, 
        text_threshold=0.3, 
        target_sizes=target_sizes
    )[0] # Get results for the first (and only) image processed

    # 5. Draw bounding boxes
    draw = ImageDraw.Draw(image_pil)
    for score, label, box in zip(results["scores"], results["text_labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        draw.rectangle(box, outline="red", width=1) # Using width=1 as per current notebook state
        draw.text((box[0], box[1]-10), f"{label}: {round(score.item(), 2)}", fill="red")

    # 6. Display the image with detections and the ground truth label side by side
    fig, axs = plt.subplots(1, 2, figsize=(16, 8))

    # Left: Detected objects
    axs[0].imshow(image_pil)
    axs[0].axis('off')
    axs[0].set_title(f"Detections for prompt: '{text_prompts_str if len(text_prompts_str) < 50 else 'COCO labels'}'\nimage index {image_index}")

    if labels is not None:
        # Right: Ground truth label mask
        gt_mask = labels[image_index].cpu().numpy()
        axs[1].imshow(gt_mask, cmap='tab20')
        axs[1].axis('off')
        axs[1].set_title("COCO Ground Truth Label")

    plt.tight_layout()
    plt.show()

    if verbose:
    # 7. Print detection details
        print(f"Detected objects: {results['labels']}")
        print(f"Scores: {results['scores']}")
        print(f"Boxes: {results['boxes']}")

    return results

In [None]:
def segment_objects_on_image(images_dataset, image_index, text_prompts_str, ground_truth_labels, verbose = False):
    """
    Performs text-prompt based object detection and segmentation on a specified image.
    Uses Grounding DINO for detection and SAM for segmentation.
    Displays detections, segmentations, and ground truth.

    Args:
        images_dataset (torch.Tensor): The dataset of images (e.g., coco_images).
        image_index (int): Index of the image in the images_dataset.
        text_prompts_str (str): Text prompt describing objects to detect and segment.
        ground_truth_labels (torch.Tensor): The ground truth segmentation masks for comparison.
    
    Returns:
        dict: A dictionary containing DINO detection results and SAM segmentation masks, or None if an error occurs.
    """
    # Ensure global models (model, processor for DINO; sam_model, sam_processor for SAM) and device are accessible.
    # Ensure necessary imports (Image, ImageDraw, plt, torch, np) are available.

    # 1. Select and prepare the image
    if not (0 <= image_index < len(images_dataset)):
        print(f"Error: image_index {image_index} is out of bounds for images_dataset (size {len(images_dataset)}).")
        return None
        
    image_tensor = images_dataset[image_index] 

    processed_image_tensor = image_tensor
    if len(image_tensor.shape) == 3 and image_tensor.shape[0] == 3: # CHW to HWC for PIL
        processed_image_tensor = image_tensor.permute(1, 2, 0)
    
    image_np = processed_image_tensor.cpu().numpy()

    min_val, max_val = image_np.min(), image_np.max()
    image_scaled_np = (image_np - min_val) / (max_val - min_val) if max_val > min_val else np.zeros_like(image_np)
    
    image_pil = Image.fromarray((image_scaled_np * 255).astype(np.uint8)).convert("RGB")


    # 2. Object Detection with Grounding DINO
    dino_inputs = processor(images=image_pil, text=text_prompts_str, return_tensors="pt")
    #dino_inputs = {k: v.to(device) for k, v in dino_inputs.items()}

    with torch.no_grad():
        dino_outputs = model(**dino_inputs)

    dino_results_processed = processor.post_process_grounded_object_detection(
        dino_outputs,
        dino_inputs.input_ids, 
        box_threshold=0.4, 
        text_threshold=0.3, 
        target_sizes=[image_pil.size[::-1]] 
    )[0]

    detected_boxes = dino_results_processed["boxes"]
    detected_labels = dino_results_processed["text_labels"]
    detected_scores = dino_results_processed["scores"]

    if verbose:
        print(f"DINO Detected objects: {detected_labels}")
        print(f"DINO Scores: {detected_scores}")
        print(f"DINO Boxes: {detected_boxes}")

    # 3. Image Segmentation with SAM (if objects were detected)
    sam_masks_processed = None
    if len(detected_boxes) > 0:
        sam_inputs = sam_processor(image_pil, input_boxes=[detected_boxes.cpu().tolist()], return_tensors="pt")
        #sam_inputs = {k: v.to(device) for k, v in sam_inputs.items()}
        
        with torch.no_grad():
            sam_outputs = sam_model(**sam_inputs)
        
        sam_masks_processed = sam_processor.image_processor.post_process_masks(
            sam_outputs.pred_masks.cpu(), 
            sam_inputs["original_sizes"].cpu(), 
            sam_inputs["reshaped_input_sizes"].cpu()
        )[0] # Get masks for the first (and only) image
        if verbose:
            print(f"SAM generated {len(sam_masks_processed)} masks for {len(detected_boxes)} detected boxes.")
    else:
        if verbose:
            print("No objects detected by Grounding DINO, skipping SAM segmentation.")

    # 4. Visualization
    fig, axs = plt.subplots(1, 3, figsize=(36, 12))

    # Plot 1: Detections (Image + DINO boxes + DINO labels)
    image_dino_plot = image_pil.copy()
    draw_dino = ImageDraw.Draw(image_dino_plot)
    for score, label, box in zip(detected_scores, detected_labels, detected_boxes):
        box_coords = [round(i, 2) for i in box.tolist()]
        draw_dino.rectangle(box_coords, outline="red", width=2)
        draw_dino.text((box_coords[0], box_coords[1]-12), f"{label}: {round(score.item(), 2)}", fill="red")
    axs[0].imshow(image_dino_plot)
    axs[0].set_title(f"DINO Detections: '{text_prompts_str if len(text_prompts_str) < 50 else 'COCO labels'}'\nimage index {image_index}")
    axs[0].axis('off')

    # Plot 2: Segmentations (Image + SAM masks)
    axs[1].imshow(image_pil)
    if sam_masks_processed is not None and len(sam_masks_processed) > 0:
        #generate a list of bright rgb colors to be used for overlaying masks
        colors = [
            [255, 255, 102],  # Light Yellow
            [173, 255, 47],   # Light Green (GreenYellow)
            [135, 206, 250],  # Light Sky Blue
            [255, 182, 193],  # Light Pink
            [255, 160, 122],  # Light Salmon
            [240, 255, 240],  # Honeydew (very light green)
            [255, 250, 205],  # Lemon Chiffon
            [224, 255, 255],  # Light Cyan
            [250, 235, 215],  # Antique White
            [255, 228, 181],  # Moccasin (Light Orange/Peach)
            [245, 245, 220],  # Beige
            [211, 211, 211],  # Light Grey
            [175, 238, 238],  # Pale Turquoise
            [255, 218, 185],  # Peach Puff
            [255, 105, 180]   # Hot Pink (brighter, but still light)
            ]
        for i, mask_tensor in enumerate(sam_masks_processed): # Iterate over each object's mask set
            # Assuming each item in sam_masks_processed is [num_masks_for_object, H, W]
            # For Grounded-SAM, usually, one box leads to multiple masks, pick the best one (e.g., by area or use SAM scores if available)
            # Here, we'll just show the first mask for each detected box for simplicity.
            # The masks from SAM are boolean.
            # mask_tensor shape is [H, W] after selecting one mask per box.
            
            # For multiple masks per box, you might need to select one.
            # Here, let's assume sam_masks_processed is [num_detected_objects, H, W]
            # Or if it's [1, num_detected_objects, H, W], squeeze it.
            # The output of post_process_masks is typically [batch_size, num_objects, H, W]
            # Since we have batch_size=1, it's [1, num_objects, H, W]. So sam_masks_processed[0] is [num_objects, H, W]
            
            # If sam_masks_processed is [N, H, W] where N is number of detected boxes
            #print(mask_tensor.shape) # Debugging line to check mask shape
            mask_to_show = mask_tensor.squeeze().cpu().numpy() # Take the mask for the i-th box
            #provide a long list of bright colors for overlay
            #colors = plt.cm.get_cmap('hsv', len(sam_masks_processed)) # Get a colormap with enough colors
            color = colors[i%15]
            overlay = np.zeros_like(image_scaled_np)
            #permute to HWC for overlay, is an np.array
            mask_to_show = np.moveaxis(mask_to_show, 0, -1) # Ensure mask_to_show is HWC
            overlay[mask_to_show[:,:,np.argmax(sam_outputs.iou_scores[0,i,:])]] = color # Use a threshold, SAM masks are float [0,1]

            axs[1].imshow(overlay.astype(np.uint8), alpha=0.3) # Overlay mask

            # Add corresponding label text if possible
            if i < len(detected_labels):
                 box_coords = [round(coord, 2) for coord in detected_boxes[i].tolist()]
                 axs[1].text(box_coords[0], box_coords[1]-12, detected_labels[i], color='white', backgroundcolor='black', fontsize=8)

    axs[1].set_title("SAM Segmentations")
    axs[1].axis('off')

    # Plot 3: Ground Truth Label Mask
    if ground_truth_labels is not None and 0 <= image_index < len(ground_truth_labels):
        gt_mask = ground_truth_labels[image_index].cpu().numpy()
        axs[2].imshow(gt_mask, cmap='tab20') # Assuming GT is a single mask with different class IDs
        axs[2].set_title("Ground Truth Label")
    else:
        axs[2].set_title("Ground Truth (Not available/plotted)")
    axs[2].axis('off')

    plt.tight_layout()
    plt.show()

    return {"dino_results": dino_results_processed, "sam_masks": sam_masks_processed, "sam_results": sam_outputs if sam_masks_processed is not None else None}


In [None]:
for i in range(10):
    detect_objects_on_image(mammo_images, i, "baloon.", None)

In [None]:
detect_objects_on_image(coco_images, 0, "traffic light . car . person .", coco_labels)

In [None]:
detect_objects_on_image(coco_images, 0, coco_categories, coco_labels) 

In [None]:
detect_objects_on_image(vaihingen_images, 0, "cars . window . yellow car .", vaihingen_labels)

In [None]:
segmentation_results = segment_objects_on_image(
    images_dataset=vaihingen_images, 
    image_index=0, 
    text_prompts_str=" window . yellow car .", 
    ground_truth_labels=vaihingen_labels
)

In [None]:
for i in range(10):
    segment_objects_on_image(
        images_dataset=mammo_images, 
        image_index=i, 
        text_prompts_str="Jumbo jet.", 
        ground_truth_labels=None
    )

In [None]:
ious = []
for i in range(len(coco_images)):
    print(f"Segmenting image {i} of {len(coco_images)}")
    result = segment_objects_on_image(
        images_dataset=coco_images, 
        image_index=i, 
        text_prompts_str=coco_categories, 
        ground_truth_labels=coco_labels
    )
    masks = result['sam_masks']
    if masks is None or len(masks) == 0:
        print(f"No masks generated for image {i}. Skipping.")
        continue
    iou_scores = result['sam_results'].iou_scores
    iou = 0.0
    joined_mask = np.zeros_like(coco_labels[i].cpu().numpy(), dtype=bool)
    for j, mask in enumerate(masks):
        mask = mask[np.argmax(iou_scores[0,j,:])]
        
        mask_np = mask.cpu().numpy() if hasattr(mask, 'cpu') else np.array(mask)
        mask_np = mask_np.astype(bool)
        # For binary IoU: treat all nonzero in gt_mask as foreground
        joined_mask |= mask_np

    gt_mask = coco_labels[i].cpu().numpy()
    gt_mask_bin = gt_mask > 0

    intersection = np.logical_and(joined_mask, gt_mask_bin).sum()
    union = np.logical_or(joined_mask, gt_mask_bin).sum()
    iou += intersection / union if union > 0 else 0.0
    
    print(f"IoU for image {i}: {iou:.4f}")
    ious.append(iou)



In [None]:
np.average(ious)