In [None]:
from utils import (
    getMasterOutout,
    scale_intrinsics,
    CameraMatrix,
    run_pnp,
    getImageFromIndex,
)

master_size = [512,384] #size of image used by mast3r

#imports for visualizing matches
import numpy as np

from matplotlib import pyplot as pl
from mpl_toolkits.axes_grid1 import make_axes_locatable

import cv2 #for pnp
from pyproj import Proj, transform #cartographic transformations and coordinate conversions

#supressing unnecessary warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
import os
from PIL import Image


import torch
from mast3r.model import AsymmetricMASt3R
import json
#ensuring i dont exceed cpu limits?
import os
os.environ["MKL_NUM_THREADS"] = "8"
os.environ["NUMEXPR_NUM_THREADS"] = "8"
os.environ["OMP_NUM_THREADS"] = "8"

def camera_b_to_a(camera_a_to_world, camera_b_to_world):
    #Inverse of camera A to world
    world_to_camera_a = np.linalg.inv(camera_a_to_world)
    
    # Camera B to A = (A to World)^-1 * (B to World)
    camera_b_to_a = np.dot(world_to_camera_a, camera_b_to_world)
    
    return camera_b_to_a

#load model
device = 'cuda:1'
model_name = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
model = AsymmetricMASt3R.from_pretrained(model_name).to(device)

In [None]:
from torch.utils.data import Dataset
from sevenScenesDatasets import SevenScenesNBDataset
root_dir = '/datasets/7scenes_org'
pairs_file = '/home/bjangley/VPR/7scenes/pairs2/test_tuples_multiimagerelposenet.txt'
output_file = '/home/bjangley/VPR/7scenes/pairs2/results_dud.txt'
output_file2 = '/home/bjangley/VPR/7scenes/pairs2/results_n30_v1.txt'
dataset = SevenScenesNBDataset(root_dir, pairs_file, mast3r_output=None)

In [None]:
item = dataset[4233]
print(item['query_path'])
print(len(dataset))
print(item['mast3r_query_pose'])

In [None]:
from tqdm import tqdm
import json

K = CameraMatrix(585,585,320,240)
n_matches = 30
output_file = '/home/bjangley/VPR/7scenes/pairs2/results_n30_v1.txt'
# fails_file = '/home/bjangley/VPR/7scenes/pairs2/fails.json'
total_items = len(dataset)

fails = {}
start_index = 4234 #one minus the index it's going to write to
mode = 'w' if start_index == 0 else 'a'

with open(output_file, mode) as f:
    pbar = tqdm(total=total_items-start_index, desc="Processing dataset",initial=start_index)
    for i in range(start_index,total_items):
        neighbourhood = dataset[i]
        query_path = neighbourhood['query_path']
        f.write(f"{query_path}")
        w,h = Image.open(query_path).convert('RGB').size
        K_scaled = scale_intrinsics(K,w,h, master_size[0],master_size[1])
        neighbourhood_fails = []
        for index,anchor_path in enumerate(neighbourhood['anchors_path']):
            filtered_matches_im0,filtered_matches_im1,matches_im0, matches_im1, pts3d_im0, pts3d_im1, conf_im0, conf_im1, desc_conf_im0, desc_conf_im1 = getMasterOutout(model, device, anchor_path, neighbourhood['query_path'], n_matches,visualizeMatches=False,verboseFlag=False)

            ret_val, transformation = run_pnp(filtered_matches_im1.astype(np.float32), pts3d_im0[filtered_matches_im0[:, 1], filtered_matches_im0[:, 0], :].astype(np.float32), K_scaled.astype(np.float32))
            if ret_val:
                mast3r_query_pose =  np.dot(neighbourhood['anchor_poses'][index], transformation)
                f.write(f" {' '.join(map(str, transformation.flatten()))}")
                f.write(f" {' '.join(map(str, mast3r_query_pose.flatten()))}")
            else:
                neighbourhood_fails.append(index)
                f.write(" " + " ".join(["0"]*32))
        f.write(f"\n")
        # if neighbourhood_fails:
        #     fails[query_path] = neighbourhood_fails
        pbar.update(1)  # Update progress bar after each item
    pbar.close()

