Skip to content

Commit

Permalink
Implement METRIC.NaNEuclidean
Browse files Browse the repository at this point in the history
Summary:
#3355

A couple open questions:
- Given L2 was squared, I figured I would leave this one as squared as well?
- Also, wasn't sure if we wanted to return nan when present == 0 or -1?

Differential Revision: D57017608
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed May 6, 2024
1 parent 0cc0e19 commit 8fe159e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum MetricType {
METRIC_JensenShannon,
METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
///< where a_i, b_i > 0
METRIC_NaNEuclidean,
};

/// all vector indices are this type
Expand Down
19 changes: 19 additions & 0 deletions faiss/utils/extra_distances-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,23 @@ inline float VectorDistance<METRIC_Jaccard>::operator()(
return accu_num / accu_den;
}

template <>
inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
const float* x,
const float* y) const {
// https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
float accu = 0;
size_t present = 0;
for (size_t i = 0; i < d; i++) {
if (!std::isnan(x[i]) && !std::isnan(y[i])) {
float diff = x[i] - y[i];
accu += diff * diff;
present++;
}
}
if (present == 0) {
return std::numeric_limits<float>::quiet_NaN();
}
return static_cast<float>(d) / static_cast<float>(present) * accu;
}
} // namespace faiss
3 changes: 3 additions & 0 deletions faiss/utils/extra_distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ void pairwise_extra_distances(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -195,6 +196,7 @@ void knn_extra_metrics(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -242,6 +244,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down
20 changes: 20 additions & 0 deletions tests/test_extra_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ def test_jaccard(self):
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_Jaccard)
self.assertTrue(np.allclose(ref_dis, new_dis))

def test_nan_euclidean(self):
xq, yb = self.make_example()
ref_dis = np.array([
[scipy.spatial.distance.sqeuclidean(x, y) for y in yb]
for x in xq
])
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.allclose(ref_dis, new_dis))

x = [[3, np.nan, np.nan, 6]]
q = [[1, np.nan, np.nan, 5]]
dis = [(4 / 2 * ((3 - 1)**2 + (6 - 5)**2))]
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.allclose(new_dis, dis))

x = [[np.nan] * 4]
q = [[np.nan] * 4]
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.isnan(new_dis[0]))


class TestKNN(unittest.TestCase):
""" test that the knn search gives the same as distance matrix + argmin """
Expand Down

0 comments on commit 8fe159e

Please sign in to comment.