In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# SAM3 Image grounding ground truth and prediction visualization



# <a target="_blank" href="https://colab.research.google.com/github/facebookresearch/sam3/blob/main/notebooks/sam3_data_and_predictions_visualization.ipynb">
#   <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
# </a>

In [None]:
using_colab = False

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib scikit-learn
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam3.git'

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import json
import tqdm
import os
from collections import defaultdict

from sam3.visualization_utils import (
    save_side_by_side_visualization,
    convert_coco_to_masklet_format,
    single_visualization
)

### Load ground truth and predictions

In [None]:
IMAGE_ROOT = None # ADD YOUR PATH HERE
PREDICTION_THRESHOLD = 0.5
NUM_VISUALIZATIONS = 20

# _a_ is one of 3 annotators
gt_path = '../assets/gold_image_eval/gold_attributes_merged_a_test.json'
predictions_path = '../assets/sam3_predictions_gold_image/coco_predictions_gold_attr.json'

# Load ground truth
with open(gt_path, 'r') as f:
    gt_data = json.load(f)

# Load predictions
with open(predictions_path, 'r') as f:
    predictions = json.load(f)

gt_annotations = gt_data['annotations']
images = gt_data['images']
print(f"Number of GT annotations: {len(gt_annotations)}")
print(f"Number of images: {len(images)}")
print(f"Number of predictions: {len(predictions)}")


image_dict = {img['id']: img for img in images}

# Group GT annotations by image_id
gt_by_image = defaultdict(list)
for ann in gt_annotations:
    image_id = ann['image_id']
    if image_id not in gt_by_image:
        gt_by_image[image_id] = []
    gt_by_image[image_id].append(ann)

# Group predictions by image_id
pred_by_image = {}
for pred in predictions:
    image_id = pred['image_id']
    if image_id not in pred_by_image:
        pred_by_image[image_id] = []
    pred_by_image[image_id].append(pred)

In [None]:
# For visualization purposes we seperate positive and negative NPs
# negatives won't have any masks in the ground truth
positiveNPs = common_image_ids = [img_id for img_id in pred_by_image.keys() if img_id in gt_by_image and gt_by_image[img_id]]
negativeNPs = [img_id for img_id in pred_by_image.keys() if img_id not in gt_by_image or not gt_by_image[img_id]]

### Plot positive NP ground truth (left) and predictions (right)

In [None]:
for idx, image_id in enumerate(tqdm.tqdm(list(positiveNPs[:NUM_VISUALIZATIONS]))):
    img_info = image_dict[image_id]
    gt_anns = gt_by_image[image_id]

    if not gt_anns:
        print(f"No ground truth annotations for image_id {image_id}, skipping.")
        continue

    pred_anns = pred_by_image[image_id]

    image_path = os.path.join(IMAGE_ROOT, img_info["file_name"])

    # Load original image as numpy array
    img = np.array(Image.open(image_path))
    if img.shape[-1] == 4:  # Remove alpha channel if present
        img = img[..., :3]

    pred_anns = [p for p in pred_anns if p.get('score', 0) >= PREDICTION_THRESHOLD]

    gt_conveted = convert_coco_to_masklet_format(gt_anns, img_info, is_prediction=False)
    pred_converted = convert_coco_to_masklet_format(pred_anns, img_info, is_prediction=True)

    noun_phrase = img_info['text_input']

    save_side_by_side_visualization(img, gt_conveted, pred_converted, noun_phrase)
    plt.show()
    plt.close()


### Plot negative NP predictions

In [None]:
# every prediction is by definition a false positive
for idx, image_id in enumerate(tqdm.tqdm(list(negativeNPs[:NUM_VISUALIZATIONS]))):
    img_info = image_dict[image_id]
    pred_anns = pred_by_image[image_id]

    image_path = os.path.join(IMAGE_ROOT, img_info["file_name"])

    # Load original image as numpy array
    img = np.array(Image.open(image_path))
    if img.shape[-1] == 4:  # Remove alpha channel if present
        img = img[..., :3]

    pred_anns = [p for p in pred_anns if p.get('score', 0) >= PREDICTION_THRESHOLD]

    pred_converted = convert_coco_to_masklet_format(pred_anns, img_info, is_prediction=True)

    noun_phrase = img_info['text_input']

    single_visualization(img, pred_converted, noun_phrase)
    plt.show()
    plt.close()