# with open(fails_file, 'w') as f:
#     json.dump(fails, f, indent=4)




In [None]:
def readtxt(results_file):
    results = []
    with open(results_file, 'r') as f:
        for line in f:
            chunks = line.strip().split(' ')
            query_path = chunks[0]
            if len(chunks) !=89:
                print("error")

file = "/home/bjangley/VPR/7scenes/pairs2/test_tuples_multiimagerelposenet.txt"
readtxt(file)

In [None]:
import numpy as np
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm
import matplotlib.pyplot as plt

def compute_pose_error(est_pose, gt_pose):
    # Compute positional error
    pos_error = np.linalg.norm(est_pose[:3, 3] - gt_pose[:3, 3])
    
    est_quat = R.from_matrix(est_pose[:3, :3]).as_quat()
    gt_quat = R.from_matrix(gt_pose[:3, :3]).as_quat()
    est_quat = np.concatenate(([est_quat[3]], est_quat[:3]))
    gt_quat = np.concatenate(([gt_quat[3]], gt_quat[:3]))
    # Compute the quaternion dot product and account for double covering.
    dot = np.clip(np.abs(np.dot(est_quat, gt_quat)), -1.0, 1.0)
    theta = 2 * np.arccos(dot)
    
    rot_error = np.degrees(theta)

    # # Compute rotational error
    # R_diff = np.dot(est_pose[:3, :3], gt_pose[:3, :3].T)
    # rot_error = np.degrees(np.arccos((np.trace(R_diff) - 1) / 2))
    
    return pos_error, rot_error


def evaluateresults(dataset, results_file, error_threshold):
    pos_errors = []
    rot_errors = []
    large_error_indices = []
    
    with open(results_file, 'r') as f:
        for query_idx, line in enumerate(tqdm(f, desc="Evaluating poses")):
            chunks = line.strip().split()
            query_path = chunks[0]
            
            gt_pose = dataset._load_pose(query_path)
            
            for anchor_idx in range(9):
                start_idx = 1 + anchor_idx * 32
                end_idx = start_idx + 32
                transform_and_pose = chunks[start_idx:end_idx]
                
                if not all(float(x) == 0 for x in transform_and_pose[16:]):
                    mast3r_query_pose = np.array([float(x) for x in transform_and_pose[16:]]).reshape(4, 4)
                    pos_error, rot_error = compute_pose_error(mast3r_query_pose, gt_pose)
                    pos_errors.append(pos_error)
                    rot_errors.append(rot_error)
                    
                    if pos_error > error_threshold:
                        large_error_indices.append((query_idx, anchor_idx))
    
    return pos_errors, rot_errors, large_error_indices

def plot_xz_locations(dataset, results_file, large_error_indices, error_threshold):
    gt_x, gt_z = [], []
    est_x, est_z = [], []
    large_error_x, large_error_z = [], []
    
    with open(results_file, 'r') as f:
        for query_idx, line in enumerate(tqdm(f, desc="Processing locations")):
            chunks = line.strip().split()
            query_path = chunks[0]
            
            gt_pose = dataset._load_pose(query_path)
            
            for anchor_idx in range(9):
                start_idx = 1 + anchor_idx * 32
                end_idx = start_idx + 32
                transform_and_pose = chunks[start_idx:end_idx]
                
                if not all(float(x) == 0 for x in transform_and_pose[16:]):
                    mast3r_query_pose = np.array([float(x) for x in transform_and_pose[16:]]).reshape(4, 4)
                    
                    if (query_idx, anchor_idx) in large_error_indices:
                        large_error_x.append(mast3r_query_pose[0, 3])
                        large_error_z.append(mast3r_query_pose[2, 3])
                    else:
                        est_x.append(mast3r_query_pose[0, 3])
                        est_z.append(mast3r_query_pose[2, 3])
                    
                    gt_x.append(gt_pose[0, 3])
                    gt_z.append(gt_pose[2, 3])
    
    print(f"Number of queries: {len(gt_x)}")
    print(f"Number of estimates: {len(est_x) + len(large_error_x)}")
    print(f"Number of large errors (>{error_threshold}m): {len(large_error_x)}")
    
    plt.figure(figsize=(12, 12))
    plt.scatter(gt_x, gt_z, c='blue', label='Ground Truth', alpha=0.5, s=10)
    plt.scatter(est_x, est_z, c='green', label=f'Estimated (Error <={error_threshold}m)', alpha=0.5, s=10)
    plt.scatter(large_error_x, large_error_z, c='red', label=f'Estimated (Error >{error_threshold}m)', alpha=0.5, s=10)
    plt.xlabel('X coordinate')
    plt.ylabel('Z coordinate')
    plt.title(f'XZ Plot of Query Locations (Error Threshold: {error_threshold}m)')
    plt.legend()
    plt.grid(True)
    plt.axis('equal')
    plt.show()




