In [69]:
import cv2
import numpy as np
import h5py
from collections import defaultdict
from scipy.spatial.transform import Rotation

import os
from pprint import pprint
import gc

In [70]:
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
    if len(kpts0) < 5:
        return None

    f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
    norm_thresh = thresh / f_mean

    kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
    kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]

    E, mask = cv2.findEssentialMat(
        kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
        method=cv2.RANSAC)

    if E is None:
        return None

    best_num_inliers = 0
    ret = None
    for _E in np.split(E, len(E) / 3):
        n, R, t, _ = cv2.recoverPose(
            _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
        if n > best_num_inliers:
            best_num_inliers = n
            ret = (R, t[:, 0], mask.ravel() > 0)
    return ret


def angle_error_mat(R1, R2):
    cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
    cos = np.clip(cos, -1., 1.)  # numercial errors can make it out of bounds
    return np.rad2deg(np.abs(np.arccos(cos)))


def angle_error_vec(v1, v2):
    n = np.linalg.norm(v1) * np.linalg.norm(v2)
    return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))


def compute_pose_error(T_0to1, est_pose):
    R_gt = T_0to1[:3, :3]
    t_gt = T_0to1[:3, 3]
    R = est_pose[:3, :3]
    t = est_pose[:3, 3]
    error_t = angle_error_vec(t, t_gt)
    error_t = np.minimum(error_t, 180 - error_t)  # ambiguity of E estimation
    error_R = angle_error_mat(R, R_gt)
    return error_t, error_R

