diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index c35499f78af5b..24314e929e77b 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -482,6 +482,21 @@ Value ExpressionArray::serialize(bool explain) const { return Value(std::move(expressions)); } +intrusive_ptr ExpressionArray::optimize() { + bool allValuesConstant = true; + for (auto&& expr : vpOperand) { + expr = expr->optimize(); + if (!dynamic_cast(expr.get())) { + allValuesConstant = false; + } + } + // If all values in ExpressionArray are constant evaluate to ExpressionConstant. + if (allValuesConstant) { + return ExpressionConstant::create(getExpressionContext(), evaluate(Document())); + } + return this; +} + const char* ExpressionArray::getOpName() const { // This should never be called, but is needed to inherit from ExpressionNary. return "$array"; @@ -2716,31 +2731,106 @@ Value ExpressionIndexOfArray::evaluate(const Document& root) const { arrayArg.isArray()); std::vector array = arrayArg.getArray(); + auto args = evaluateAndValidateArguments(root, vpOperand, array.size()); + for (int i = args.startIndex; i < args.endIndex; i++) { + if (getExpressionContext()->getValueComparator().evaluate(array[i] == args.targetOfSearch)) { + return Value(static_cast(i)); + } + } - Value searchItem = vpOperand[1]->evaluate(root); + return Value(-1); +} - size_t startIndex = 0; - if (vpOperand.size() > 2) { - Value startIndexArg = vpOperand[2]->evaluate(root); +ExpressionIndexOfArray::Arguments ExpressionIndexOfArray::evaluateAndValidateArguments( + const Document& root, const ExpressionVector& operands, size_t arrayLength) const { + + int startIndex = 0; + if (operands.size() > 2) { + Value startIndexArg = operands[2]->evaluate(root); uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index"); - startIndex = static_cast(startIndexArg.coerceToInt()); + startIndex = startIndexArg.coerceToInt(); } - size_t endIndex = array.size(); - if (vpOperand.size() > 3) { - Value endIndexArg = vpOperand[3]->evaluate(root); + int endIndex = arrayLength; + if (operands.size() > 3) { + Value endIndexArg = operands[3]->evaluate(root); uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index"); // Don't let 'endIndex' exceed the length of the array. - endIndex = std::min(array.size(), static_cast(endIndexArg.coerceToInt())); + endIndex = std::min(static_cast(arrayLength), endIndexArg.coerceToInt()); } + return {vpOperand[1]->evaluate(root), startIndex, endIndex}; +} - for (size_t i = startIndex; i < endIndex; i++) { - if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) { - return Value(static_cast(i)); +/** + * This class handles the case where IndexOfArray is given an ExpressionConstant + * instead of using a vector and searching through it we can use a unordered_map + * for O(1) lookup time. + */ +class ExpressionIndexOfArray::Optimized : public ExpressionIndexOfArray { +public: + Optimized(const boost::intrusive_ptr& expCtx, + const ValueUnorderedMap>& indexMap, + const ExpressionVector& operands) + : ExpressionIndexOfArray(expCtx), _indexMap(std::move(indexMap)) { + vpOperand = operands; + } + + virtual Value evaluate(const Document& root) const { + auto args = evaluateAndValidateArguments(root, vpOperand, _indexMap.size()); + auto indexVec = _indexMap.find(args.targetOfSearch); + + if (indexVec == _indexMap.end()) + return Value(-1); + + // Search through the vector of indecies for first index in our range. + for (auto index : indexVec->second) { + if (index >= args.startIndex && index < args.endIndex) { + return Value(index); + } } + // The value we are searching for exists but is not in our range. + return Value(-1); } - return Value(-1); +private: + // Maps the values in the array to the positions at which they occur. We need to remember the + // positions so that we can verify they are in the appropriate range. + const ValueUnorderedMap> _indexMap; +}; + +intrusive_ptr ExpressionIndexOfArray::optimize() { + // This will optimize all arguments to this expression. + auto optimized = ExpressionNary::optimize(); + if(optimized.get() != this){ + return optimized; + } + // If the input array is an ExpressionConstant we can optimize using a unordered_map instead of an + // array. + if (auto constantArray = dynamic_cast(vpOperand[0].get())) { + const Value valueArray = constantArray->getValue(); + if (valueArray.nullish()) { + return ExpressionConstant::create(getExpressionContext(), Value(BSONNULL)); + } + uassert(50749, + str::stream() << "First operand of $indexOfArray must be an array. First " + << "argument is of type: " + << typeName(valueArray.getType()), + valueArray.isArray()); + + auto arr = valueArray.getArray(); + // To handle the case of duplicate values the values need to map to a vector of indecies. + auto indexMap = + getExpressionContext()->getValueComparator().makeUnorderedValueMap>(); + + for (int i = 0; i < int(arr.size()); i++) { + if (indexMap.find(arr[i]) == indexMap.end()) { + indexMap.emplace(arr[i], vector()); + } + indexMap[arr[i]].push_back(i); + } + return new Optimized(getExpressionContext(), indexMap, vpOperand); + } + return this; } REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse); diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 5b4a0286b7c03..7fb281cc758e9 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -728,6 +728,7 @@ class ExpressionArray final : public ExpressionVariadic { Value evaluate(const Document& root) const final; Value serialize(bool explain) const final; + boost::intrusive_ptr optimize() final; const char* getOpName() const final; }; @@ -1218,13 +1219,38 @@ class ExpressionIn final : public ExpressionFixedArity { }; -class ExpressionIndexOfArray final : public ExpressionRangedArity { +class ExpressionIndexOfArray : public ExpressionRangedArity { public: explicit ExpressionIndexOfArray(const boost::intrusive_ptr& expCtx) : ExpressionRangedArity(expCtx) {} - Value evaluate(const Document& root) const final; + Value evaluate(const Document& root) const; + boost::intrusive_ptr optimize() final; const char* getOpName() const final; + +protected: + struct Arguments { + Arguments(Value targetOfSearch, int startIndex, int endIndex) + : targetOfSearch(targetOfSearch), startIndex(startIndex), endIndex(endIndex) {} + + Value targetOfSearch; + int startIndex; + int endIndex; + }; + /** + * When given 'operands' which correspond to the arguments to $indexOfArray, evaluates and + * validates the target value, starting index, and ending index arguments and returns their + * values as a Arguments struct. The starting index and ending index are optional, so as default + * 'startIndex' will be 0 and 'endIndex' will be the length of the input array. Throws a + * UserException if the values are found to be invalid in some way, e.g. if the indexes are not + * numbers. + */ + Arguments evaluateAndValidateArguments(const Document& root, + const ExpressionVector& operands, + size_t arrayLength) const; + +private: + class Optimized; }; @@ -2013,4 +2039,4 @@ class ExpressionConvert final : public Expression { boost::intrusive_ptr _onError; boost::intrusive_ptr _onNull; }; -} +} \ No newline at end of file diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 6e0ab87fcc327..c629963de2c38 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -2208,6 +2208,152 @@ TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegati }); } +TEST(ExpressionArray, ExpressionArrayWithAllConstantValuesShouldOptimizeToExpressionConstant) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + // ExpressionArray of constant values should optimize to ExpressionConsant. + BSONObj bsonarrayOfConstants = BSON("" << BSON_ARRAY(1 << 2 << 3 << 4)); + BSONElement elementArray = bsonarrayOfConstants.firstElement(); + auto expressionArr = ExpressionArray::parse(expCtx, elementArray, vps); + auto optimizedToConstant = expressionArr->optimize(); + auto exprConstant = dynamic_cast(optimizedToConstant.get()); + ASSERT_TRUE(exprConstant); + + // ExpressionArray with not all constant values should not optimize to ExpressionConstant. + BSONObj bsonarray = BSON("" << BSON_ARRAY(1 << "$x" << 3 << 4)); + BSONElement elementArrayNotConstant = bsonarray.firstElement(); + auto expressionArrNotConstant = ExpressionArray::parse(expCtx, elementArrayNotConstant, vps); + auto notOptimized = expressionArrNotConstant->optimize(); + auto notExprConstant = dynamic_cast(notOptimized.get()); + ASSERT_FALSE(notExprConstant); +} + +TEST(ExpressionArray, ExpressionArrayShouldOptimizeSubExpressionToExpressionConstant) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + + // ExpressionArray with constant values and sub expression that evaluates to constant should + // optimize to Expression constant. + BSONObj bsonarrayWithSubExpression = + BSON("" << BSON_ARRAY(1 << BSON("$add" << BSON_ARRAY(1 << 1)) << 3 << 4)); + BSONElement elementArrayWithSubExpression = bsonarrayWithSubExpression.firstElement(); + auto expressionArrWithSubExpression = + ExpressionArray::parse(expCtx, elementArrayWithSubExpression, vps); + auto optimizedToConstantWithSubExpression = expressionArrWithSubExpression->optimize(); + auto constantExpression = + dynamic_cast(optimizedToConstantWithSubExpression.get()); + ASSERT_TRUE(constantExpression); +} + +TEST(ExpressionIndexOfArray, ExpressionIndexOfArrayShouldOptimizeArguments) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + + auto expIndexOfArray = Expression::parseExpression( + expCtx, // 2, 1, 1 + BSON("$indexOfArray" << BSON_ARRAY( + BSON_ARRAY(BSON("$add" << BSON_ARRAY(1 << 1)) << 1 << 1 << 2) + // Value we are searching for = 2. + << BSON("$add" << BSON_ARRAY(1 << 1)) + // Start index = 1. + << BSON("$add" << BSON_ARRAY(0 << 1)) + // End index = 4. + << BSON("$add" << BSON_ARRAY(1 << 3)))), + expCtx->variablesParseState); + auto argsOptimizedToConstants = expIndexOfArray->optimize(); + auto shouldBeIndexOfArray = dynamic_cast(argsOptimizedToConstants.get()); + ASSERT_TRUE(shouldBeIndexOfArray); + ASSERT_VALUE_EQ(Value(3), shouldBeIndexOfArray->getValue()); +} + +TEST(ExpressionIndexOfArray, + ExpressionIndexOfArrayShouldOptimizeNullishInputArrayToExpressionConstant) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + auto expIndex = Expression::parseExpression( + expCtx, fromjson("{ $indexOfArray : [ undefined , 1, 1]}"), expCtx->variablesParseState); + + auto isExpIndexOfArray = dynamic_cast(expIndex.get()); + ASSERT_TRUE(isExpIndexOfArray); + + auto nullishValueOptimizedToExpConstant = isExpIndexOfArray->optimize(); + auto shouldBeExpressionConstant = + dynamic_cast(nullishValueOptimizedToExpConstant.get()); + ASSERT_TRUE(shouldBeExpressionConstant); + // Nullish input array should become a Value(BSONNULL). + ASSERT_VALUE_EQ(Value(BSONNULL), shouldBeExpressionConstant->getValue()); +} + +TEST(ExpressionIndexOfArray, + OptimizedExpressionIndexOfArrayWithConstantArgumentsShouldEvaluateProperly) { + + intrusive_ptr expCtx(new ExpressionContextForTest()); + + auto expIndexOfArray = Expression::parseExpression( + expCtx, + // Search for $x. + fromjson("{ $indexOfArray : [ [0, 1, 2, 3, 4, 5, 'val'] , '$x'] }"), + expCtx->variablesParseState); + auto optimizedIndexOfArray = expIndexOfArray->optimize(); + ASSERT_VALUE_EQ(Value(0), optimizedIndexOfArray->evaluate(Document{{"x", 0}})); + ASSERT_VALUE_EQ(Value(1), optimizedIndexOfArray->evaluate(Document{{"x", 1}})); + ASSERT_VALUE_EQ(Value(2), optimizedIndexOfArray->evaluate(Document{{"x", 2}})); + ASSERT_VALUE_EQ(Value(3), optimizedIndexOfArray->evaluate(Document{{"x", 3}})); + ASSERT_VALUE_EQ(Value(4), optimizedIndexOfArray->evaluate(Document{{"x", 4}})); + ASSERT_VALUE_EQ(Value(5), optimizedIndexOfArray->evaluate(Document{{"x", 5}})); + ASSERT_VALUE_EQ(Value(6), optimizedIndexOfArray->evaluate(Document{{"x", string("val")}})); + + auto optimizedIndexNotFound = optimizedIndexOfArray->optimize(); + // Should evaluate to -1 if not found. + ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 10}})); + ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 100}})); + ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 1000}})); + ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", string("string")}})); + ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", -1}})); +} + +TEST(ExpressionIndexOfArray, + OptimizedExpressionIndexOfArrayWithConstantArgumentsShouldEvaluateProperlyWithRange) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + + auto expIndexOfArray = Expression::parseExpression( + expCtx, + // Search for 4 between 3 and 5. + fromjson("{ $indexOfArray : [ [0, 1, 2, 3, 4, 5] , '$x', 3, 5] }"), + expCtx->variablesParseState); + auto optimizedIndexOfArray = expIndexOfArray->optimize(); + ASSERT_VALUE_EQ(Value(4), optimizedIndexOfArray->evaluate(Document{{"x", 4}})); + + // Should evaluate to -1 if not found in range. + ASSERT_VALUE_EQ(Value(-1), optimizedIndexOfArray->evaluate(Document{{"x", 0}})); +} + +TEST(ExpressionIndexOfArray, + OptimizedExpressionIndexOfArrayWithConstantArrayShouldEvaluateProperlyWithDuplicateValues) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + + auto expIndexOfArrayWithDuplicateValues = + Expression::parseExpression(expCtx, + // Search for 4 between 3 and 5. + fromjson("{ $indexOfArray : [ [0, 1, 2, 2, 3, 4, 5] , '$x'] }"), + expCtx->variablesParseState); + auto optimizedIndexOfArrayWithDuplicateValues = expIndexOfArrayWithDuplicateValues->optimize(); + ASSERT_VALUE_EQ(Value(2), + optimizedIndexOfArrayWithDuplicateValues->evaluate(Document{{"x", 2}})); + // Duplicate Values in a range. + auto expIndexInRangeWithhDuplicateValues = Expression::parseExpression( + expCtx, + // Search for 2 between 4 and 6. + fromjson("{ $indexOfArray : [ [0, 1, 2, 2, 2, 2, 4, 5] , '$x', 4, 6] }"), + expCtx->variablesParseState); + auto optimizedIndexInRangeWithDuplcateValues = expIndexInRangeWithhDuplicateValues->optimize(); + // Should evaluate to 4. + ASSERT_VALUE_EQ(Value(4), + optimizedIndexInRangeWithDuplcateValues->evaluate(Document{{"x", 2}})); +} + namespace FieldPath { /** The provided field path does not pass validation. */