# Set the error threshold
error_threshold = 0.25  # You can change this value as needed

# Run the evaluation and plotting
pos_errors, rot_errors, large_error_indices = evaluateresults(dataset, output_file, error_threshold)

# Compute statistics
median_pos_error = np.median(pos_errors)
median_rot_error = np.median(rot_errors)
mean_pos_error = np.mean(pos_errors)
mean_rot_error = np.mean(rot_errors)

print(f"Median position error: {median_pos_error:.3f} meters")
print(f"Median rotation error: {median_rot_error:.3f} degrees")
print(f"Mean position error: {mean_pos_error:.3f} meters")
print(f"Mean rotation error: {mean_rot_error:.3f} degrees")
# Create XZ plot
plot_xz_locations(dataset, output_file, large_error_indices, error_threshold)
print("XZ plot displayed")

# Print the list of query and anchor indices with large errors
print(f"\nQuery and anchor indices with position error > {error_threshold}m:")
for query_idx, anchor_idx in large_error_indices:
    print(f"Query {query_idx}, Anchor {anchor_idx}")
output_file ='/home/bjangley/VPR/7scenes/pairs2/results_dud.txt'
# Run the evaluation and plotting
pos_errors, rot_errors, large_error_indices = evaluateresults(dataset, output_file, error_threshold)

# Compute statistics
median_pos_error = np.median(pos_errors)
median_rot_error = np.median(rot_errors)
mean_pos_error = np.mean(pos_errors)
mean_rot_error = np.mean(rot_errors)

print(f"Median position error: {median_pos_error:.3f} meters")
print(f"Median rotation error: {median_rot_error:.3f} degrees")
print(f"Mean position error: {mean_pos_error:.3f} meters")
print(f"Mean rotation error: {mean_rot_error:.3f} degrees")

# Create XZ plot
plot_xz_locations(dataset, output_file, large_error_indices, error_threshold)
print("XZ plot displayed")

# Print the list of query and anchor indices with large errors
print(f"\nQuery and anchor indices with position error > {error_threshold}m:")
for query_idx, anchor_idx in large_error_indices:
    print(f"Query {query_idx}, Anchor {anchor_idx}")


In [None]:
print(large_error_indices)
print(315*9)

In [None]:
query_idx = 3
for anchor_idx in range(9):
    if (query_idx, anchor_idx) in large_error_indices:
        print(f"Query {query_idx}, Anchor {anchor_idx}: Large error detected (> {error_threshold}m)")
    else:
        print(f"Query {query_idx}, Anchor {anchor_idx}: Error within threshold (<= {error_threshold}m)")
    print(dataset.show_query_and_anchors(query_idx,anchor_idx))

query_idx = 3  # The index of the query you're interested in

# for anchor_idx in range(9):
#     dataset.show_query_and_anchors(query_idx, anchor_idx)
    
#     print(f"Checking pair: ({query_idx}, {anchor_idx})")
#     print(f"Is in large_error_indices: {(query_idx, anchor_idx) in large_error_indices}")
    
#     if (query_idx, anchor_idx) in large_error_indices:
#         print(f"Query {query_idx}, Anchor {anchor_idx}: Large error detected (> {error_threshold}m)")
#     else:
#         print(f"Query {query_idx}, Anchor {anchor_idx}: Error within threshold (<= {error_threshold}m)")
