In [None]:
# =============================================================================
# Project       : GSFeatLoc: Visual Localization Using Feature Correspondence on 3D Gaussian Splatting
# File          : gsfeatloc-demo.ipynb
# Description   : This is the demo notebook for GSFeatLoc.
# 
# Author        : Jongwon Lee (jongwon5@illinois.edu)
# Year          : 2025
# License       : BSD License
# =============================================================================

## GSFeatLoc: Visual Localization Using Feature Correspondence on 3D Gaussian Splatting

#### Import libraries and define utility functions

In [None]:
from nerfbaselines import new_cameras, camera_model_to_int
import numpy as np
from scipy.spatial.transform import Rotation as R
import json
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import cv2 as cv
import time
import os
import pycolmap

# Import the utils module from the gs-loc package

# gsfeatloc is located in the parent directory of this script
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from gsfeatloc.utils import get_world_frame_difference, perturb_SE3
from gsfeatloc.visualizer import visualize_matches, visualize_feature_points, visualize_3d_points
from gsfeatloc.feature_matcher import do_feature_matching_SIFT, do_feature_matching_SPSG, do_feature_matching_LoFTR
from gsfeatloc.dataset_loader import load_frames_blender, load_frames_colmap
from gsfeatloc.pose_estimator import get_2d_feature_points, compute_3d_points, estimate_camera_pose

np.set_printoptions(suppress=True, precision=4)

rng = np.random.default_rng(0)


#### Load models and files

Load the trained Gsplat model

In [None]:
# We use the ExitStack to manage the context managers such that
# we can persist the contexts between cells. This is rarely needed
# in practice, but it is useful for this tutorial.
from contextlib import ExitStack
stack = ExitStack().__enter__()

import pprint
from nerfbaselines import load_checkpoint

dataset = "blender"
sequence = "lego"
downscale_factor = 1

scene_path = f"/home/jongwonlee/datasets/nerfbaselines/{dataset}/{sequence}"
checkpoint_path = f"/home/jongwonlee/models/gsplat/{dataset}/{sequence}/checkpoint-30000"
reading_dir = "test" if "blender" is dataset else f"images_{downscale_factor}" if downscale_factor > 1 else "images"

# Start the docker backend and load the checkpoint
model, _ = stack.enter_context(load_checkpoint(checkpoint_path, backend="docker"))

# Print model information
pprint.pprint(model.get_info())

Load the metadata with ground-truth

Load a frame's filename, its ground-truth pose, and camera parameters

In [None]:
# Determine the source of test data
colmap_sparse_path = Path(scene_path, "sparse", "0")
test_list_path = Path(scene_path, "test_list.txt")
transforms_test_path = Path(scene_path, "transforms_test.json")

if colmap_sparse_path.exists():
    print(f"Using test data from COLMAP reconstruction in {colmap_sparse_path}")
    K, w, h, frames = load_frames_colmap(colmap_sparse_path, downscale_factor)

    if not test_list_path.exists():
        raise FileNotFoundError(f"test_list.txt not found in {scene_path}. Please provide a valid path to the dataset.")

    # Load and filter frames based on the test list
    with open(test_list_path, "r") as f:
        test_list = f.read().splitlines()
        test_list_without_ext = [os.path.splitext(filename)[0] for filename in test_list]
    
    frames = {filename: frames[filename] for filename in test_list_without_ext if filename in frames}
    print(f"Total number of frames: {len(frames)}")

elif transforms_test_path.exists():
    print(f"Using test data from {transforms_test_path}")
    K, w, h, frames = load_frames_blender(transforms_test_path)
    print(f"Total number of frames: {len(frames)}")

else:
    raise FileNotFoundError(f"No test data found in {scene_path}. Please provide a valid path to the dataset.")

In [None]:
# Randomly select a frame from the test set
filenames = sorted(list(frames.keys()))
filename = rng.choice(filenames)
filename = "r_75"
print(f"Selected frame: {filename}")

Load the query image (with the ground-truth pose T_inW_ofC)

In [None]:
im_query = np.array(Image.open(Path(scene_path, reading_dir, filename + ".JPG"))) if dataset == "mipnerf360" else \
    np.array(Image.open(Path(scene_path, reading_dir, filename + ".png")))
# depth_query_raw = np.array(Image.open(Path(scene_path, filename + "_depth_0001.png")), dtype=np.float32)[...,0]

# Show the image
# plt.gcf().set_size_inches(4, 4)
plt.imshow(im_query)
plt.axis("off")
plt.show()
plt.tight_layout()

