- this model needs a transformers build from source.

In [None]:
from transformers import Sam3Processor, Sam3Model
import torch
from PIL import Image
import requests
import os

os.environ["HF_TOKEN"] = "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"

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Sam3Model.from_pretrained("facebook/sam3").to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")

# Load image
image_url = "http://images.cocodataset.org/val2017/000000077595.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")

# Segment using text prompt  
inputs = processor(images=image, text="ear", return_tensors="pt").to(device)  

with torch.no_grad():
    outputs = model(**inputs)  

# Post-process results
results = processor.post_process_instance_segmentation(
    outputs,
    threshold=0.5,
    mask_threshold=0.5,
    target_sizes=inputs.get("original_sizes").tolist()
)[0]  

print(f"Found {len(results['masks'])} objects")  
# Results contain:  
# - masks: Binary masks resized to original image size  
# - boxes: Bounding boxes in absolute pixel coordinates (xyxy format)  
# - scores: Confidence scores  

In [None]:
# Check if there's a cls_token parameter in embeddings
embeddings = model.vision_encoder.backbone.embeddings

# List all parameters/buffers
for name, param in embeddings.named_parameters():
    print(f"{name}: {param.shape}")

# Specifically check for cls_token
if hasattr(embeddings, 'cls_token'):
    print(f"CLS token exists! Shape: {embeddings.cls_token.shape}")
else:
    print("No CLS token in this model")

In [None]:
import torch.nn.functional as F

def get_image_embedding(model, processor, image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output = model.vision_encoder.backbone(pixel_values=inputs['pixel_values'])
        features = output.last_hidden_state  # [1, 5184, 1024]
        
        # Global Average Pooling
        embedding = features.mean(dim=1)  # [1, 1024]
        
        # L2 normalize
        embedding = F.normalize(embedding, p=2, dim=-1)
    
    return embedding.squeeze(0).cpu().numpy()  # [1024]

get_image_embedding(model, processor, image)

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np

FOLD_PATH = "/workspace/yolo_dataset_4_dec"
SAVE_PATH = "/workspace/embeddings.npz"
BATCH_SIZE = 32

def extract_embeddings_batched(model, processor, root_path, device, batch_size=32):
    root = Path(root_path)
    image_extensions = {'.jpg', '.jpeg', '.png', '.webp'}
    
    image_paths = sorted([p for p in root.rglob('*') if p.suffix.lower() in image_extensions])
    print(f"Found {len(image_paths)} images")
    
    all_embeddings = []
    valid_paths = []
    valid_labels = []
    valid_splits = []
    
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing batches"):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = []
        batch_meta = []
        
        for p in batch_paths:
            try:
                img = Image.open(p).convert("RGB")
                batch_images.append(img)
                batch_meta.append({
                    'path': str(p),
                    'label': p.parent.name,
                    'split': p.parents[1].name if 'train' in str(p) or 'val' in str(p) else 'unknown'
                })
            except Exception as e:
                print(f"Skipping {p}: {e}")
                continue
        
        if not batch_images:
            continue
        
        inputs = processor(images=batch_images, return_tensors="pt").to(device)
        
        with torch.no_grad():
            output = model.vision_encoder.backbone(pixel_values=inputs['pixel_values'])
            features = output.last_hidden_state
            embeddings = features.mean(dim=1)
            embeddings = F.normalize(embeddings, p=2, dim=-1)
        
        all_embeddings.append(embeddings.cpu().numpy())
        for meta in batch_meta:
            valid_paths.append(meta['path'])
            valid_labels.append(meta['label'])
            valid_splits.append(meta['split'])
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return np.vstack(all_embeddings), valid_paths, valid_labels, valid_splits

embeddings, paths, labels, splits = extract_embeddings_batched(
    model, processor, FOLD_PATH, device, batch_size=BATCH_SIZE
)

print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Train samples: {splits.count('train')}")
print(f"Val samples: {splits.count('val')}")
print(f"Classes: {set(labels)}")

np.savez(
    SAVE_PATH,
    embeddings=embeddings,
    paths=np.array(paths),
    labels=np.array(labels),
    splits=np.array(splits)
)
print(f"\nSaved to: {SAVE_PATH}")