Skip to content

Commit

Permalink
Fix NaN handling for map subscript and add test for map() (facebookin…
Browse files Browse the repository at this point in the history
…cubator#9881)

Summary:

Ensures that map subscript identifies NaN as a key where NaNs with any binary
representation are considered equal.

Differential Revision: D57634535
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 22, 2024
1 parent ffc781e commit 7165198
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 11 deletions.
4 changes: 3 additions & 1 deletion velox/docs/functions/presto/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ Map Functions
.. function:: map(array(K), array(V)) -> map(K,V)
:noindex:

Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls. ::
Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT map(ARRAY[1,3], ARRAY[2,4]); -- {1 -> 2, 3 -> 4}

Expand Down Expand Up @@ -147,6 +148,7 @@ Map Functions
:noindex:

Returns value for given ``key``. Return null if the key is not contained in the map.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal and can be used as keys.
Corresponds to SQL subscript operator [].

SELECT name_to_age_map['Bob'] AS bob_age;
Expand Down
12 changes: 11 additions & 1 deletion velox/functions/lib/SubscriptUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ namespace facebook::velox::functions {

namespace {

template <typename T>
inline bool isPrimitiveEqual(const T& lhs, const T& rhs) {
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareEquals<T>{}(lhs, rhs);
} else {
return lhs == rhs;
}
}

template <TypeKind Kind>
struct SimpleType {
using type = typename TypeTraits<Kind>::NativeType;
Expand Down Expand Up @@ -128,7 +137,8 @@ VectorPtr applyMapTyped(
} else {
// Search map without caching.
for (size_t offset = offsetStart; offset < offsetEnd; ++offset) {
if (decodedMapKeys->valueAt<TKey>(offset) == searchKey) {
if (isPrimitiveEqual<TKey>(
decodedMapKeys->valueAt<TKey>(offset), searchKey)) {
rawIndices[row] = offset;
found = true;
break;
Expand Down
9 changes: 3 additions & 6 deletions velox/functions/lib/SubscriptUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorReaders.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/type/Type.h"
#include "velox/vector/BaseVector.h"
#include "velox/vector/ComplexVector.h"
Expand Down Expand Up @@ -77,12 +78,8 @@ class LookupTable : public LookupTableBase {
using inner_allocator_t =
memory::StlAllocator<std::pair<key_t const, vector_size_t>>;

using inner_map_t = folly::F14FastMap<
key_t,
vector_size_t,
folly::f14::DefaultHasher<key_t>,
folly::f14::DefaultKeyEqual<key_t>,
inner_allocator_t>;
using inner_map_t = typename util::floating_point::
HashMapNaNAwareTypeTraits<key_t, vector_size_t, inner_allocator_t>::Type;

using outer_allocator_t =
memory::StlAllocator<std::pair<vector_size_t const, inner_map_t>>;
Expand Down
96 changes: 96 additions & 0 deletions velox/functions/prestosql/tests/ElementAtTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,93 @@ class ElementAtTest : public FunctionBaseTest {
"{10: 10, 11: 11, 12: 12}",
});
}

template <typename T>
void testFloatingPointCornerCases() {
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();

auto values = makeFlatVector<int32_t>({1, 2, 3, 4, 5});
auto expected = makeConstant<int32_t>(3, 1);

auto elementAt = [&](auto map, auto search) {
return evaluate("element_at(C0, C1)", makeRowVector({map, search}));
};

// Case 1: Verify NaNs identified even with different binary
// representations.
auto keysIdenticalNaNs = makeFlatVector<T>({1, 2, kNaN, 4, 5});
auto mapVector = makeMapVector({0}, keysIdenticalNaNs, values);
test::assertEqualVectors(
expected, elementAt(mapVector, makeConstant<T>(kNaN, 1)));
test::assertEqualVectors(
expected, elementAt(mapVector, makeConstant<T>(kSNaN, 1)));

// Case 2: Verify for equality of +0.0 and -0.0.
auto keysDifferentZeros = makeFlatVector<T>({1, 2, -0.0, 4, 5});
mapVector = makeMapVector({0}, keysDifferentZeros, values);
test::assertEqualVectors(
expected, elementAt(mapVector, makeConstant<T>(0.0, 1)));
test::assertEqualVectors(
expected, elementAt(mapVector, makeConstant<T>(-0.0, 1)));

// Case 3: Verify NaNs are identified when nested inside complex type keys
{
auto rowKeys = makeRowVector(
{makeFlatVector<T>({1, 2, kNaN, 4, 5, 6}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6})});
auto mapOfRowKeys = makeMapVector(
{0, 3}, rowKeys, makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}));
auto elementValue = makeRowVector(
{makeFlatVector<T>({kSNaN, 1}), makeFlatVector<int32_t>({3, 1})});
auto element = BaseVector::wrapInConstant(2, 0, elementValue);
auto expected = makeNullableFlatVector<int32_t>({3, std::nullopt});
auto result = evaluate(
"element_at(C0, C1)", makeRowVector({mapOfRowKeys, element}));
test::assertEqualVectors(expected, result);
}
// case 4: Verify NaNs are identified when employing caching.
exec::ExprSet exprSet({}, &execCtx_);
auto inputs = makeRowVector({});
exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get());

SelectivityVector rows(1);
auto inputMap = makeMapVector({0}, keysIdenticalNaNs, values);

auto keys = makeFlatVector<T>(std::vector<T>({kSNaN}));
std::vector<VectorPtr> args = {inputMap, keys};

facebook::velox::functions::MapSubscript mapSubscriptWithCaching(true);

auto checkStatus = [&](bool cachingEnabled,
bool materializedMapIsNull,
const VectorPtr& firtSeen) {
EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled());
EXPECT_EQ(firtSeen, mapSubscriptWithCaching.firstSeenMap());
EXPECT_EQ(
materializedMapIsNull,
nullptr == mapSubscriptWithCaching.lookupTable());
};

// Initial state.
checkStatus(true, true, nullptr);

auto result1 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
// Nothing has been materialized yet since the input is seen only once.
checkStatus(true, true, args[0]);

auto result2 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
checkStatus(true, false, args[0]);

auto result3 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
checkStatus(true, false, args[0]);

// all the result should be the same.
expected = makeConstant<int32_t>(3, 1);
test::assertEqualVectors(expected, result2);
test::assertEqualVectors(result1, result2);
test::assertEqualVectors(result2, result3);
}
};

