In [None]:
import numpy as np
from pathlib import Path
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

from hfnet.evaluation.image_retrieval import compute_recall, is_gt_match_2D
from hfnet.datasets.nclt import Nclt
from hfnet.settings import DATA_PATH, EXPER_PATH
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
def get_data(seq, experiment):
    im_poses = Nclt.get_pose_file(seq)
    descriptors = []
    for t in im_poses['time']:
        with np.load(Path(EXPER_PATH, 'exports', experiment, seq, f'{t}.npz')) as npz:
            descriptors.append(npz['global_descriptor'].copy())
    return im_poses, np.array(descriptors)

def nclt_recall(ref_seq, query_seqs, experiment, distance_thresh=10, angle_thresh=np.pi/2, *arg, **kwarg):
    ref_poses, ref_descriptors = get_data(ref_seq, experiment)
    query_poses, query_descriptors = [], []
    for s in query_seqs:
        poses, descriptors = get_data(s, experiment)
        query_poses.append(poses)
        query_descriptors.append(descriptors)
    query_poses = np.concatenate(query_poses, axis=0)
    query_descriptors = np.concatenate(query_descriptors, axis=0)
    gt_matches = is_gt_match_2D(query_poses, ref_poses, distance_thresh, angle_thresh)
    return compute_recall(ref_descriptors, query_descriptors, gt_matches, *arg, **kwarg)

In [None]:
experiments = [
    'netvlad/nclt', 
    # experiments
]
ref_seq = '2012-01-08'
query_seqs = ['2013-02-23', '2012-08-20']

plt.figure(dpi=150)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for e, c in zip(experiments, colors):
    m = nclt_recall(ref_seq, query_seqs, e, max_num_nn=30, pca_dim=0)
    plt.plot(1+np.arange(len(m)), 100*m, label=e, color=c, linewidth=1);
    print(f'{e:<70} Recall@10: {m[9]:.3f}')
    
#    m = nclt_recall(ref_seq, query_seqs, e, max_num_nn=30, pca_dim=512)
#    plt.plot(1+np.arange(len(m)), 100*m, label=e+'_proj512', 
#             color=c, linewidth=1.3, linestyle='--');
#    print(f'{e+"_proj512":<70} Recall@10: {m[9]:.3f}')

plt.xticks([1]+np.arange(10, 31, step=10).tolist()); plt.grid(color=[0.85]*3);
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.2));
plt.xlabel('Number of queried neighbors'), plt.ylabel('Recall@N (%)');