From 07289f4672044f933e531a5554e4a70dbefbaeb4 Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Wed, 22 May 2024 11:47:44 -0700 Subject: [PATCH] Fix NaN handling for map subscript and add test for map() (#9881) Summary: Ensures that map subscript identifies NaN as a key where NaNs with any binary representation are considered equal. Differential Revision: D57634535 --- velox/docs/functions/presto/map.rst | 4 +- velox/functions/lib/SubscriptUtil.cpp | 12 ++- velox/functions/lib/SubscriptUtil.h | 9 +- .../prestosql/tests/ElementAtTest.cpp | 96 +++++++++++++++++++ velox/functions/prestosql/tests/MapTest.cpp | 64 ++++++++++++- velox/type/FloatingPointUtil.h | 55 ++++++++++- 6 files changed, 229 insertions(+), 11 deletions(-) diff --git a/velox/docs/functions/presto/map.rst b/velox/docs/functions/presto/map.rst index b91addaa5b5b..7b97b3892138 100644 --- a/velox/docs/functions/presto/map.rst +++ b/velox/docs/functions/presto/map.rst @@ -39,7 +39,8 @@ Map Functions .. function:: map(array(K), array(V)) -> map(K,V) :noindex: - Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls. :: + Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: SELECT map(ARRAY[1,3], ARRAY[2,4]); -- {1 -> 2, 3 -> 4} @@ -147,6 +148,7 @@ Map Functions :noindex: Returns value for given ``key``. Return null if the key is not contained in the map. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal and can be used as keys. Corresponds to SQL subscript operator []. SELECT name_to_age_map['Bob'] AS bob_age; diff --git a/velox/functions/lib/SubscriptUtil.cpp b/velox/functions/lib/SubscriptUtil.cpp index 76ad87cc4844..bcb884c746d7 100644 --- a/velox/functions/lib/SubscriptUtil.cpp +++ b/velox/functions/lib/SubscriptUtil.cpp @@ -28,6 +28,15 @@ namespace facebook::velox::functions { namespace { +template +inline bool isPrimitiveEqual(const T& lhs, const T& rhs) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareEquals{}(lhs, rhs); + } else { + return lhs == rhs; + } +} + template struct SimpleType { using type = typename TypeTraits::NativeType; @@ -128,7 +137,8 @@ VectorPtr applyMapTyped( } else { // Search map without caching. for (size_t offset = offsetStart; offset < offsetEnd; ++offset) { - if (decodedMapKeys->valueAt(offset) == searchKey) { + if (isPrimitiveEqual( + decodedMapKeys->valueAt(offset), searchKey)) { rawIndices[row] = offset; found = true; break; diff --git a/velox/functions/lib/SubscriptUtil.h b/velox/functions/lib/SubscriptUtil.h index 2698a17390ba..18daf68e03a9 100644 --- a/velox/functions/lib/SubscriptUtil.h +++ b/velox/functions/lib/SubscriptUtil.h @@ -21,6 +21,7 @@ #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" #include "velox/expression/VectorReaders.h" +#include "velox/type/FloatingPointUtil.h" #include "velox/type/Type.h" #include "velox/vector/BaseVector.h" #include "velox/vector/ComplexVector.h" @@ -77,12 +78,8 @@ class LookupTable : public LookupTableBase { using inner_allocator_t = memory::StlAllocator>; - using inner_map_t = folly::F14FastMap< - key_t, - vector_size_t, - folly::f14::DefaultHasher, - folly::f14::DefaultKeyEqual, - inner_allocator_t>; + using inner_map_t = typename util::floating_point:: + HashMapNaNAwareTypeTraits::Type; using outer_allocator_t = memory::StlAllocator>; diff --git a/velox/functions/prestosql/tests/ElementAtTest.cpp b/velox/functions/prestosql/tests/ElementAtTest.cpp index ede4c6571801..c7f2706cde76 100644 --- a/velox/functions/prestosql/tests/ElementAtTest.cpp +++ b/velox/functions/prestosql/tests/ElementAtTest.cpp @@ -87,6 +87,93 @@ class ElementAtTest : public FunctionBaseTest { "{10: 10, 11: 11, 12: 12}", }); } + + template + void testFloatingPointCornerCases() { + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kSNaN = std::numeric_limits::signaling_NaN(); + + auto values = makeFlatVector({1, 2, 3, 4, 5}); + auto expected = makeConstant(3, 1); + + auto elementAt = [&](auto map, auto search) { + return evaluate("element_at(C0, C1)", makeRowVector({map, search})); + }; + + // Case 1: Verify NaNs identified even with different binary + // representations. + auto keysIdenticalNaNs = makeFlatVector({1, 2, kNaN, 4, 5}); + auto mapVector = makeMapVector({0}, keysIdenticalNaNs, values); + test::assertEqualVectors( + expected, elementAt(mapVector, makeConstant(kNaN, 1))); + test::assertEqualVectors( + expected, elementAt(mapVector, makeConstant(kSNaN, 1))); + + // Case 2: Verify for equality of +0.0 and -0.0. + auto keysDifferentZeros = makeFlatVector({1, 2, -0.0, 4, 5}); + mapVector = makeMapVector({0}, keysDifferentZeros, values); + test::assertEqualVectors( + expected, elementAt(mapVector, makeConstant(0.0, 1))); + test::assertEqualVectors( + expected, elementAt(mapVector, makeConstant(-0.0, 1))); + + // Case 3: Verify NaNs are identified when nested inside complex type keys + { + auto rowKeys = makeRowVector( + {makeFlatVector({1, 2, kNaN, 4, 5, 6}), + makeFlatVector({1, 2, 3, 4, 5, 6})}); + auto mapOfRowKeys = makeMapVector( + {0, 3}, rowKeys, makeFlatVector({1, 2, 3, 4, 5, 6})); + auto elementValue = makeRowVector( + {makeFlatVector({kSNaN, 1}), makeFlatVector({3, 1})}); + auto element = BaseVector::wrapInConstant(2, 0, elementValue); + auto expected = makeNullableFlatVector({3, std::nullopt}); + auto result = evaluate( + "element_at(C0, C1)", makeRowVector({mapOfRowKeys, element})); + test::assertEqualVectors(expected, result); + } + // case 4: Verify NaNs are identified when employing caching. + exec::ExprSet exprSet({}, &execCtx_); + auto inputs = makeRowVector({}); + exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get()); + + SelectivityVector rows(1); + auto inputMap = makeMapVector({0}, keysIdenticalNaNs, values); + + auto keys = makeFlatVector(std::vector({kSNaN})); + std::vector args = {inputMap, keys}; + + facebook::velox::functions::MapSubscript mapSubscriptWithCaching(true); + + auto checkStatus = [&](bool cachingEnabled, + bool materializedMapIsNull, + const VectorPtr& firtSeen) { + EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled()); + EXPECT_EQ(firtSeen, mapSubscriptWithCaching.firstSeenMap()); + EXPECT_EQ( + materializedMapIsNull, + nullptr == mapSubscriptWithCaching.lookupTable()); + }; + + // Initial state. + checkStatus(true, true, nullptr); + + auto result1 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx); + // Nothing has been materialized yet since the input is seen only once. + checkStatus(true, true, args[0]); + + auto result2 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx); + checkStatus(true, false, args[0]); + + auto result3 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx); + checkStatus(true, false, args[0]); + + // all the result should be the same. + expected = makeConstant(3, 1); + test::assertEqualVectors(expected, result2); + test::assertEqualVectors(result1, result2); + test::assertEqualVectors(result2, result3); + } }; template <> @@ -1086,3 +1173,12 @@ TEST_F(ElementAtTest, testCachingOptimzation) { test::assertEqualVectors(result, result1); } } + +TEST_F(ElementAtTest, floatingPointCornerCases) { + // Verify that different code paths (keys of simple types, complex types and + // optimized caching) correctly identify NaNs and treat all NaNs with + // different binary representations as equal. Also verifies that -/+ 0.0 are + // considered equal. + testFloatingPointCornerCases(); + testFloatingPointCornerCases(); +} diff --git a/velox/functions/prestosql/tests/MapTest.cpp b/velox/functions/prestosql/tests/MapTest.cpp index d3f38fa07cbf..8626c5e17a16 100644 --- a/velox/functions/prestosql/tests/MapTest.cpp +++ b/velox/functions/prestosql/tests/MapTest.cpp @@ -25,7 +25,64 @@ using namespace facebook::velox::functions::test; namespace { -class MapTest : public FunctionBaseTest {}; +class MapTest : public FunctionBaseTest { + public: + template + void testFloatingPointCornerCases() { + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kSNaN = std::numeric_limits::signaling_NaN(); + // Case 1: Check for duplicate NaNs with the same binary representation. + VectorPtr keysIdenticalNaNs = + makeNullableArrayVector({{1, 2, kNaN, 4, 5, kNaN}}); + // Case 2: Check for duplicate NaNs with different binary representation. + VectorPtr keysDifferentNaNs = + makeNullableArrayVector({{1, 2, kNaN, 4, 5, kSNaN}}); + // Case 3: Check for duplicate NaNs when the keys vector is a constant. This + // is to ensure the code path for constant keys is exercised. + VectorPtr keysConstant = + BaseVector::wrapInConstant(1, 0, keysDifferentNaNs); + // Case 4: Check for duplicate NaNs when the keys vector wrapped in a + // dictionary. + VectorPtr keysInDictionary = + wrapInDictionary(makeIndices(1, folly::identity), keysDifferentNaNs); + // Case 5: Check for equality of +0.0 and -0.0. + VectorPtr keysDifferentZeros = + makeNullableArrayVector({{1, 2, -0.0, 4, 5, 0.0}}); + auto values = makeNullableArrayVector({{1, 2, 3, 4, 5, 6}}); + + auto checkDuplicate = [&](VectorPtr& keys, std::string expectedError) { + VELOX_ASSERT_THROW( + evaluate("map(c0, c1)", makeRowVector({keys, values})), + expectedError); + + ASSERT_NO_THROW( + evaluate("try(map(c0, c1))", makeRowVector({keys, values}))); + + // Trying the map version with allowing duplicates. + functions::prestosql::registerMapAllowingDuplicates("map2"); + ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values}))); + }; + + checkDuplicate( + keysIdenticalNaNs, "Duplicate map keys (NaN) are not allowed"); + checkDuplicate( + keysDifferentNaNs, "Duplicate map keys (NaN) are not allowed"); + checkDuplicate(keysConstant, "Duplicate map keys (NaN) are not allowed"); + checkDuplicate( + keysInDictionary, "Duplicate map keys (NaN) are not allowed"); + checkDuplicate( + keysDifferentZeros, "Duplicate map keys (0) are not allowed"); + + // Case 6: Check for duplicate NaNs nested inside a complex key. + VectorPtr arrayOfRows = makeArrayVector( + {0}, + makeRowVector( + {makeFlatVector({1, 2, kNaN, 4, 5, kSNaN}), + makeFlatVector({1, 2, 3, 4, 5, 3})})); + checkDuplicate( + arrayOfRows, "Duplicate map keys ({NaN, 3}) are not allowed"); + } +}; TEST_F(MapTest, noNulls) { auto size = 1'000; @@ -170,6 +227,11 @@ TEST_F(MapTest, duplicateKeys) { ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values}))); } +TEST_F(MapTest, floatingPointCornerCases) { + testFloatingPointCornerCases(); + testFloatingPointCornerCases(); +} + TEST_F(MapTest, fewerValuesThanKeys) { auto size = 1'000; diff --git a/velox/type/FloatingPointUtil.h b/velox/type/FloatingPointUtil.h index 26615847cbca..1dbfc71c7987 100644 --- a/velox/type/FloatingPointUtil.h +++ b/velox/type/FloatingPointUtil.h @@ -20,6 +20,7 @@ #include #include +#include #include namespace facebook::velox { @@ -83,8 +84,9 @@ struct NaNAwareHash { } }; -// Utility struct to provide a clean way of defining a hash set type using -// folly::F14FastSet with overrides for floating point types. +// Utility struct to provide a clean way of defining a hash set and map type +// using folly::F14FastSet and folly::F14FastMap respectively with overrides for +// floating point types. template class HashSetNaNAware : public folly::F14FastSet {}; @@ -97,6 +99,55 @@ template <> class HashSetNaNAware : public folly:: F14FastSet, NaNAwareEquals> {}; + +template < + typename Key, + typename Mapped, + typename Alloc = folly::f14::DefaultAlloc>> +struct HashMapNaNAwareTypeTraits { + using Type = folly::F14FastMap< + Key, + Mapped, + folly::f14::DefaultHasher, + folly::f14::DefaultKeyEqual, + Alloc>; +}; + +template +struct HashMapNaNAwareTypeTraits { + using Type = folly::F14FastMap< + float, + Mapped, + NaNAwareHash, + NaNAwareEquals, + Alloc>; +}; + +template +struct HashMapNaNAwareTypeTraits { + using Type = folly::F14FastMap< + double, + Mapped, + NaNAwareHash, + NaNAwareEquals, + Alloc>; +}; + +/* template +class HashMapNaNAware : public folly::F14FastMap< + float, + Mapped, + NaNAwareHash, + NaNAwareEquals, + Alloc> {}; + +template +class HashMapNaNAware : public folly::F14FastMap< + double, + Mapped, + NaNAwareHash, + NaNAwareEquals, + Alloc> {}; */ } // namespace util::floating_point /// A static class that holds helper functions for DOUBLE type.