In [None]:
# The actual image size may be different from the image size being loaded (by up to two pixels)
# If that's the case, re-assign h and w to mach the image size

assert abs(im_query.shape[0] - h) <= 2, f"Image height mismatch: {im_query.shape[0]} vs {h}"
assert abs(im_query.shape[1] - w) <= 2, f"Image width mismatch: {im_query.shape[1]} vs {w}"

if im_query.shape[0] != h or im_query.shape[1] != w:
    print(f"[Warning] Image size mismatch: {im_query.shape[:2]} vs ({h}, {w})")
    h, w = im_query.shape[:2]
    print(f"[Warning] Updated image size: {h}, {w}")


Perturb the camera pose. This will serve as a reference pose of the query.

In [None]:
# Perturb the camera pose
M = 12.10  # Rotation perturbation in degrees
N = 0.87  # Translation perturbation in meters

T_inW_ofC = frames[filename]

T_inW_ofC_perturbated = perturb_SE3(T_inW_ofC, 
                                    rotation_magnitude=30.00, 
                                    translation_magnitude=0.20, 
                                    axis_mode='random', magnitude_mode='gaussian', 
                                    rng=rng)
# Print the difference between the original and perturbated camera pose
get_world_frame_difference(T_inW_ofC, T_inW_ofC_perturbated)

Render the image at the initial guess; this will serve as a reference image (with a known pose) against which the query performs feature matching to estimate its pose.

In [None]:
# Create camera object
camera = new_cameras(
    poses=T_inW_ofC_perturbated,
    intrinsics=np.array([K[0,0], K[1,1], K[0,2], K[1,2]], dtype=np.float32),
    image_sizes=np.array([w, h], dtype=np.int32),
    camera_models=np.array(camera_model_to_int("pinhole"), dtype=np.int32),
)

tic = time.time()
# Render the image
outputs = model.render(camera=camera, options={"outputs": "depth", "output_type_dtypes": {'color': 'uint8', 'depth': 'float32'}})
toc = time.time()
print(f"Rendering time: {toc - tic:.2f} seconds")

# print(camera)
# print(outputs.keys())

im_reference = outputs["color"]
depth_reference = outputs["depth"]

# Show images side by side
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(im_query)
ax[0].set_title("Query Image (Original)")
ax[0].axis("off")
ax[1].imshow(im_reference)
ax[1].set_title("Reference Image (Synthetic)")
ax[1].axis("off")
plt.show()

#### Feature matching

Perform SIFT feature matching between the images

In [None]:
# Perform feature matching
tic = time.time()
pts_query_raw_cv, pts_reference_raw_cv, matches = do_feature_matching_SIFT(im_query, im_reference, ratio=0.7, do_visualize=True)
toc = time.time()

print(f"Elapsed time: {toc - tic:.2f} seconds")

Perform  SuperPoint and SuperGlue for feature matching between the images

In [None]:
# Perform feature matching
tic = time.time()
pts_query_raw_cv, pts_reference_raw_cv, matches = do_feature_matching_SPSG(im_query, im_reference, superpoint_threshold=0.01, superglue_threshold=0.5, do_visualize=True)
toc = time.time()

print(f"Elapsed time: {toc - tic:.2f} seconds")

Perform LoFTR for feature matching between the images

In [None]:
# Perform feature matching
tic = time.time()
pts_query_raw_cv, pts_reference_raw_cv, matches = do_feature_matching_LoFTR(im_query, im_reference, do_visualize=True)
toc = time.time()

print(f"Elapsed time: {toc - tic:.2f} seconds")

#### Pose estimation

In [None]:
tic = time.time()

pts_query, pts_reference, pts_query_cv, pts_reference_cv = get_2d_feature_points(matches, pts_query_raw_cv, pts_reference_raw_cv)
p_inW, p_inC = compute_3d_points(pts_reference, depth_reference, K, T_inW_ofC_perturbated, depth_scale=4.0)
rvec, tvec, inliers = estimate_camera_pose(p_inW, pts_query, K, reprojection_error=5.0, iterations=50)

toc = time.time()
print(f"Elapsed time: {toc - tic:.2f} seconds")

# visualize_feature_points(im_query, im_reference, pts_query, pts_reference, inliers)
# visualize_matches(im_query, im_reference, pts_query_raw_cv, pts_reference_raw_cv, [matches[i] for i in inliers])
# Optionally visualize 3D points
# visualize_3d_points(p_inC)

