In [None]:
#imports for model 

from mast3r.model import AsymmetricMASt3R

#general imports for plotting and visualization 
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from PIL import Image


#my utilities for evaluation and general dataset reading
from utils import CameraMatrix, computePoseError
from myDataset import ImagePairDataset, ResultsDataset, readResultsFile
from sevenScenesDatasets import loadPose7scenes
from sevenScenesDatasets import readMultiImageRelPoseNetPairsFile #function will change per pair file
from sevenScenesDatasets import scenes_dict, getSceneIndices #these will change for the dataset


#load model
device = 'cuda:4'
model_name = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
model = AsymmetricMASt3R.from_pretrained(model_name).to(device)

#model hyperprams
K = CameraMatrix(585,585,320,240) #specific to dataset
n_matches = 30





In [None]:
#reading the pairs file



root_dir = "/datasets/7scenes_org"
pairs_file = "/home/bjangley/VPR/7scenes/pairs2/test_tuples_multiimagerelposenet.txt"

anchors, queries, scene_ids = readMultiImageRelPoseNetPairsFile(pairs_file, root_dir)
dataset = ImagePairDataset(anchors, queries)

#can plot dataset item directly using 
dataset.plotItem(1)

#can access scene_ids via the scene_id list but this is 7scenes_specific 

In [None]:


# Usage
output_file = '/home/bjangley/VPR/mast3r/results_n30_withlogs.txt'
results = readResultsFile(output_file)
results.printSummary() # Print summary

## you can access an individual element by index using item = results.getPairResults(i)
## then access further information with item[key] with key being 'n_matches_total', 'n_matches_filtered', 'ret_val','mast3r_q2a','mast3r_q2world'
## the results class also has additional functions that return a list of indices 
#### lists returned::
#### getMatchesBelow(threshold)
#### getMatchesWithin(lower,Upper)
#### getFails() --> returns indices of all fails -- transform = 0

"CODE TO VISUALISE IMPORTANT PAIRS"
# indices = results.getMatchesWithin(1000,2000)
# dataset.visualizePairs(indices)

In [None]:



#the reshape operation just makes it easy for the plotting code -- used in a lot of the plotting code 
#the pairs file is query x 9 anchors -- so each of the rows would be all entries for that query
# scene = np.array(indices_scene1).reshape(int(len(indices_scene1)/9),9)
scene1 = getSceneIndices(0, scene_ids=scene_ids)




In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from plotting import plotSceneResults

# # Usage
scene_id = 4
scene = getSceneIndices(scene_id, scene_ids=scene_ids)
plotSceneResults(scene, dataset, results, title=scenes_dict[scene_id])

# for scene_id, scene_name in scenes_dict.items():
#     scene_indices = getSceneIndices(scene_id, scene_ids=scene_ids)
#     plotSceneResults(scene_indices, dataset, results, title=scene_name.capitalize())

In [None]:

from plotting import plotLocalizationVsMatches

#to plot per scene
scene_id = 3
scene = getSceneIndices(3, scene_ids=scene_ids)
plotLocalizationVsMatches(scene, dataset, results, scenes_dict[scene_id])

# Use the original function with all scene indices
# plotLocalizationVsMatches(range(len(dataset)), dataset, results, title="All Scenes")

In [None]:
import json
from tabulate import tabulate
import numpy as np

from plotting import evaluateScene, pos_thresholds, rot_thresholds

output = {}

for scene_id, scene_name in scenes_dict.items():
    scene = getSceneIndices(scene_id, scene_ids=scene_ids)
    output[scene_name] = evaluateScene(scene, dataset, results, confidence_threshold=1000)

# Print results in a table
headers = ['Scene'] + [f'{p}m, {r}°' for p in pos_thresholds for r in rot_thresholds] + ['%  Fail', '% Below Threshold', '% Above Threshold', 'Total', 'Mean Pos Error (m)', 'Mean Rot Error (°)']
table_data = []

for scene_name, scene_results in output.items():
    row = [scene_name]
    for p in pos_thresholds:
        for r in rot_thresholds:
            row.append(f"{scene_results[f'{p}m_{r}deg']:.2f}%")
    row.append(f"{scene_results['percent_complete_fail']:.2f}%")
    row.append(f"{scene_results['percent_below_threshold']:.2f}%")
    row.append(f"{scene_results['percent_above_threshold']:.2f}%")
    row.append(str(scene_results['total_estimations']))
    row.append(f"{scene_results['mean_pos_error']:.3f}")
    row.append(f"{scene_results['mean_rot_error']:.3f}")
    table_data.append(row)

print(tabulate(table_data, headers=headers, tablefmt="grid"))


In [None]:
from plotting import createSceneHistogram

# Usage
scene_id = 3  # For example, to plot the 'office' scene
scene = getSceneIndices(scene_id, scene_ids=scene_ids)
createSceneHistogram(results, dataset, range(len(dataset)), bin_width=100, title=scenes_dict[scene_id])