In [15]:
import numpy as np
import torch
import cv2
from scipy.spatial import cKDTree

from hfnet.evaluation.utils.descriptors import matches_cv2np, normalize as norm

# 2D-3D Matching

In [3]:
n_2d = 2000
n_3d = 3000

d_2d = norm(np.random.rand(n_2d, 256) - .5)
l_3d = np.random.randint(0, high=n_2d, size=(n_3d))
d_3d = norm(d_2d[l_3d] + np.random.normal(scale=0.3, size=(n_3d, 256)))

## OpenCV

In [86]:
%%time
matcher = cv2.BFMatcher(cv2.NORM_L2)
matches = matcher.knnMatch(d_2d.astype(np.float32), d_3d.astype(np.float32), k=2)
matches1, matches2 = list(zip(*matches))
(matches1, dist1) = matches_cv2np(matches1)
(matches2, dist2) = matches_cv2np(matches2)
good = (l_3d[matches1[:, 1]] == l_3d[matches2[:, 1]])
good = good | (dist1/dist2 < 0.95)
matches = matches1[good]

CPU times: user 1.31 s, sys: 25.1 ms, total: 1.34 s
Wall time: 180 ms


## Numpy

In [229]:
%%time

dist = 2*(1 - d_2d @ d_3d.T)
ind = np.argpartition(dist, 2, axis=-1)[:, :2]
dist_nn = np.take_along_axis(dist, ind, axis=-1)
labels_nn = l_3d[ind]

thresh = 0.95**2
match_ok = (labels_nn[:, 0] == labels_nn[:, 1])
match_ok |= (dist_nn[:, 0] <= thresh*dist_nn[:, 1])
matches = np.stack([np.where(match_ok)[0], ind[match_ok][:, 0]])

CPU times: user 655 ms, sys: 1.64 s, total: 2.29 s
Wall time: 378 ms


## PyTorch

In [24]:
td_2d = torch.from_numpy(d_2d)
td_3d = torch.from_numpy(d_3d)
tl_3d = torch.from_numpy(l_3d)

In [9]:
%%time
with torch.no_grad():
    td_3d.t_()
    dist = 2*(1 - td_2d @ td_3d)

    dist_nn, ind = dist.topk(2, dim=-1, largest=False)
    labels_nn = tl_3d[ind]

    thresh = 0.95**2
    match_ok = (labels_nn[:, 0] == labels_nn[:, 1])
    match_ok |= (dist_nn[:, 0] <= thresh*dist_nn[:, 1])
    matches = torch.stack([torch.nonzero(match_ok)[:, 0], ind[match_ok][:, 0]])

CPU times: user 203 ms, sys: 73.7 ms, total: 276 ms
Wall time: 178 ms


In [19]:
@torch.jit.script
def jit_matching(desc1, desc2, ratio_thresh, labels):
    dist = 2*(1 - desc1 @ desc2.t())
    dist_nn, ind = dist.topk(2, dim=-1, largest=False)
    match_ok = (dist_nn[:, 0] <= (ratio_thresh**2)*dist_nn[:, 1])
    labels_nn = labels[ind]
    match_ok = match_ok | (labels_nn[:, 0] == labels_nn[:, 1])
    matches = torch.stack(
        [torch.nonzero(match_ok)[:, 0], ind[match_ok][:, 0]], dim=-1)
    return matches

In [36]:
jit_matching.save('pytorch_matching.pt')

# Global Matching

In [3]:
n = 20000
k = 10
query = norm(np.random.rand(1024) - .5)
db = norm(np.random.rand(n, 1024) - .5)

## kd-tree

In [4]:
index = cKDTree(db)

In [5]:
%%time
d, ind = index.query(query, k=k)

CPU times: user 50.7 ms, sys: 21.4 ms, total: 72.1 ms
Wall time: 71.2 ms


## Numpy

In [214]:
%%time
dist = 2 * (1 - db @ query)
ind = np.argpartition(dist, k)[:k]
ind = ind[np.argsort(dist[ind])]

CPU times: user 187 ms, sys: 576 ms, total: 763 ms
Wall time: 97.2 ms


## PyTorch

In [215]:
tquery = torch.from_numpy(query)
tdb = torch.from_numpy(db)

In [227]:
%%time
with torch.no_grad():
    dist = 2*(1 - tdb @ tquery)
    _, ind = dist.topk(k, largest=False)

CPU times: user 54.3 ms, sys: 0 ns, total: 54.3 ms
Wall time: 10.4 ms