In [None]:
# Step 1: Reproject 3D points using estimated pose
pts_reprojected, _ = cv.projectPoints(p_inW, rvec, tvec, K, None)

# Convert to numpy array for easier processing
pts_reprojected = pts_reprojected.reshape(-1, 2)

# Step 2: Compute Reprojection Error
reprojection_errors = np.linalg.norm(pts_query - pts_reprojected, axis=1)

# Compute Mean and Median Reprojection Error
mean_reprojection_error = np.nanmean(reprojection_errors)
median_reprojection_error = np.nanmedian(reprojection_errors)

print(f"Mean Reprojection Error: {mean_reprojection_error:.2f} pixels")
print(f"Median Reprojection Error: {median_reprojection_error:.2f} pixels")

plt.hist(reprojection_errors, bins=30, edgecolor='black')
plt.xlabel("Reprojection Error (pixels)")
plt.ylabel("Number of Points")
plt.title("Reprojection Error Distribution")
plt.show()

In [None]:
# Construct estimated transformation matrix
T_inC_ofW_estimated = np.eye(4, dtype=np.float32)
T_inC_ofW_estimated[:3, :3] = R.from_rotvec(rvec.flatten()).as_matrix()
T_inC_ofW_estimated[:3, 3] = tvec.flatten()

T_inW_ofC_estimated = np.linalg.inv(T_inC_ofW_estimated)

# Print the difference between the two poses
get_world_frame_difference(T_inW_ofC, T_inW_ofC_perturbated)
get_world_frame_difference(T_inW_ofC, T_inW_ofC_estimated)

In [None]:
# Superimpose the query image and the image rendered at T_inW_ofC_estimated

# Create camera object
camera = new_cameras(
    poses=T_inW_ofC_estimated,
    intrinsics=np.array([K[0,0], K[1,1], K[0,2], K[1,2]], dtype=np.float32),
    image_sizes=np.array([w, h], dtype=np.int32),
    camera_models=np.array(camera_model_to_int("pinhole"), dtype=np.int32),
)

# Render the image
outputs = model.render(camera=camera, options={"outputs": "depth", "output_type_dtypes": {'color': 'uint8', 'depth': 'float32'}})

im_estimated = outputs["color"]
depth_estimated = outputs["depth"]

fig, ax = plt.subplots(1, 4, figsize=(16, 4))

# im_query has alpha channel, so we need to remove it
if im_query.shape[-1] == 4:
    im_query_copy = cv.cvtColor(im_query, cv.COLOR_RGBA2RGB)
    # Change the background to white
    im_query_copy[im_query[..., 3] == 0] = 255
else:
    im_query_copy = im_query

ax[0].imshow(im_query_copy)
ax[0].set_title("Query Image")
ax[0].axis("off")
ax[1].imshow(im_reference)
ax[1].set_title("Rendered Image (Initial Guess)")
ax[1].axis("off")
ax[2].imshow(im_estimated)
ax[2].set_title("Rendered Image (Estimate)")
ax[2].axis("off")

# Blend the images
im_blended = cv.addWeighted(im_query_copy, 0.5, im_estimated, 0.5, 0)

# Display the blended image
ax[3].imshow(im_blended)
ax[3].set_title("Blended Image (Query + Estimate)")
ax[3].axis("off")

plt.tight_layout()
plt.show()

#### Now, put all the things in a single code

In [None]:
M = 12.10 # Rotation perturbation in degrees
N = 0.87  # Translation perturbation in meters

num_iterations = 20  # Number of iterations for the loop
rot_errors = [None] * num_iterations
trs_errors = [None] * num_iterations
elapsed_times = [None] * num_iterations

# For loop starts here ...

