# Embed SLAM Models Comparison

This notebook compares the models from `embed_slam`: ConceptFusion, DINOFusion, XFusion, and NARadioFusion.

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os

from vlm_eval.core import EncoderRegistry
from vlm_eval.encoders import *

# Ensure models are registered
print("Available encoders:", EncoderRegistry.list_available())

In [None]:
# Load Models
# Note: You need to have the necessary checkpoints and dependencies installed.

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

models = {}
model_names = ["concept_fusion", "dino_fusion", "x_fusion", "naradio_fusion"]

for name in model_names:
    try:
        print(f"Loading {name}...")
        models[name] = EncoderRegistry.get(name, device=device)
        print(f"Loaded {name}")
    except Exception as e:
        print(f"Failed to load {name}: {e}")


In [None]:
# Load a sample image
# Replace with your image path
image_path = "../examples/sample_image.jpg"

if not os.path.exists(image_path):
    # Create a dummy image if not exists
    dummy_img = Image.fromarray(np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8))
    dummy_img.save(image_path)
    print(f"Created dummy image at {image_path}")

image = Image.open(image_path).convert("RGB")
plt.imshow(image)
plt.title("Input Image")
plt.show()

# Preprocess
image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.unsqueeze(0).to(device)
print("Image tensor shape:", image_tensor.shape)

In [None]:
# Run Inference and Compare
text_query = "chair"

fig, axes = plt.subplots(1, len(models), figsize=(20, 5))
if len(models) == 1:
    axes = [axes]

for i, (name, model) in enumerate(models.items()):
    print(f"Running {name}...")
    with torch.no_grad():
        # Get image features
        features = model(image_tensor) # (B, C, H, W)
        
        # Get text features
        text_emb = model.encode_text([text_query]) # (1, C)
        
        # Compute similarity
        # features: (1, C, H, W)
        # text_emb: (1, C)
        sim = torch.einsum("bchw,bc->bhw", features, text_emb)
        
        sim_map = sim[0].cpu().numpy()
        
        axes[i].imshow(sim_map, cmap="jet")
        axes[i].set_title(f"{name} - '{text_query}'")
        axes[i].axis("off")

plt.tight_layout()
plt.show()