Skip to content

Commit

Permalink
Fix issue 2589
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Dec 19, 2023
1 parent c0cce2b commit 2ed267f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 110 deletions.
107 changes: 20 additions & 87 deletions src/expression_evaluator/case_evaluator.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
#include "expression_evaluator/case_evaluator.h"

#include <cmath>

#include "common/types/date_t.h"
#include "common/types/interval_t.h"
#include "common/types/ku_string.h"
#include "common/types/timestamp_t.h"
#include "common/types/types.h"

using namespace kuzu::common;
using namespace kuzu::processor;
using namespace kuzu::storage;
Expand Down Expand Up @@ -41,16 +33,16 @@ void CaseExpressionEvaluator::evaluate() {
alternativeEvaluator->thenEvaluator->evaluate();
auto thenVector = alternativeEvaluator->thenEvaluator->resultVector.get();
if (alternativeEvaluator->whenEvaluator->isResultFlat()) {
fillAllSwitch(*thenVector);
fillAll(thenVector);
} else {
fillSelectedSwitch(*whenSelVector, *thenVector);
fillSelected(*whenSelVector, thenVector);
}
if (filledMask.count() == resultVector->state->selVector->selectedSize) {
return;
}
}
elseEvaluator->evaluate();
fillAllSwitch(*elseEvaluator->resultVector);
fillAll(elseEvaluator->resultVector.get());
}

