diff --git a/faiss/MetricType.h b/faiss/MetricType.h index 538b0a8e72..4689d4d018 100644 --- a/faiss/MetricType.h +++ b/faiss/MetricType.h @@ -33,6 +33,7 @@ enum MetricType { METRIC_JensenShannon, METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) ///< where a_i, b_i > 0 + METRIC_NaNEuclidean, }; /// all vector indices are this type diff --git a/faiss/utils/extra_distances-inl.h b/faiss/utils/extra_distances-inl.h index d3768df668..79ead454a4 100644 --- a/faiss/utils/extra_distances-inl.h +++ b/faiss/utils/extra_distances-inl.h @@ -130,4 +130,23 @@ inline float VectorDistance::operator()( return accu_num / accu_den; } +template <> +inline float VectorDistance::operator()( + const float* x, + const float* y) const { + // https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html + float accu = 0; + size_t present = 0; + for (size_t i = 0; i < d; i++) { + if (!std::isnan(x[i]) && !std::isnan(y[i])) { + float diff = x[i] - y[i]; + accu += diff * diff; + present++; + } + } + if (present == 0) { + return std::numeric_limits::quiet_NaN(); + } + return static_cast(d) / static_cast(present) * accu; +} } // namespace faiss diff --git a/faiss/utils/extra_distances.cpp b/faiss/utils/extra_distances.cpp index 8c0699880d..fb225e7c9e 100644 --- a/faiss/utils/extra_distances.cpp +++ b/faiss/utils/extra_distances.cpp @@ -164,6 +164,7 @@ void pairwise_extra_distances( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); + HANDLE_VAR(NaNEuclidean); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented"); @@ -195,6 +196,7 @@ void knn_extra_metrics( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); + HANDLE_VAR(NaNEuclidean); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented"); @@ -242,6 +244,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer( HANDLE_VAR(JensenShannon); HANDLE_VAR(Lp); HANDLE_VAR(Jaccard); + HANDLE_VAR(NaNEuclidean); #undef HANDLE_VAR default: FAISS_THROW_MSG("metric type not implemented"); diff --git a/tests/test_extra_distances.py b/tests/test_extra_distances.py index a474dd6ba7..66318f76c5 100644 --- a/tests/test_extra_distances.py +++ b/tests/test_extra_distances.py @@ -94,6 +94,26 @@ def test_jaccard(self): new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_Jaccard) self.assertTrue(np.allclose(ref_dis, new_dis)) + def test_nan_euclidean(self): + xq, yb = self.make_example() + ref_dis = np.array([ + [scipy.spatial.distance.sqeuclidean(x, y) for y in yb] + for x in xq + ]) + new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_NaNEuclidean) + self.assertTrue(np.allclose(ref_dis, new_dis)) + + x = [[3, np.nan, np.nan, 6]] + q = [[1, np.nan, np.nan, 5]] + dis = [(4 / 2 * ((3 - 1)**2 + (6 - 5)**2))] + new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean) + self.assertTrue(np.allclose(new_dis, dis)) + + x = [[np.nan] * 4] + q = [[np.nan] * 4] + new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean) + self.assertTrue(np.isnan(new_dis[0])) + class TestKNN(unittest.TestCase): """ test that the knn search gives the same as distance matrix + argmin """