Skip to content

Commit

Permalink
prepared_statement: move parameters
Browse files Browse the repository at this point in the history
In the Java API, we had a bug where we take ownership of and free
parameters passed into executeWithParams.

Inspecting the method itself, it was taking a shared_ptr, but then
performing a deep copy, which is nonsense. Instead, we should take a
unique_ptr, since we need to copy the parameters to guarantee that they
are not modified for the duration of the query.

This commit also fixes three other issues. First, the Java tests weren't running
any tests from ConnectionTest.java, which is why we didn't observe this
bug. Additionally, the constructor of KuzuConnection uses an assertion,
but assertions are disabled by default, which causes our tests to fail
(and if the assertion is skipped, we segfault).

Also, since rust-lang/cc-rs#900 has been
fixed, we can remove the version pinning of `cc` on MacOS.
  • Loading branch information
Riolku committed Nov 16, 2023
1 parent 093e627 commit d956f01
Show file tree
Hide file tree
Showing 23 changed files with 156 additions and 143 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ jobs:
run: |
ulimit -n 10240
source /Users/runner/.cargo/env
cargo update -p cc --precise '1.0.83'
make rusttest
- name: Rust example
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ nodejstest: nodejs
javatest: java
ifeq ($(OS),Windows_NT)
$(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \
javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \
javac -d build/test -cp ".;build/kuzu_java.jar;third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \
java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".;build/kuzu_java.jar;build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose
else
$(call mkdirp,tools/java_api/build/test) && cd tools/java_api/ && \
javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" -sourcepath src/test/java/com/kuzudb/test/*.java && \
javac -d build/test -cp ".:build/kuzu_java.jar:third_party/junit-platform-console-standalone-1.9.3.jar" src/test/java/com/kuzudb/test/*.java && \
java -jar third_party/junit-platform-console-standalone-1.9.3.jar -cp ".:build/kuzu_java.jar:build/test/" --scan-classpath --include-package=com.kuzudb.java_test --details=verbose
endif

Expand Down
19 changes: 14 additions & 5 deletions src/c_api/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,28 @@ kuzu_prepared_statement* kuzu_connection_prepare(kuzu_connection* connection, co
auto* c_prepared_statement = new kuzu_prepared_statement;
c_prepared_statement->_prepared_statement = prepared_statement;
c_prepared_statement->_bound_values =
new std::unordered_map<std::string, std::shared_ptr<Value>>;
new std::unordered_map<std::string, std::unique_ptr<Value>>;
return c_prepared_statement;
}

kuzu_query_result* kuzu_connection_execute(
kuzu_connection* connection, kuzu_prepared_statement* prepared_statement) {
auto prepared_statement_ptr =
static_cast<PreparedStatement*>(prepared_statement->_prepared_statement);
auto bound_values = static_cast<std::unordered_map<std::string, std::shared_ptr<Value>>*>(
auto bound_values = static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
auto query_result = static_cast<Connection*>(connection->_connection)
->executeWithParams(prepared_statement_ptr, *bound_values)
.release();

// Must copy the parameters for safety, and so that the parameters in the prepared statement
// stay the same.
std::unordered_map<std::string, std::unique_ptr<Value>> copied_bound_values;
for (auto& [name, value] : *bound_values) {
copied_bound_values.emplace(name, value->copy());
}

auto query_result =
static_cast<Connection*>(connection->_connection)
->executeWithParams(prepared_statement_ptr, std::move(copied_bound_values))
.release();
if (query_result == nullptr) {
return nullptr;
}
Expand Down
72 changes: 36 additions & 36 deletions src/c_api/prepared_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ using namespace kuzu::common;
using namespace kuzu::main;

void kuzu_prepared_statement_bind_cpp_value(kuzu_prepared_statement* prepared_statement,
const char* param_name, const std::shared_ptr<Value>& value) {
auto* bound_values = static_cast<std::unordered_map<std::string, std::shared_ptr<Value>>*>(
const char* param_name, std::unique_ptr<Value> value) {
auto* bound_values = static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
bound_values->insert({param_name, value});
bound_values->insert({param_name, std::move(value)});
}

void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement) {
Expand All @@ -22,7 +22,7 @@ void kuzu_prepared_statement_destroy(kuzu_prepared_statement* prepared_statement
delete static_cast<PreparedStatement*>(prepared_statement->_prepared_statement);
}
if (prepared_statement->_bound_values != nullptr) {
delete static_cast<std::unordered_map<std::string, std::shared_ptr<Value>>*>(
delete static_cast<std::unordered_map<std::string, std::unique_ptr<Value>>*>(
prepared_statement->_bound_values);
}
delete prepared_statement;
Expand All @@ -48,97 +48,97 @@ char* kuzu_prepared_statement_get_error_message(kuzu_prepared_statement* prepare

void kuzu_prepared_statement_bind_bool(
kuzu_prepared_statement* prepared_statement, const char* param_name, bool value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_int64(
kuzu_prepared_statement* prepared_statement, const char* param_name, int64_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_int32(
kuzu_prepared_statement* prepared_statement, const char* param_name, int32_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_int16(
kuzu_prepared_statement* prepared_statement, const char* param_name, int16_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_int8(
kuzu_prepared_statement* prepared_statement, const char* param_name, int8_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_uint64(
kuzu_prepared_statement* prepared_statement, const char* param_name, uint64_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_uint32(
kuzu_prepared_statement* prepared_statement, const char* param_name, uint32_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_uint16(
kuzu_prepared_statement* prepared_statement, const char* param_name, uint16_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_uint8(
kuzu_prepared_statement* prepared_statement, const char* param_name, uint8_t value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_double(
kuzu_prepared_statement* prepared_statement, const char* param_name, double value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_float(
kuzu_prepared_statement* prepared_statement, const char* param_name, float value) {
auto value_ptr = std::make_shared<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(value);
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_date(
kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_date_t value) {
auto value_ptr = std::make_shared<Value>(date_t(value.days));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(date_t(value.days));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_timestamp(
kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_timestamp_t value) {
auto value_ptr = std::make_shared<Value>(timestamp_t(value.value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(timestamp_t(value.value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_interval(
kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_interval_t value) {
auto value_ptr = std::make_shared<Value>(interval_t(value.months, value.days, value.micros));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(interval_t(value.months, value.days, value.micros));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_string(
kuzu_prepared_statement* prepared_statement, const char* param_name, const char* value) {
auto value_ptr =
std::make_shared<Value>(LogicalType{LogicalTypeID::STRING}, std::string(value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, std::string(value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}

void kuzu_prepared_statement_bind_value(
kuzu_prepared_statement* prepared_statement, const char* param_name, kuzu_value* value) {
auto value_ptr = std::make_shared<Value>(*static_cast<Value*>(value->_value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, value_ptr);
auto value_ptr = std::make_unique<Value>(*static_cast<Value*>(value->_value));
kuzu_prepared_statement_bind_cpp_value(prepared_statement, param_name, std::move(value_ptr));
}
4 changes: 2 additions & 2 deletions src/include/common/enums/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ struct StatementTypeUtils {
case StatementType::ALTER:
case StatementType::CREATE_MACRO:
case StatementType::COPY_FROM:
return true;
default:
return false;
default:
return true;
}
}
};
Expand Down
8 changes: 8 additions & 0 deletions src/include/common/types/value/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ class Value {
* @return a Value with the same value as other.
*/
KUZU_API Value(const Value& other);

/**
* @param other the value to move from.
* @return a Value with the same value as other.
*/
KUZU_API Value(Value&& other) = default;
KUZU_API Value& operator=(Value&& other) = default;

/**
* @brief Sets the data type of the Value.
* @param dataType_ the data type to set to.
Expand Down
16 changes: 8 additions & 8 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class Connection {
template<typename... Args>
inline std::unique_ptr<QueryResult> execute(
PreparedStatement* preparedStatement, std::pair<std::string, Args>... args) {
std::unordered_map<std::string, std::shared_ptr<common::Value>> inputParameters;
return executeWithParams(preparedStatement, inputParameters, args...);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParameters;
return executeWithParams(preparedStatement, std::move(inputParameters), args...);
}
/**
* @brief Executes the given prepared statement with inputParams and returns the result.
Expand All @@ -93,7 +93,7 @@ class Connection {
* @return the result of the query.
*/
KUZU_API std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& inputParams);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParams);
/**
* @brief interrupts all queries currently executing within this connection.
*/
Expand Down Expand Up @@ -151,16 +151,16 @@ class Connection {

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& params,
std::unordered_map<std::string, std::unique_ptr<common::Value>> params,
std::pair<std::string, T> arg, std::pair<std::string, Args>... args) {
auto name = arg.first;
auto val = std::make_shared<common::Value>((T)arg.second);
params.insert({name, val});
return executeWithParams(preparedStatement, params, args...);
auto val = std::make_unique<common::Value>((T)arg.second);
params.insert({name, std::move(val)});
return executeWithParams(preparedStatement, std::move(params), args...);
}

void bindParametersNoLock(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<common::Value>>& inputParams);
std::unordered_map<std::string, std::unique_ptr<common::Value>> inputParams);

std::unique_ptr<QueryResult> executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx = 0u);
Expand Down
12 changes: 8 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ uint64_t Connection::getQueryTimeOut() {
}

std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<Value>>& inputParams) {
std::unordered_map<std::string, std::unique_ptr<Value>> inputParams) {
lock_t lck{mtx};
if (!preparedStatement->isSuccess()) {
return queryResultWithError(preparedStatement->errMsg);
}
try {
bindParametersNoLock(preparedStatement, inputParams);
bindParametersNoLock(preparedStatement, std::move(inputParams));
} catch (Exception& exception) {
std::string errMsg = exception.what();
return queryResultWithError(errMsg);
Expand All @@ -172,7 +172,7 @@ std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* pr
}

void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,
std::unordered_map<std::string, std::shared_ptr<Value>>& inputParams) {
std::unordered_map<std::string, std::unique_ptr<Value>> inputParams) {
auto& parameterMap = preparedStatement->parameterMap;
for (auto& [name, value] : inputParams) {
if (!parameterMap.contains(name)) {
Expand All @@ -184,7 +184,11 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,
value->getDataType()->toString() + " but expects " +
expectParam->getDataType()->toString() + ".");
}
parameterMap.at(name)->copyValueFrom(*value);
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
// The reason is that other parts of the code rely on the existing Value object to be
// modified in-place, not replaced in this map.
Value* v = value.release();
*parameterMap.at(name) = std::move(*v);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/prepared_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace kuzu {
namespace main {

bool PreparedStatement::allowActiveTransaction() const {
return !StatementTypeUtils::allowActiveTransaction(preparedSummary.statementType);
return StatementTypeUtils::allowActiveTransaction(preparedSummary.statementType);
}

bool PreparedStatement::isTransactionStatement() const {
Expand Down
2 changes: 1 addition & 1 deletion test/c_api/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ TEST_F(CApiConnectionTest, Execute) {
auto connection = getConnection();
auto query = "MATCH (a:person) WHERE a.isStudent = $1 RETURN COUNT(*)";
auto statement = kuzu_connection_prepare(connection, query);
kuzu_prepared_statement_bind_bool(statement, (char*)"1", true);
kuzu_prepared_statement_bind_bool(statement, "1", true);
auto result = kuzu_connection_execute(connection, statement);
ASSERT_NE(result, nullptr);
ASSERT_NE(result->_query_result, nullptr);
Expand Down
Loading

0 comments on commit d956f01

Please sign in to comment.