In [71]:
def load_retrieval_pairs(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        pairs = []
        for line in f:
            pair = line.strip().split(" ")
            pairs.append(pair)
    f.close()
    return pairs

def load_images(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        images = defaultdict(dict)
        for line in f:
            if line.startswith("#"): continue
            timestamp, sensor_id, image_path = line.strip().split(", ")
            images[image_path] = {
                "timestamp": timestamp,
                "sensor_id": sensor_id
            }
    return images

def load_intrinsics(file_path):
    sensors = defaultdict(dict)
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith("#"): continue
            line = line.strip().split(", ")
            if len(line) < 6: continue
            
            sensor_id = line[0]
            width, height = line[4:6]
            fx, fy, cx, cy = line[6:]
            K = np.array([
                [fx, 0, cx],
                [0, fy, cy],
                [0, 0, 1],
            ], dtype=float)
            
            sensors[sensor_id] = {
                'K': K,
                'width': int(width),
                'height': int(height),
            }
    return sensors

def load_rigs(file_path):
    if file_path is None:
        q = {'x': 0.0, 'y': 0.0, 'z': 0.0, 'w': 1.0}    # No rotate
        t = {'x': 0.0, 'y': 0.0, 'z': 0.0}              # No translate
        q_xyzw = np.array([q['x'], q['y'], q['z'], q['w']])
        Q = Rotation.from_quat(q_xyzw).as_matrix()
        T = np.array([t['x'], t['y'], t['z']])
        
        cam2rig = np.eye(4)
        cam2rig[:3, :3] = Q
        cam2rig[:3, 3] = T
        
        return {
            'rig_sensors': {
                'cam2rig': cam2rig,
            }
        }
        
    rigs = defaultdict(dict)
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith("#"): continue
            line = line.strip().split(", ")
            
            # rig_id = line[0].split("_")[-1]
            # sensor_id = line[1].split("/")[-1]
            
            rig_id = line[0]
            sensor_id = line[1]
            q = {
                'w': float(line[2]),
                'x': float(line[3]),
                'y': float(line[4]),
                'z': float(line[5]),
            }
            t = {
                'x': float(line[6]),
                'y': float(line[7]),
                'z': float(line[8]),
            }
            
            q_xyzw = np.array([q['x'], q['y'], q['z'], q['w']])
            Q = Rotation.from_quat(q_xyzw).as_matrix()
            T = np.array([t['x'], t['y'], t['z']])
            
            cam2rig = np.eye(4)
            cam2rig[:3, :3] = Q
            cam2rig[:3, 3] = T
            
            rigs[rig_id][sensor_id] = {
                'cam2rig': cam2rig,
            }
    return rigs

def load_poses(file_path):
    poses = defaultdict(dict)
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith("#"): continue
            line = line.strip().split(", ")
            timestamp = line[0]
            device_id = line[1]
            
            q = {
                'w': float(line[2]),
                'x': float(line[3]),
                'y': float(line[4]),
                'z': float(line[5]),
            }
            t = {
                'x': float(line[6]),
                'y': float(line[7]),
                'z': float(line[8]),
            }
            
            q_xyzw = np.array([q['x'], q['y'], q['z'], q['w']])
            Q = Rotation.from_quat(q_xyzw).as_matrix()
            T = np.array([t['x'], t['y'], t['z']])
            pose = np.eye(4)
            pose[:3, :3] = Q
            pose[:3, 3] = T
            poses[timestamp] = {
                'pose': pose,
                'device_id': device_id
            }
    return poses

def load_keypoints(file_path):
    """
    Load keypoints and descriptors from H5 file.
    
    Supports two structures:
    
    Structure 1 (ios_query):
    - {session}/raw_data/{subsession}/images/{image_id}/
      - keypoints, descriptors, scores, image_size
    
    Structure 2 (spot_query):
    - {session}/raw_data/{subsession}/{camera}/{image_id}/
      - keypoints, descriptors, scores, image_size
    
    Structure 3 (hl_query):
    - {session}/raw_data/{subsession}/images/{camera}/{image_id}/
      - keypoints, descriptors, scores, image_size
    
    Returns nested dicts with flexible structure based on hierarchy depth.
    """
    
    results = defaultdict(dict)
    
    def recursive_load(group, result_key):
        """
        Recursively traverse H5 group and load data when keypoints are found.
        """
        # Check if this group contains keypoints (leaf node)
        if 'keypoints' in group:
            keypoints = np.array(group['keypoints'][:], dtype=np.int32)
            descriptors = np.array(group['descriptors'][:], dtype=np.float32)
            scores = np.array(group['scores'][:], dtype=np.float32)
            image_size = np.array(group['image_size'][:], dtype=np.int32)
            
            results[result_key]['keypoints'] = keypoints
            results[result_key]['descriptors'] = descriptors
            results[result_key]['scores'] = scores
            results[result_key]['image_size'] = image_size
            return
        
        # If not leaf node, iterate through children
        for key in group.keys():
            item = group[key]
            if isinstance(item, h5py.Group):
                # Recurse into this group
                recursive_load(item, f"{result_key}{'/' if result_key != '' else ''}{key}")
    
    with h5py.File(file_path, 'r') as f:
        recursive_load(f, result_key="")
    
    return results

def load_matches(file_path):
    """
    Load matching results from H5 file.
    
    Structure:
    - Group for each query image
      - Group for each map image
        - matches0: matched indices in map image (-1 means no match)
        - matching_scores0: confidence scores for matches
    """
    
    results = defaultdict(dict)
    
    def recursive_load(group, result_key):
        """
        Recursively traverse H5 group and load data when keypoints are found.
        """
        # Check if this group contains keypoints (leaf node)
        if 'matches0' in group:
            matches = np.array(group['matches0'][:], dtype=np.int32)
            scores = np.array(group['matching_scores0'][:], dtype=np.float32)
            
            results[result_key]['matches0'] = matches
            results[result_key]['matching_scores0'] = scores
            return
        
        # If not leaf node, iterate through children
        for key in group.keys():
            item = group[key]
            if isinstance(item, h5py.Group):
                # Recurse into this group
                recursive_load(item, f"{result_key} {key}" if result_key != "" else key)
    
    with h5py.File(file_path, 'r') as f:
        recursive_load(f, result_key="")
    
    return results

In [72]:
def get_K(Ks, images, image):
    sensor_id = images[image]['sensor_id']
    return Ks[sensor_id]['K']

def get_device_id(query_img, timestamp, query_device):
    query_img = query_img.split("/")
    sub = query_img[0]
    
    if query_device == "ios":
        query_device = f"{sub}/cam_phone_{timestamp}"
    if query_device == "hl":
        query_device = f"{sub}/hetrig_{timestamp}"
    if query_device == "spot":
        query_device = f"{sub}/{timestamp}-body"
    
    return query_device

def estimate_poses(pairs, all_matches, all_kpts0, all_kpts1, query_images, map_images, query_Ks, map_Ks):
    est_poses = {}
    for query_img, map_img in pairs:
        query_device = query_img.split("_")[0]
        
        # Matches
        matches, _ = all_matches[f"{query_img.replace('/', '-')} {map_img.replace('/', '-')}"].values()

        # Keypoints
        kpts0 = all_kpts0[query_img]['keypoints']
        kpts1 = all_kpts1[map_img]['keypoints']
            
        # Keep the matching keypoints.
        valid = matches > -1
        mkpts0 = kpts0[valid]
        mkpts1 = kpts1[matches[valid]]

        # Estimate the pose and compute the pose error.
        query_img = "/".join(query_img.split("/")[2:])
        map_img = "/".join(map_img.split("/")[2:])
        
        K0 = get_K(query_Ks, query_images, query_img)
        K1 = get_K(map_Ks, map_images, map_img)

        thresh = 1.  # In pixels relative to resized image size.
        est_pose = estimate_pose(mkpts0, mkpts1, K0, K1, thresh)
        if est_pose is None: continue
        
        query_timestamp = query_images[query_img]['timestamp']
        map_timestamp = map_images[map_img]['timestamp']
        est_poses[f"{query_timestamp}-{map_timestamp}"] = {
            "pose": est_pose,
            "device_id": get_device_id(query_img, query_timestamp, query_device)
        }
    return est_poses

def compute_errors(pairs, query_images, map_images, est_poses, query_poses, map_poses, query_rigs, map_rigs, top, Rt_threshold, r_margin):    
    all_err_t, all_err_R = defaultdict(list), defaultdict(list)
    good_pairs, bad_pairs = [], []
    for query_img, map_img in pairs:
        try:
            query_img = "/".join(query_img.split("/")[2:])
            map_img = "/".join(map_img.split("/")[2:])
            query_timestamp, query_sensor = query_images[query_img]['timestamp'],  query_images[query_img]['sensor_id']
            map_timestamp, map_sensor = map_images[map_img]['timestamp'],  map_images[map_img]['sensor_id']
            
            est_pose = est_poses[f"{query_timestamp}-{map_timestamp}"]['pose']
            
            query_pose = get_groundtruth(query_poses, query_rigs, query_timestamp, query_sensor)
            map_pose = get_groundtruth(map_poses, map_rigs, map_timestamp, map_sensor)
            
            # TODD: T_0to1 is transform from query to map
            T_0to1 = np.linalg.inv(map_pose) @ query_pose

            err_t, err_R = compute_pose_error(T_0to1, est_pose)
            
            # Collect
            all_err_t[query_img].append(err_t)
            all_err_R[query_img].append(err_R)
            
            # Classify
            if err_R < r_margin:
                good_pairs.append((query_img, map_img, str(err_R)))
            else:
                bad_pairs.append((query_img, map_img, str(err_R)))
        except:
            continue
    
    all_err_t = convert_to_list(all_err_t, top)
    all_err_R = convert_to_list(all_err_R, top)
    
    th_r, th_t = Rt_threshold
    recall = np.mean((all_err_R < th_r) & (all_err_t < th_t))
    return all_err_t.mean(), all_err_R.mean(), recall, good_pairs, bad_pairs

def get_q_t(pose):
    Q, T, _ = pose
    r = Rotation.from_matrix(Q)
    q_xyzw = r.as_quat()
    t = T
    return q_xyzw, t

def save_est_poses(file_path, est_poses):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        lines = ""
        lines += "# timestamp, device_id, qw, qx, qy, qz, tx, ty, tz, *covar\n"
        for timestamp, data in est_poses.items():
            line = [timestamp]
            line.append(data['device_id'])
            q_xyzw, t = get_q_t(data['pose'])
            line.append(str(q_xyzw[3]))
            line.append(str(q_xyzw[0]))
            line.append(str(q_xyzw[1]))
            line.append(str(q_xyzw[2]))
            line.append(str(t[0]))
            line.append(str(t[1]))
            line.append(str(t[2]))
            lines += ", ".join(line) + "\n"
        f.write(lines)
    f.close()
    print(f"Saved estimated poses to {file_path}")
    
def get_groundtruth(poses, rigs, timestamp, cam):
    rig2world = poses[timestamp]['pose']
    rig_id = poses[timestamp]['device_id']
    
    if "rig_sensors" in rigs.keys():
        cam2rig = rigs["rig_sensors"]['cam2rig']
    else:
        cam2rig = rigs[rig_id][cam]['cam2rig']
    
    return rig2world @ cam2rig

def filter_top(results, n=5):
    for k, v in results.items():
        results[k] = sorted(v)[:n]
    return results

def convert_to_list(results, top=5):
    results = filter_top(results, top)
    result_list = []
    for value_list in results.values():
        result_list.extend(value_list)
    return np.array(result_list)

def save_pairs(pairs, save_dir, type):
    save_path = f"{save_dir}/{type}_pairs.txt"
    with open(save_path, 'w', encoding='utf-8') as f:
        lines = "# query_img, map_img, R_err\n"
        for pair in pairs:
            line = ", ".join(pair)
            lines += line + '\n'
        f.write(lines)
    print(f"Saved {type} pairs to {save_path}")
    f.close()

In [73]:
def evaluate_pair(query_device, map_device, top, Rt_threshold, r_margin):
    # Pairs
    PAIRS_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/long/benchmarking_results/pair_selection/{query_device}_query/{map_device}_map/megaloc/pairs.txt"
    pairs = load_retrieval_pairs(PAIRS_PATH)

    # Images
    QUERY_IMAGES_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/images.txt"
    MAP_IMAGES_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{map_device}_map/images.txt"
    query_images = load_images(QUERY_IMAGES_PATH)
    map_images = load_images(MAP_IMAGES_PATH)
    
    SAVE_DIR = f"estimate_pose/{query_device}_query/{map_device}_map"
    SAVE_PATH = f"{SAVE_DIR}/est_poses.txt"
    
    if not os.path.exists(SAVE_PATH):
        # Keypoints
        KPTS0_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/long/benchmarking_results/extraction/{query_device}_query/superpoint/features.h5"
        all_kpts0 = load_keypoints(KPTS0_PATH)
        KPTS1_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/long/benchmarking_results/extraction/{map_device}_map/superpoint/features.h5"
        all_kpts1 = load_keypoints(KPTS1_PATH)

        # Matches
        MATCHES_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/long/benchmarking_results/matching/{query_device}_query/{map_device}_map/superpoint/lightglue/matches.h5"
        all_matches = load_matches(MATCHES_PATH)

        # Intrinsics
        QUERY_SENSORS_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/sensors.txt"
        MAP_SENSORS_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{map_device}_map/sensors.txt"
        query_Ks = load_intrinsics(QUERY_SENSORS_PATH)
        map_Ks = load_intrinsics(MAP_SENSORS_PATH)
        
        est_poses = estimate_poses(pairs, all_matches, all_kpts0, all_kpts1, query_images, map_images, query_Ks, map_Ks)
        save_est_poses(SAVE_PATH, est_poses)
    else:
        est_poses = load_poses(SAVE_PATH)
        
    # Poses
    QUERY_POSES_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/proc/alignment_trajectories.txt"
    MAP_POSES_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{map_device}_map/trajectories.txt"
    query_poses = load_poses(QUERY_POSES_PATH)
    map_poses = load_poses(MAP_POSES_PATH)

    # Rigs
    QUERY_RIGS_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/rigs.txt" if query_device != "ios" else None
    MAP_RIGS_PATH = f"/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{map_device}_map/rigs.txt" if map_device != "ios" else None
    query_rigs = load_rigs(QUERY_RIGS_PATH)
    map_rigs = load_rigs(MAP_RIGS_PATH)
        
    all_err_t, all_err_R, recall, good_pairs, bad_pairs = compute_errors(pairs, query_images, map_images, est_poses, query_poses, map_poses, query_rigs, map_rigs, top, Rt_threshold, r_margin)
    save_pairs(good_pairs, SAVE_DIR, 'good')
    save_pairs(bad_pairs, SAVE_DIR, 'bad')
    gc.collect()
    
    return all_err_t, all_err_R, recall

In [74]:
QUERY_DEVICES = ["ios", "hl", "spot"]
MAP_DEVICES = ["ios", "hl", "spot"]

top = 5
Rt_threshold = (20.0, 20.0)
r_margin = 5.0

results = defaultdict(dict)
for query_device in QUERY_DEVICES:
    for map_device in MAP_DEVICES:
        print(f"{'*'*10} {query_device}-{map_device} (top {top}) {'*'*10}")
        err_t, err_R, recall = evaluate_pair(query_device, map_device, top, Rt_threshold, r_margin)
        results[query_device][map_device] = {
            "err_t": err_t,
            "err_R": err_R, 
            "recall": recall
        }
        print(f"{err_t:.2f} || {err_R:.2f} || {recall:.2f}")


********** ios-ios (top 5) **********
Saved good pairs to estimate_pose/ios_query/ios_map/good_pairs.txt
Saved bad pairs to estimate_pose/ios_query/ios_map/bad_pairs.txt
4.78 || 4.85 || 0.93
********** ios-hl (top 5) **********
Saved good pairs to estimate_pose/ios_query/hl_map/good_pairs.txt
Saved bad pairs to estimate_pose/ios_query/hl_map/bad_pairs.txt
15.93 || 20.09 || 0.68
********** ios-spot (top 5) **********
Saved good pairs to estimate_pose/ios_query/spot_map/good_pairs.txt
Saved bad pairs to estimate_pose/ios_query/spot_map/bad_pairs.txt
27.57 || 36.57 || 0.46
********** hl-ios (top 5) **********
Saved good pairs to estimate_pose/hl_query/ios_map/good_pairs.txt
Saved bad pairs to estimate_pose/hl_query/ios_map/bad_pairs.txt
8.59 || 10.39 || 0.83
********** hl-hl (top 5) **********
Saved good pairs to estimate_pose/hl_query/hl_map/good_pairs.txt
Saved bad pairs to estimate_pose/hl_query/hl_map/bad_pairs.txt
8.70 || 3.25 || 0.86
********** hl-spot (top 5) **********
Saved good 

In [75]:
from pprint import pprint
pprint(results)

defaultdict(<class 'dict'>,
            {'hl': {'hl': {'err_R': 3.2541068861789033,
                           'err_t': 8.701040292548322,
                           'recall': 0.8603712671509282},
                    'ios': {'err_R': 10.387546828152527,
                            'err_t': 8.585823968380561,
                            'recall': 0.8258746948738812},
                    'spot': {'err_R': 26.268981845563747,
                             'err_t': 22.59320267734755,
                             'recall': 0.5378356387306753}},
             'ios': {'hl': {'err_R': 20.09436891100037,
                            'err_t': 15.925424598319045,
                            'recall': 0.6817359855334539},
                     'ios': {'err_R': 4.854526889174639,
                             'err_t': 4.779456928839683,
                             'recall': 0.9344552701505757},
                     'spot': {'err_R': 36.57144754878939,
                              'err_t': 27.567702270

In [76]:
def show_matrix(results, metric):
    # Get all labels
    labels = list(results.keys())

    # Print header
    print(f"{'':>6}", end=' ')
    for col in labels:
        print(f"{col:>10}", end=' ')
    print()

    # Print rows
    for row in labels:
        print(f"{row:>6}", end=' ')
        for col in labels:
            value = results[row][col][metric]
            print(f"{value:10.4f}", end=' ')
        print()

metric = 'err_R' # err_R, err_t, acc_R, recall
show_matrix(results, metric)

              ios         hl       spot 
   ios     4.8545    20.0944    36.5714 
    hl    10.3875     3.2541    26.2690 
  spot    27.3850    35.4545     3.5310 


In [77]:
!cd ~/Workspace/crocodl-benchmark && export CAPTURE_DIR=./capture && bash ./evaluate/evaluate.sh

Checking Python dependencies...
Dependencies OK
You are running with parameters: 
  Capture: ./capture
  Output: ./capture/evaluation_results
  Benchmarking dir: long/benchmarking_results
  Local feature method: superpoint
  Matching method: lightglue
  Global feature method: megaloc
  Scenes: arche_d2
  Devices map: ios hl spot
  Devices query: ios hl spot
  Position threshold: 0.5 meters
  Rotation threshold: 5 degrees
Running evaluation...
Starting cross-device pose estimation evaluation
Configuration:
  Capture dir: ./capture
  Benchmarking dir: long/benchmarking_results
  Local feature method: superpoint
  Matching method: lightglue
  Global feature method: megaloc
  Scenes: ['arche_d2']
  Map devices: ['ios', 'hl', 'spot']
  Query devices: ['ios', 'hl', 'spot']
  Position threshold: 0.5
  Rotation threshold: 5.0

Processing scene: arche_d2
Evaluating arche_d2: ios query vs ios map
Evaluating arche_d2: ios query vs hl map
Evaluating arche_d2: ios query vs spot map
Evaluating arche

In [78]:
EST_POSES_COLOR = [255, 0, 0]
TRAJ_COLOR = [0, 255, 0]
ALM_TRAJ_COLOR = [0, 0, 255]

def visualize_poses(query_device, map_device):
    import sys
    import os

    # Add project root to path
    project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
    if project_root not in sys.path:
        sys.path.insert(0, project_root)
        
    from visualize.cam_pose_visualizer import CamPoseVisualizer, load_gt_and_est_poses

    visualizer = CamPoseVisualizer()
    
    est_poses_path = f"estimate_pose/{query_device}_query/{map_device}_map/est_poses.txt"
    traj_path = f'/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/trajectories.txt'
    alm_traj_path = f'/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/proc/alignment_trajectories.txt'
    sensors_path = f'/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/sensors.txt'
    rigs_path = None if query_device == 'ios' else f'/home/long/Workspace/crocodl-benchmark/capture/ARCHE_D2/sessions/{query_device}_query/rigs.txt'
    est_poses, traj = load_gt_and_est_poses(est_poses_path, traj_path, sensors_path, rigs_path, EST_POSES_COLOR, TRAJ_COLOR)
    est_poses, alm_traj = load_gt_and_est_poses(est_poses_path, alm_traj_path, sensors_path, rigs_path, EST_POSES_COLOR, ALM_TRAJ_COLOR)
    
    visualizer.visualize(est_poses + traj + alm_traj)
    
# visualize_poses(query_device, map_device)

In [79]:
import numpy as np

def filter_top(results: dict, n: int = 5):
    for k, v in results.items():
        results[k] = sorted(v, reverse=True)[:n]
    return results

def get_error(results: dict, top: int = 5):
    results = filter_top(top)
    result_list = []
    for value_list in results.values():
        result_list.extend(value_list)
    return np.mean(np.array(result_list))

results = {
    "abc": np.arange(10, dtype=int)
}

filter_top(results)


{'abc': [9, 8, 7, 6, 5]}

In [80]:
my_dict = {'key1': [1, 2, 3], 'key2': [4, 5], 'key3': [6, 7, 8, 9]}
result_list = []
for value_list in my_dict.values():
    result_list.extend(value_list)
print(result_list)

[1, 2, 3, 4, 5, 6, 7, 8, 9]
