From 4d2b73270d7d0231006d1dde2139b69f126e2433 Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Wed, 22 May 2024 11:46:57 -0700 Subject: [PATCH] Fix NaN handling in map_subset (#9893) Summary: Ensure NaNs values of any binary representations are treated as equal and can be identified as keys in a map. Differential Revision: D57681657 --- velox/docs/functions/presto/map.rst | 3 +- velox/functions/prestosql/MapSubset.h | 3 +- .../prestosql/tests/MapSubsetTest.cpp | 67 ++++++++++++++++++- velox/vector/SimpleVector.h | 11 ++- velox/vector/tests/utils/VectorMaker.h | 1 + 5 files changed, 80 insertions(+), 5 deletions(-) diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index 7b97b3892138..8b4541f4ecb8 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -87,7 +87,8 @@ Map Functions .. function:: map_subset(map(K,V), array(k)) -> map(K,V) - Constructs a map from those entries of ``map`` for which the key is in the array given:: + Constructs a map from those entries of ``map`` for which the key is in the array given + For keys containing REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[10]); -- {} SELECT map_subset(MAP(ARRAY[1,2], ARRAY['a','b']), ARRAY[1]); -- {1->'a'} diff --git a/velox/functions/prestosql/MapSubset.h b/velox/functions/prestosql/MapSubset.h index 36f7459b7693..238922e8f052 100644 --- a/velox/functions/prestosql/MapSubset.h +++ b/velox/functions/prestosql/MapSubset.h @@ -17,6 +17,7 @@ #include "velox/expression/ComplexViewTypes.h" #include "velox/functions/Udf.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { @@ -84,7 +85,7 @@ struct MapSubsetPrimitiveFunction { } bool constantSearchKeys_{false}; - folly::F14FastSet> searchKeys_; + util::floating_point::HashSetNaNAware> searchKeys_; }; /// Fast path for constant string keys: map_subset(m, array['a', 'b', 'c']). diff --git a/velox/functions/prestosql/tests/MapSubsetTest.cpp b/velox/functions/prestosql/tests/MapSubsetTest.cpp index de8e1f6afc1c..8d120003d639 100644 --- a/velox/functions/prestosql/tests/MapSubsetTest.cpp +++ b/velox/functions/prestosql/tests/MapSubsetTest.cpp @@ -21,7 +21,67 @@ using namespace facebook::velox::test; namespace facebook::velox::functions { namespace { -class MapSubsetTest : public test::FunctionBaseTest {}; +class MapSubsetTest : public test::FunctionBaseTest { + public: + template + void testFloatNaNs() { + static const auto kNaN = std::numeric_limits::quiet_NaN(); + static const auto kSNaN = std::numeric_limits::signaling_NaN(); + + // Case 1: Non-constant search keys. + auto data = makeRowVector( + {makeMapVectorFromJson({ + "{1:10, NaN:20, 3:null, 4:40, 5:50, 6:60}", + "{NaN:20}", + }), + makeArrayVector({{1, kNaN, 5}, {kSNaN, 3}})}); + + auto expected = makeMapVectorFromJson({ + "{1:10, NaN:20, 5:50}", + "{NaN:20}", + }); + auto result = evaluate("map_subset(c0, c1)", data); + assertEqualVectors(expected, result); + + // Case 2: Constant search keys. + data = makeRowVector( + {makeMapVectorFromJson({ + "{1:10, NaN:20, 3:null, 4:40, 5:50, 6:60}", + "{NaN:20}", + }), + BaseVector::wrapInConstant(2, 0, makeArrayVector({{1, kNaN, 5}}))}); + expected = makeMapVectorFromJson({ + "{1:10, NaN:20, 5:50}", + "{NaN:20}", + }); + result = evaluate("map_subset(c0, c1)", data); + assertEqualVectors(expected, result); + + // Case 3: Map with Complex type as key. + // Map: { [{1, NaN,3}: 1, {4, 5}: 2], [{NaN, 3}: 3, {1, 2}: 4] } + data = makeRowVector({ + makeMapVector( + {0, 2}, + makeArrayVector({{1, kNaN, 3}, {4, 5}, {kSNaN, 3}, {1, 2}}), + makeFlatVector({1, 2, 3, 4})), + makeNestedArrayVectorFromJson({ + "[[1, NaN, 3], [4, 5]]", + "[[1, 2, 3], [NaN, 3]]", + }), + }); + expected = makeMapVector( + {0, 2}, + makeArrayVectorFromJson({ + "[1, NaN, 3]", + "[4, 5]", + "[NaN, 3]", + }), + makeFlatVector({1, 2, 3})); + + result = evaluate("map_subset(c0, c1)", data); + assertEqualVectors(expected, result); + } +}; TEST_F(MapSubsetTest, bigintKey) { auto data = makeRowVector({ @@ -133,5 +193,10 @@ TEST_F(MapSubsetTest, arrayKey) { assertEqualVectors(expected, result); } +TEST_F(MapSubsetTest, floatNaNs) { + testFloatNaNs(); + testFloatNaNs(); +} + } // namespace } // namespace facebook::velox::functions diff --git a/velox/vector/SimpleVector.h b/velox/vector/SimpleVector.h index 1e1f60ddfe16..b039c0ebd8b4 100644 --- a/velox/vector/SimpleVector.h +++ b/velox/vector/SimpleVector.h @@ -29,6 +29,7 @@ #include "velox/functions/lib/string/StringCore.h" #include "velox/type/DecimalUtil.h" +#include "velox/type/FloatingPointUtil.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" #include "velox/vector/TypeAliases.h" @@ -171,8 +172,14 @@ class SimpleVector : public BaseVector { * @return the hash of the value at the given index in this vector */ uint64_t hashValueAt(vector_size_t index) const override { - return isNullAt(index) ? BaseVector::kNullHash - : folly::hasher{}(valueAt(index)); + if (isNullAt(index)) { + return BaseVector::kNullHash; + } + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareHash{}(valueAt(index)); + } else { + return folly::hasher{}(valueAt(index)); + } } std::optional isSorted() const { diff --git a/velox/vector/tests/utils/VectorMaker.h b/velox/vector/tests/utils/VectorMaker.h index 225c2f67cdd2..a047930dd035 100644 --- a/velox/vector/tests/utils/VectorMaker.h +++ b/velox/vector/tests/utils/VectorMaker.h @@ -822,6 +822,7 @@ class VectorMaker { folly::json::serialization_opts options; options.convert_int_keys = true; options.allow_non_string_keys = true; + options.allow_nan_inf = true; folly::dynamic mapObject = folly::parseJson(jsonMap, options); if (mapObject.isNull()) { // Null map.