# TUM Rosbag Loading and Model Comparison

This notebook demonstrates how to load TUM dataset rosbags and compare `embed_slam` models on the sequence.

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import os

from vlm_eval.core import EncoderRegistry, DatasetRegistry
from vlm_eval.encoders import *
from vlm_eval.datasets.tum_rosbag import TUMRosbagDataset

# Ensure models are registered
print("Available encoders:", EncoderRegistry.list_available())

In [None]:
# Configuration
bag_path = "/home/jovyan/cps_persistent1_shared/datasets/public/TUM/rgbd_dataset_freiburg1_room.bag" # Update this path
device = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists(bag_path):
    print(f"Warning: Bag file not found at {bag_path}. Please update the path.")
else:
    # Load Dataset
    dataset = TUMRosbagDataset(bag_path, topics=["/camera/rgb/image_color"])
    print(f"Loaded dataset with {len(dataset)} frames.")

In [1]:
# Load Models
models = {}
model_names = ["concept_fusion", "dino_fusion", "x_fusion", "naradio_fusion"]

for name in model_names:
    try:
        print(f"Loading {name}...")
        models[name] = EncoderRegistry.get(name, device=device)
        print(f"Loaded {name}")
    except Exception as e:
        print(f"Failed to load {name}: {e}")

Loading concept_fusion...
Failed to load concept_fusion: name 'EncoderRegistry' is not defined
Loading dino_fusion...
Failed to load dino_fusion: name 'EncoderRegistry' is not defined
Loading x_fusion...
Failed to load x_fusion: name 'EncoderRegistry' is not defined
Loading naradio_fusion...
Failed to load naradio_fusion: name 'EncoderRegistry' is not defined


In [None]:
# Visualize and Compare on a Frame
frame_idx = 0
text_query = "chair"

if os.path.exists(bag_path):
    data = dataset[frame_idx]
    image_tensor = data["image"].unsqueeze(0).to(device)
    
    # Display Image
    plt.figure(figsize=(5, 5))
    img_np = image_tensor[0].permute(1, 2, 0).cpu().numpy()
    plt.imshow(img_np)
    plt.title(f"Frame {frame_idx}")
    plt.axis("off")
    plt.show()
    
    # Run Models
    fig, axes = plt.subplots(1, len(models), figsize=(20, 5))
    if len(models) == 1:
        axes = [axes]
        
    for i, (name, model) in enumerate(models.items()):
        print(f"Running {name}...")
        with torch.no_grad():
            features = model(image_tensor)
            text_emb = model.encode_text([text_query])
            sim = torch.einsum("bchw,bc->bhw", features, text_emb)
            sim_map = sim[0].cpu().numpy()
            
            axes[i].imshow(sim_map, cmap="jet")
            axes[i].set_title(f"{name} - '{text_query}'")
            axes[i].axis("off")
        torch.cuda.empty_cache()
            
    plt.tight_layout()
    plt.show()