template <>
Expand Down Expand Up @@ -1086,3 +1173,12 @@ TEST_F(ElementAtTest, testCachingOptimzation) {
test::assertEqualVectors(result, result1);
}
}

TEST_F(ElementAtTest, floatingPointCornerCases) {
// Verify that different code paths (keys of simple types, complex types and
// optimized caching) correctly identify NaNs and treat all NaNs with
// different binary representations as equal. Also verifies that -/+ 0.0 are
// considered equal.
testFloatingPointCornerCases<float>();
testFloatingPointCornerCases<double>();
}
64 changes: 63 additions & 1 deletion velox/functions/prestosql/tests/MapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,64 @@ using namespace facebook::velox::functions::test;

namespace {

class MapTest : public FunctionBaseTest {};
class MapTest : public FunctionBaseTest {
public:
template <typename T>
void testFloatingPointCornerCases() {
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();
// Case 1: Check for duplicate NaNs with the same binary representation.
VectorPtr keysIdenticalNaNs =
makeNullableArrayVector<T>({{1, 2, kNaN, 4, 5, kNaN}});
// Case 2: Check for duplicate NaNs with different binary representation.
VectorPtr keysDifferentNaNs =
makeNullableArrayVector<T>({{1, 2, kNaN, 4, 5, kSNaN}});
// Case 3: Check for duplicate NaNs when the keys vector is a constant. This
// is to ensure the code path for constant keys is exercised.
VectorPtr keysConstant =
BaseVector::wrapInConstant(1, 0, keysDifferentNaNs);
// Case 4: Check for duplicate NaNs when the keys vector wrapped in a
// dictionary.
VectorPtr keysInDictionary =
wrapInDictionary(makeIndices(1, folly::identity), keysDifferentNaNs);
// Case 5: Check for equality of +0.0 and -0.0.
VectorPtr keysDifferentZeros =
makeNullableArrayVector<T>({{1, 2, -0.0, 4, 5, 0.0}});
auto values = makeNullableArrayVector<int32_t>({{1, 2, 3, 4, 5, 6}});

auto checkDuplicate = [&](VectorPtr& keys, std::string expectedError) {
VELOX_ASSERT_THROW(
evaluate("map(c0, c1)", makeRowVector({keys, values})),
expectedError);

ASSERT_NO_THROW(
evaluate("try(map(c0, c1))", makeRowVector({keys, values})));

// Trying the map version with allowing duplicates.
functions::prestosql::registerMapAllowingDuplicates("map2");
ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values})));
};