for i in range(num_iterations):
    filename = rng.choice(filenames)
    print(f"Selected frame: {filename}")
    im_query = np.array(Image.open(Path(scene_path, reading_dir, filename + ".JPG"))) if dataset == "mipnerf360" else \
        np.array(Image.open(Path(scene_path, reading_dir, filename + ".png")))
    T_inW_ofC = frames[filename]

    tic = time.time()

    # Create random pose perturbation and render a reference image
    T_inW_ofC_perturbated = perturb_SE3(T_inW_ofC, 
                                        rotation_magnitude=M, 
                                        translation_magnitude=N, 
                                        axis_mode='random', magnitude_mode='gaussian', 
                                        rng=rng)

    camera = new_cameras(
        poses=T_inW_ofC_perturbated,
        intrinsics=np.array([K[0,0], K[1,1], K[0,2], K[1,2]], dtype=np.float32),
        image_sizes=np.array([w, h], dtype=np.int32),
        camera_models=np.array(camera_model_to_int("pinhole"), dtype=np.int32),
    )
    outputs = model.render(camera=camera, options={"outputs": "depth", "output_type_dtypes": {'color': 'uint8', 'depth': 'float32'}})
    im_reference = outputs["color"]
    depth_reference = outputs["depth"]

    # Do feature detection and matching (LoFTR) on both the query and the reference image
    pts_query_raw_cv, pts_reference_raw_cv, matches = do_feature_matching_SPSG(im_query, im_reference, do_visualize=False)

    # Do pose estimation
    pts_query, pts_reference, pts_query_cv, pts_reference_cv = get_2d_feature_points(matches, pts_query_raw_cv, pts_reference_raw_cv)
    
    # Check all the shapes
    assert pts_query.shape[0] == pts_reference.shape[0], "Number of points do not match"
    assert len(pts_query_cv) == len(pts_reference_cv) == len(matches), "Number of points do not match"

    # Ensure there are enough matches for pose estimation
    if len(matches) >= 4:
        # Compute 3D points in the world and camera frames
        p_inW, p_inC = compute_3d_points(
        pts_reference, depth_reference, K, T_inW_ofC_perturbated, depth_scale=4.0
        )
        
        # Estimate the camera pose using solvePnPRansac
        rvec, tvec, inliers = estimate_camera_pose(
        p_inW, pts_query, K, reprojection_error=5.0, iterations=50
        )

        if inliers is not None and len(inliers) > 0:
            # Construct the estimated transformation matrix
            T_inC_ofW_estimated = np.eye(4, dtype=np.float32)
            T_inC_ofW_estimated[:3, :3] = R.from_rotvec(rvec.flatten()).as_matrix()
            T_inC_ofW_estimated[:3, 3] = tvec.flatten()

            # Compute the inverse to get the camera-to-world transformation
            T_inW_ofC_estimated = np.linalg.inv(T_inC_ofW_estimated)
        else:
            # Fallback to the perturbated pose if pose estimation fails
            print(f"Pose estimation failed for frame {filename}. Using perturbated pose.")
            T_inW_ofC_estimated = T_inW_ofC_perturbated
    else:
        print(f"Not enough matches for frame {filename}. Skipping this frame.")
        T_inW_ofC_estimated = T_inW_ofC_perturbated

    toc = time.time()

    # Report the error
    T_inC_ofW_estimated = np.eye(4, dtype=np.float32)
    T_inC_ofW_estimated[:3, :3] = R.from_rotvec(rvec.flatten()).as_matrix()
    T_inC_ofW_estimated[:3, 3] = tvec.flatten()
    T_inW_ofC_estimated = np.linalg.inv(T_inC_ofW_estimated)

    get_world_frame_difference(T_inW_ofC, T_inW_ofC_perturbated)
    rot_error, trs_error = get_world_frame_difference(T_inW_ofC, T_inW_ofC_estimated)
    
    # print(f"Elapsed time: {toc - tic:.2f} seconds")

    rot_errors[i] = rot_error
    trs_errors[i] = trs_error
    elapsed_times[i] = toc - tic
    print(f"Iteration {i+1}/{num_iterations}: Rotation Error = {rot_error:.2f} degrees, Translation Error = {trs_error:.2f} meters, Elapsed Time = {toc - tic:.2f} seconds")

In [None]:
# Print the average errors and elapsed time
avg_rot_error = np.nanmean(rot_errors)
avg_trs_error = np.nanmean(trs_errors)
avg_elapsed_time = np.nanmean(elapsed_times)
print(f"\nAverage Rotation Error: {avg_rot_error:.2f} degrees")
print(f"Average Translation Error: {avg_trs_error:.2f} meters")
print(f"Average Elapsed Time: {avg_elapsed_time:.2f} seconds")

In [None]:
# Count the number of outcomes with <5deg and <0.05m criteria
count_5deg = np.sum(np.array(rot_errors) < 5)
count_5cm = np.sum(np.array(trs_errors) < 0.05)
print(f"Number of outcomes with <5deg: {count_5deg} (ratio: {count_5deg / len(rot_errors):.2f})")
print(f"Number of outcomes with <0.05m: {count_5cm} (ratio: {count_5deg / len(trs_errors):.2f})")

In [None]:
# In this tutorial, we used `ExitStack` to simplify context management. 
# We need to close the context to release the memory.
stack.close() 