## VGGT

In [None]:
import torch
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

import os
import gc
from copy import deepcopy
from scripts import utils, features
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from sklearn.metrics import silhouette_score
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")

In [None]:
DATA_DIR = "../data/image-matching-challenge-2025"
VGGT_DIR = "weights/vggt-1B"
OUTPUT_FILE = "train_predictions.csv"
FEATURES_DIR = "vggt_features/last_features"
# Configure dataset filtering 
DATASETS_FILTER = [
    # New 2025 datasets
    "amy_gardens",
    "ETs",
    "fbk_vineyard",
    "stairs",
    # Data from IMC 2023 and 2024.
    'imc2024_dioscuri_baalshamin',
    'imc2023_theather_imc2024_church',
    'imc2023_heritage',
    'imc2023_haiper',
    'imc2024_lizard_pond',
    # Crowdsourced PhotoTourism data.
    'pt_stpeters_stpauls',
    'pt_brandenburg_british_buckingham',
    'pt_piazzasanmarco_grandplace',
    'pt_sacrecoeur_trevi_tajmahal',
]

In [None]:
# Load the dataset
samples = utils.dataset.load_dataset(DATA_DIR)

for dataset in samples:
    print(f'Dataset "{dataset}" -> num_images={len(samples[dataset])}')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

# Initialize the model and load the pretrained weights.
model = VGGT.from_pretrained(VGGT_DIR).to(device)

## Do Inference

In [None]:
def patch_pooling(x):
    return np.mean([image_feature[5:, :] for image_feature in x], axis=1)

In [None]:
# Clear memory to prevent OOM errors
gc.collect()
mapping_result_strs = []  # Store results for each dataset

print(f"Extracting on device {device}")
# Process each dataset
for dataset, predictions in samples.items():
    # Skip datasets not in filter list
    if DATASETS_FILTER and dataset not in DATASETS_FILTER:
        print(f'Skipping "{dataset}"')
        continue

    # Setup paths and image lists
    images_dir = os.path.join(DATA_DIR, "train", dataset)
    images = sorted([os.path.join(images_dir, p.filename) for p in predictions])
    features_dir = os.path.join(FEATURES_DIR, dataset)

    # get all files with .pt extension
    features_files = [f for f in os.listdir(features_dir) if f.endswith('.pt')]
    vggt_features = []
    for feature_file in features_files:
        feature_path = os.path.join(features_dir, feature_file)
        feature = torch.load(feature_path)[0]
        if feature.shape[1] != 1374:
            num_images, num_patches, num_channels = feature.shape
            register_tokens = feature[:, :5, :]
            patch_tokens = feature[:, 5:, :].reshape(num_images, -1, 37, num_channels).permute(0, 3, 1, 2)
            patch_tokens_interpolated = torch.nn.functional.interpolate(
                patch_tokens, size=(37, 37), mode='bilinear', align_corners=False
            ).permute(0, 2, 3, 1).reshape(num_images, -1, num_channels)
            feature = torch.cat((register_tokens, patch_tokens_interpolated), dim=1)
            assert feature.shape[1] == 1374, f"Feature shape mismatch: {feature.shape}"
        vggt_features.append(feature)
    vggt_features = torch.cat(vggt_features, dim=0)

    # Map filenames to prediction indices
    filename_to_index = {p.filename: idx for idx, p in enumerate(predictions)}

    try:
        # cluster the features
        reduced_features = features.extraction.feature_reducer(
            algorithm="UMAP",
            features=np.vstack(patch_pooling(vggt_features.cpu().numpy())),
            n_components=20,
            random_state=42,
        )
        cluster_labels = features.clustering.dino_clusterer(
                algorithm="HDBSCAN",
                features=reduced_features,
                scaler=None,
                min_cluster_size=2,
        )
        print(
            f"Clustering. Number of clusters: {np.unique(cluster_labels)}, with {sum(cluster_labels == -1)} outliers."
        )
        gc.collect()
        vggt_features = vggt_features[cluster_labels != -1]
        images_np = np.array(images)[cluster_labels != -1]
        cluster_labels = cluster_labels[cluster_labels != -1]
        for cluster in np.unique(cluster_labels):
            vggt_features_cluster = vggt_features[cluster_labels == cluster]
            cluster_images = images_np[cluster_labels == cluster]

            print(f"Processing Cluster {cluster}: {len(vggt_features_cluster)} images")

            pose_enc = model.camera_head([vggt_features_cluster.unsqueeze(0)])[-1]
            translations = pose_enc[0, :, :3]
            rotation_matrices = utils.camera.quat_to_cam_pose(pose_enc[0, :, 3:7])

            for image_path, translation, rotation_matrix in zip(cluster_images, translations, rotation_matrices):
                prediction_index = filename_to_index[os.path.basename(image_path)]
                predictions[prediction_index].cluster_index = cluster
                predictions[prediction_index].translation = deepcopy(translation.detach().cpu().numpy())
                predictions[prediction_index].rotation = deepcopy(rotation_matrix.detach().cpu().numpy())

        mapping_result_str = f'Dataset "{dataset}" -> {len(images)} images with {len(np.unique(cluster_labels))} clusters'
        mapping_result_strs.append(mapping_result_str)
        print(mapping_result_str)
        gc.collect()

    except Exception as e:
        print(e)
        mapping_result_str = f'Dataset "{dataset}" -> Failed!'
        mapping_result_strs.append(mapping_result_str)
        print(mapping_result_str)

# Print summary of results
print("\nResults")
for s in mapping_result_strs:
    print(s)

In [None]:
# Create a submission file.
utils.submission.create_submission_file(samples, OUTPUT_FILE)

!head {OUTPUT_FILE}

In [None]:
final_score, dataset_scores = utils.metric.score(
    gt_csv=os.path.join(DATA_DIR, "train_labels.csv"),
    user_csv=OUTPUT_FILE,
    thresholds_csv=os.path.join(DATA_DIR, "train_thresholds.csv"),
    mask_csv=None,
    inl_cf=0,
    strict_cf=-1,
    verbose=True,
)