bool CaseExpressionEvaluator::select(SelectionVector& selVector) {
Expand Down Expand Up @@ -90,89 +82,30 @@ void CaseExpressionEvaluator::resolveResultVector(
resolveResultStateFromChildren(inputEvaluators);
}

template<typename T>
void CaseExpressionEvaluator::fillEntry(sel_t resultPos, const ValueVector& thenVector) {
if (filledMask[resultPos]) {
return;
}
filledMask[resultPos] = true;
auto thenPos =
thenVector.state->isFlat() ? thenVector.state->selVector->selectedPositions[0] : resultPos;
if (thenVector.isNull(thenPos)) {
resultVector->setNull(resultPos, true);
} else {
if (thenVector.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) {
auto srcListEntry = thenVector.getValue<list_entry_t>(thenPos);
ListVector::addList(resultVector.get(), srcListEntry.size);
resultVector->copyFromVectorData(resultPos, &thenVector, thenPos);
} else {
auto val = thenVector.getValue<T>(thenPos);
resultVector->setValue<T>(resultPos, val);
}
void CaseExpressionEvaluator::fillSelected(

Check warning on line 85 in src/expression_evaluator/case_evaluator.cpp

View check run for this annotation

Codecov / codecov/patch

src/expression_evaluator/case_evaluator.cpp#L85

Added line #L85 was not covered by tests
const SelectionVector& selVector, ValueVector* srcVector) {
for (auto i = 0u; i < selVector.selectedSize; ++i) {
auto resultPos = selVector.selectedPositions[i];
fillEntry(resultPos, srcVector);
}
}

void CaseExpressionEvaluator::fillAllSwitch(const ValueVector& thenVector) {
switch (resultVector->dataType.getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
fillAll<bool>(thenVector);
} break;
case LogicalTypeID::INT64: {
fillAll<int64_t>(thenVector);
} break;
case LogicalTypeID::DOUBLE: {
fillAll<double_t>(thenVector);
} break;
case LogicalTypeID::DATE: {
fillAll<date_t>(thenVector);
} break;
case LogicalTypeID::TIMESTAMP: {
fillAll<timestamp_t>(thenVector);
} break;
case LogicalTypeID::INTERVAL: {
fillAll<interval_t>(thenVector);
} break;
case LogicalTypeID::STRING: {
fillAll<ku_string_t>(thenVector);
} break;
case LogicalTypeID::VAR_LIST: {
fillAll<list_entry_t>(thenVector);
} break;
default:
KU_UNREACHABLE;
void CaseExpressionEvaluator::fillAll(ValueVector* srcVector) {
auto resultSelVector = resultVector->state->selVector.get();
for (auto i = 0u; i < resultSelVector->selectedSize; ++i) {
auto resultPos = resultSelVector->selectedPositions[i];
fillEntry(resultPos, srcVector);
}
}

void CaseExpressionEvaluator::fillSelectedSwitch(
const SelectionVector& selVector, const ValueVector& thenVector) {
switch (resultVector->dataType.getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
fillSelected<bool>(selVector, thenVector);
} break;
case LogicalTypeID::INT64: {
fillSelected<int64_t>(selVector, thenVector);
} break;
case LogicalTypeID::DOUBLE: {
fillSelected<double_t>(selVector, thenVector);
} break;
case LogicalTypeID::DATE: {
fillSelected<date_t>(selVector, thenVector);
} break;
case LogicalTypeID::TIMESTAMP: {
fillSelected<timestamp_t>(selVector, thenVector);
} break;
case LogicalTypeID::INTERVAL: {
fillSelected<interval_t>(selVector, thenVector);
} break;
case LogicalTypeID::STRING: {
fillSelected<ku_string_t>(selVector, thenVector);
} break;
case LogicalTypeID::VAR_LIST: {
fillSelected<list_entry_t>(selVector, thenVector);
} break;
default:
KU_UNREACHABLE;
void CaseExpressionEvaluator::fillEntry(sel_t resultPos, ValueVector* srcVector) {
if (filledMask[resultPos]) {
return;
}
filledMask[resultPos] = true;
auto srcPos =
srcVector->state->isFlat() ? srcVector->state->selVector->selectedPositions[0] : resultPos;
resultVector->copyFromVectorData(resultPos, srcVector, srcPos);
}

} // namespace evaluator
Expand Down
26 changes: 3 additions & 23 deletions src/include/expression_evaluator/case_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,11 @@ class CaseExpressionEvaluator : public ExpressionEvaluator {
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) override;

private:
template<typename T>
void fillEntry(common::sel_t resultPos, const common::ValueVector& thenVector);

template<typename T>
inline void fillSelected(
const common::SelectionVector& selVector, const common::ValueVector& thenVector) {
for (auto i = 0u; i < selVector.selectedSize; ++i) {
auto resultPos = selVector.selectedPositions[i];
fillEntry<T>(resultPos, thenVector);
}
}

template<typename T>
inline void fillAll(const common::ValueVector& thenVector) {
auto resultSelVector = resultVector->state->selVector.get();
for (auto i = 0u; i < resultSelVector->selectedSize; ++i) {
auto resultPos = resultSelVector->selectedPositions[i];
fillEntry<T>(resultPos, thenVector);
}
}
void fillSelected(const common::SelectionVector& selVector, common::ValueVector* srcVector);

void fillAllSwitch(const common::ValueVector& thenVector);
void fillAll(common::ValueVector* srcVector);

void fillSelectedSwitch(
const common::SelectionVector& selVector, const common::ValueVector& thenVector);
void fillEntry(common::sel_t resultPos, common::ValueVector* srcVector);

private:
std::shared_ptr<binder::Expression> expression;
Expand Down
26 changes: 26 additions & 0 deletions test/test_files/issue/issue.test
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,29 @@ t2
---- 2
t1
t2

-CASE 2589
-STATEMENT CREATE NODE TABLE T (id STRING,PRIMARY KEY(id));
---- ok
-STATEMENT CREATE REL TABLE E (FROM T TO T);
---- ok
-STATEMENT CREATE (t1:T{id:"t1"})-[:E]->(t2:T{id:"t2"});
---- ok
-STATEMENT WITH "node_after" as selected
MATCH (t1:T{id: "t1"})-[:E]->(t2:T{id: "t2"})
WITH
CASE WHEN selected = "node_after" THEN t2
ELSE t1
END as ret
RETURN ret;
---- 1
{_ID: 0:1, _LABEL: T, id: t2}
-STATEMENT WITH "node_after" as selected
MATCH (t1:T{id: "t1"})-[:E]->(t2:T{id: "t2"})
WITH
CASE WHEN selected <> "node_after" THEN t2
ELSE t1
END as ret
RETURN ret;
---- 1
{_ID: 0:0, _LABEL: T, id: t1}

0 comments on commit 2ed267f

Please sign in to comment.