In [1]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import json
import glob

In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

using device: cuda


In [3]:
np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

def save_masks(model_name, index, image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    plt.ioff()
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        os.makedirs(f"qualitative/{model_name}", exist_ok=True)
        plt.savefig(f"qualitative/{model_name}/mask_{index}.png", bbox_inches='tight')
        plt.close()

In [4]:
sa1b_path = Path("../../datasets/benchmark/sa1b")
sa1b_images = glob.glob(str(sa1b_path / "*.jpg"))

In [5]:
sa1b_images[0]

'../../datasets/benchmark/sa1b/sa_10034.jpg'

In [6]:
sa1b_points = {
    sa1b_images[0]: [[1000, 600]],
    sa1b_images[1]: [[1000, 1200]],
    sa1b_images[2]: [[250, 1500]],
    sa1b_images[3]: [[1600, 700]],
    sa1b_images[4]: [[1250, 1000]]
}

In [7]:
sa1b_points

{'../../datasets/benchmark/sa1b/sa_10034.jpg': [[1000, 600]],
 '../../datasets/benchmark/sa1b/sa_10110.jpg': [[1000, 1200]],
 '../../datasets/benchmark/sa1b/sa_10129.jpg': [[250, 1500]],
 '../../datasets/benchmark/sa1b/sa_10202.jpg': [[1600, 700]],
 '../../datasets/benchmark/sa1b/sa_10487.jpg': [[1250, 1000]]}

In [8]:
alpq_sam2_config = {
    "base_plus": {
        "model_name": "alpq_sam2_base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "sam2_checkpoint": "../sam2_logs/classic/adaptive_qat_toy_base_plus_20251110_155500/checkpoints/checkpoint.pt",
    },
    "small": {
        "model_name": "alpq_sam2_small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "sam2_checkpoint": "../sam2_logs/classic/adaptive_qat_toy_small_20251111_172858/checkpoints/checkpoint.pt",
    },
    "tiny": {
        "model_name": "alpq_sam2_tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "sam2_checkpoint": "../sam2_logs/classic/adaptive_qat_toy_tiny_20251112_161453_importancefixed/checkpoints/checkpoint.pt",
    },
}

sam2_config = {
    "base_plus": {
        "model_name": "sam2_base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "sam2_checkpoint": "../checkpoints/sam2.1_hiera_base_plus.pt",
    },
    "small": {
        "model_name": "sam2_small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "sam2_checkpoint": "../checkpoints/sam2.1_hiera_small.pt",
    },
    "tiny": {
        "model_name": "sam2_tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "sam2_checkpoint": "../checkpoints/sam2.1_hiera_tiny.pt",
    },
}

minmax_sam2_config = {
    "base_plus": {
        "model_name": "minmax_sam2_base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "sam2_checkpoint": "../sam2_minmax/minmax_qat_base_plus_20251111_122542/checkpoints/checkpoint_sam2.pt",
    },
    "small": {
        "model_name": "minmax_sam2_small",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_s.yaml",
        "sam2_checkpoint": "../sam2_minmax/minmax_qat_small_20251111_233441/checkpoints/checkpoint_sam2.pt",
    },
    "tiny": {
        "model_name": "minmax_sam2_tiny",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_t.yaml",
        "sam2_checkpoint": "../sam2_minmax/minmax_qat_tiny_20251111_165611/checkpoints/checkpoint_sam2.pt",
    },
}

baseonly_sam2_config = {
    "base_plus": {
        "model_name": "baseonly_sam2_base_plus",
        "model_cfg": "configs/sam2.1/sam2.1_hiera_b+.yaml",
        "sam2_checkpoint": "../sam2_logs/ablations/adaptive_qat_toy_base_plus_20251112_101653/checkpoints/checkpoint.pt",
    },
}

In [9]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

def make_qualitatives(sa1b_path, model_config):
    model_cfg = model_config["model_cfg"]
    sam2_checkpoint = model_config["sam2_checkpoint"]

    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
    predictor = SAM2ImagePredictor(sam2_model)

    # Fix KeyError by always using the string path as the key into sa1b_points
    for i, sa1b_image in enumerate(sa1b_path.glob("*.jpg")):
        key = str(sa1b_image)
        if key not in sa1b_points:
            print(f"Warning: {key} not found in sa1b_points, skipping.")
            continue
        input_point = np.array(sa1b_points[key])
        input_label = np.array([1])
        image = Image.open(sa1b_image)
        image = np.array(image.convert("RGB"))
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        sorted_ind = np.argsort(scores)[::-1]
        masks = masks[sorted_ind]
        scores = scores[sorted_ind]
        logits = logits[sorted_ind]
        save_masks(model_config["model_name"], i+1, image, masks, scores, point_coords=input_point, input_labels=input_label)

In [14]:
for model_name in alpq_sam2_config.keys():  
    make_qualitatives(sa1b_path, alpq_sam2_config[model_name])
    
for model_name in sam2_config.keys():
    make_qualitatives(sa1b_path, sam2_config[model_name])

  plt.figure(figsize=(10, 10))


In [17]:
for model_name in minmax_sam2_config.keys():
    make_qualitatives(sa1b_path, minmax_sam2_config[model_name])

In [10]:
for model_name in baseonly_sam2_config.keys():
    make_qualitatives(sa1b_path, baseonly_sam2_config[model_name])