In [1]:
import torch
from tqdm.notebook import tqdm
import numpy as np
import h5py
import datetime

import poselib
from relscalenet.models.relscale_cache import RelScaleNetCached, write_hdf5
from relscalenet.dataset_reader import EvaluationDataset
import relscalenet.geometry as geom

### Set parameters

In [2]:
SOLVER = "ours"  # ours or 5pt
KEYPOINTS = "spsg"

weights_path = "weights/model_final.pth"
data_path = f"data/scannet1500_{KEYPOINTS}.h5"
images_path = "data/scannet1500-images/images"
cache_path = f"{KEYPOINTS}_relscale_cache.h5"

pose_estimates_path = f"{KEYPOINTS}_pose_estimates_scannet.h5"

RANSAC_OPT ={
    'max_epipolar_error': 1.5,
    'min_iterations': 1000,
    'max_iterations': 100000,
    'success_prob': 0.9999,
    'dyn_num_trials_mult': 3,
}
BUNDLE_OPT = {
    'loss_scale': 1.0,
    'loss_type': 'TRIVIAL'
}

### Run pose estimation
Note that this runs faster if depths are pre-computed.

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
relscalenet = RelScaleNetCached(weights_path, device)
relscalenet.load_from_h5(cache_path)

dataset = EvaluationDataset(data_path, images_path)

pose_estimates = {}
for pair in tqdm(dataset, desc="Estimating poses"):

    x1, x2 = pair.matches()
    R_gt, t_gt = pair.relative_pose()

    if len(x1) < 5:
        # Not enough inliers to calculate pose; skip to next image pair
        pose_estimates[pair.key] = {
            'R_gt': R_gt, 't_gt': t_gt,
            'R_est': None, 't_est': None,
            'runtime': np.nan,
            'inl_ratio': 0.0,
        }
        continue

    # Predict relative scale (or load from cache)
    im1_path, im2_path = pair.image_paths()
    relscale = relscalenet.predict_image_pair(im1_path, im2_path, x1, x2)

    # Convert relative scale to relative depth
    cam1, cam2 = pair.cameras()
    f1 = cam1['params'][0:2].mean()
    f2 = cam2['params'][0:2].mean()
    reldepth = (relscale * f2 / f1).flatten()

    # Run pose estimation
    if SOLVER == "ours":
        tt1 = datetime.datetime.now()
        estimated_pose, info = poselib.estimate_relative_pose_w_relative_depth(x1, x2, reldepth, cam1, cam2, RANSAC_OPT, BUNDLE_OPT)
        tt2 = datetime.datetime.now()
    elif SOLVER == "5pt":
        tt1 = datetime.datetime.now()
        estimated_pose, info = poselib.estimate_relative_pose(x1, x2, cam1, cam2, RANSAC_OPT, BUNDLE_OPT)
        tt2 = datetime.datetime.now()
    runtime = (tt2-tt1).total_seconds()

    # Store results
    pose_estimates[pair.key] = {
        'R_gt': R_gt, 't_gt': t_gt,
        'R_est': estimated_pose.R, 't_est': estimated_pose.t,
        'runtime': runtime,
        'inl_ratio': info['inlier_ratio'],
    }
relscalenet.cache.close()

print("Estimated pose median inlier ratio:", np.median([v['inl_ratio'] for v in pose_estimates.values()]))
print(f"Saving results to {pose_estimates_path}...")
write_hdf5(pose_estimates, pose_estimates_path)

Estimating poses:   0%|          | 0/1500 [00:00<?, ?it/s]

Estimated pose median inlier ratio: 0.5421123747325751
Saving results to spsg_pose_estimates_scannet.h5...


### Calculate errors

In [4]:
AUC_THRESHOLDS = [5., 10., 20.]

results = {
    'r_err': [],
    't_err': [],
    'runtime': [],
}

with h5py.File(pose_estimates_path, "r") as f:
    for k in tqdm(f.keys()):
        pair = f[k]

        R_gt = pair['R_gt'][()]
        t_gt = pair['t_gt'][()]

        if not ('R_est' in pair and 't_est' in pair):
            results['r_err'].append(np.inf)
            results['t_err'].append(np.inf)
            continue

        R_est = pair['R_est'][()]
        t_est = pair['t_est'][()]
        runtime = pair['runtime'][()]

        r_err = geom.rotation_angle(R_est.transpose() @ R_gt)
        t_err = geom.angle(t_est, t_gt)

        results['r_err'].append(r_err)
        results['t_err'].append(t_err)
        results['runtime'].append(runtime)

# Print results
max_errs = np.max(np.c_[ results['r_err'], results['t_err']], axis=1)
aucs = geom.pose_auc(max_errs, AUC_THRESHOLDS)
aucs = [auc*100 for auc in aucs]
avg_runtime = np.mean(results['runtime'])


for i, t in enumerate(AUC_THRESHOLDS):
    print(f'AUC@{int(t)} = {aucs[i]:.2f}')
print(f'Avg. runtime: {avg_runtime*1000 :.1f} ms.')

  0%|          | 0/1500 [00:00<?, ?it/s]

AUC@5 = 21.34
AUC@10 = 38.23
AUC@20 = 53.42
Avg. runtime: 14.6 ms.