checkDuplicate(
keysIdenticalNaNs, "Duplicate map keys (NaN) are not allowed");
checkDuplicate(
keysDifferentNaNs, "Duplicate map keys (NaN) are not allowed");
checkDuplicate(keysConstant, "Duplicate map keys (NaN) are not allowed");
checkDuplicate(
keysInDictionary, "Duplicate map keys (NaN) are not allowed");
checkDuplicate(
keysDifferentZeros, "Duplicate map keys (0) are not allowed");

// Case 6: Check for duplicate NaNs nested inside a complex key.
VectorPtr arrayOfRows = makeArrayVector(
{0},
makeRowVector(
{makeFlatVector<T>({1, 2, kNaN, 4, 5, kSNaN}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 3})}));
checkDuplicate(
arrayOfRows, "Duplicate map keys ({NaN, 3}) are not allowed");
}
};

TEST_F(MapTest, noNulls) {
auto size = 1'000;
Expand Down Expand Up @@ -170,6 +227,11 @@ TEST_F(MapTest, duplicateKeys) {
ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values})));
}

TEST_F(MapTest, floatingPointCornerCases) {
testFloatingPointCornerCases<float>();
testFloatingPointCornerCases<double>();
}

TEST_F(MapTest, fewerValuesThanKeys) {
auto size = 1'000;

Expand Down
55 changes: 53 additions & 2 deletions velox/type/FloatingPointUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cmath>
#include <vector>

#include <folly/container/F14Map.h>
#include <folly/container/F14Set.h>

namespace facebook::velox {
Expand Down Expand Up @@ -83,8 +84,9 @@ struct NaNAwareHash {
}
};

// Utility struct to provide a clean way of defining a hash set type using
// folly::F14FastSet with overrides for floating point types.
// Utility struct to provide a clean way of defining a hash set and map type
// using folly::F14FastSet and folly::F14FastMap respectively with overrides for
// floating point types.
template <typename Key>
class HashSetNaNAware : public folly::F14FastSet<Key> {};

Expand All @@ -97,6 +99,55 @@ template <>
class HashSetNaNAware<double>
: public folly::
F14FastSet<double, NaNAwareHash<double>, NaNAwareEquals<double>> {};

template <
typename Key,
typename Mapped,
typename Alloc = folly::f14::DefaultAlloc<std::pair<Key const, Mapped>>>
struct HashMapNaNAwareTypeTraits {
using Type = folly::F14FastMap<
Key,
Mapped,
folly::f14::DefaultHasher<Key>,
folly::f14::DefaultKeyEqual<Key>,
Alloc>;
};

template <typename Mapped, typename Alloc>
struct HashMapNaNAwareTypeTraits<float, Mapped, Alloc> {
using Type = folly::F14FastMap<
float,
Mapped,
NaNAwareHash<float>,
NaNAwareEquals<float>,
Alloc>;
};

template <typename Mapped, typename Alloc>
struct HashMapNaNAwareTypeTraits<double, Mapped, Alloc> {
using Type = folly::F14FastMap<
double,
Mapped,
NaNAwareHash<double>,
NaNAwareEquals<double>,
Alloc>;
};

/* template <typename Mapped, typename Alloc>
class HashMapNaNAware<float, Mapped, Alloc> : public folly::F14FastMap<
float,
Mapped,
NaNAwareHash<float>,
NaNAwareEquals<float>,
Alloc> {};
template <typename Mapped, typename Alloc>
class HashMapNaNAware<double, Mapped, Alloc> : public folly::F14FastMap<
double,
Mapped,
NaNAwareHash<double>,
NaNAwareEquals<double>,
Alloc> {}; */
} // namespace util::floating_point

/// A static class that holds helper functions for DOUBLE type.
Expand Down

0 comments on commit 7165198

Please sign in to comment.