Skip to content

Commit

Permalink
Implement METRIC.NaNEuclidean (#3414)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3414

#3355

A couple open questions:
- Given L2 was squared, I figured I would leave this one as squared as well?
- Also, wasn't sure if we wanted to return nan when present == 0 or -1?

Reviewed By: mdouze

Differential Revision: D57017608

fbshipit-source-id: ba14458b92c8b055f3bf2a871565175935c8333a
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed May 16, 2024
1 parent 72571c7 commit 1876925
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions faiss/MetricType.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum MetricType {
METRIC_JensenShannon,
METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
///< where a_i, b_i > 0
METRIC_NaNEuclidean,
};

/// all vector indices are this type
Expand Down
20 changes: 20 additions & 0 deletions faiss/utils/extra_distances-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <faiss/MetricType.h>
#include <faiss/utils/distances.h>
#include <cmath>
#include <type_traits>

namespace faiss {
Expand Down Expand Up @@ -130,4 +131,23 @@ inline float VectorDistance<METRIC_Jaccard>::operator()(
return accu_num / accu_den;
}

template <>
inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
const float* x,
const float* y) const {
// https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
float accu = 0;
size_t present = 0;
for (size_t i = 0; i < d; i++) {
if (!std::isnan(x[i]) && !std::isnan(y[i])) {
float diff = x[i] - y[i];
accu += diff * diff;
present++;
}
}
if (present == 0) {
return NAN;
}
return float(d) / float(present) * accu;
}
} // namespace faiss
3 changes: 3 additions & 0 deletions faiss/utils/extra_distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ void pairwise_extra_distances(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -195,6 +196,7 @@ void knn_extra_metrics(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down Expand Up @@ -242,6 +244,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
HANDLE_VAR(JensenShannon);
HANDLE_VAR(Lp);
HANDLE_VAR(Jaccard);
HANDLE_VAR(NaNEuclidean);
#undef HANDLE_VAR
default:
FAISS_THROW_MSG("metric type not implemented");
Expand Down
20 changes: 20 additions & 0 deletions tests/test_extra_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ def test_jaccard(self):
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_Jaccard)
self.assertTrue(np.allclose(ref_dis, new_dis))

def test_nan_euclidean(self):
xq, yb = self.make_example()
ref_dis = np.array([
[scipy.spatial.distance.sqeuclidean(x, y) for y in yb]
for x in xq
])
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.allclose(ref_dis, new_dis))

x = [[3, np.nan, np.nan, 6]]
q = [[1, np.nan, np.nan, 5]]
dis = [(4 / 2 * ((3 - 1)**2 + (6 - 5)**2))]
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.allclose(new_dis, dis))

x = [[np.nan] * 4]
q = [[np.nan] * 4]
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
self.assertTrue(np.isnan(new_dis[0]))


class TestKNN(unittest.TestCase):
""" test that the knn search gives the same as distance matrix + argmin """
Expand Down

0 comments on commit 1876925

Please sign in to comment.