Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_library(
mysql_execute.cpp
mysql_extension.cpp
mysql_filter_pushdown.cpp
mysql_parameter.cpp
mysql_result.cpp
mysql_scanner.cpp
mysql_storage.cpp
Expand Down
8 changes: 5 additions & 3 deletions src/include/mysql_connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ class MySQLConnection {

public:
static MySQLConnection Open(MySQLTypeConfig type_config, const string &connection_string);
void Execute(const string &query, MySQLConnectorInterface con_interface = MySQLConnectorInterface::BASIC);
void Execute(const string &query);
void Execute(const string &query, vector<Value> params);
unique_ptr<MySQLResult> Query(const string &query, MySQLResultStreaming streaming);
unique_ptr<MySQLResult> Query(const string &query, vector<Value> params, MySQLResultStreaming streaming);

vector<IndexInfo> GetIndexInfo(const string &table_name);

Expand All @@ -79,9 +81,9 @@ class MySQLConnection {
static bool DebugPrintQueries();

private:
unique_ptr<MySQLResult> QueryInternal(const string &query, MySQLResultStreaming streaming,
unique_ptr<MySQLResult> QueryInternal(const string &query, vector<Value> params, MySQLResultStreaming streaming,
MySQLConnectorInterface con_interface);
idx_t MySQLExecute(MYSQL_STMT *stmt, const string &query, bool streaming);
idx_t MySQLExecute(MYSQL_STMT *stmt, const string &query, vector<Value> params, bool streaming);

mutex query_lock;
shared_ptr<OwnedMySQLConnection> connection;
Expand Down
29 changes: 29 additions & 0 deletions src/include/mysql_parameter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// mysql_parameter.hpp
//
//
//===----------------------------------------------------------------------===//

#pragma once

#include "duckdb.hpp"
#include "mysql.h"

namespace duckdb {

struct MySQLParameter {
Value value;
enum_field_types buffer_type = MYSQL_TYPE_INVALID;
bool is_unsigned = false;

vector<char> bind_buffer;
unsigned long bind_length = 0;

MySQLParameter(const string &query, Value value_p);

MYSQL_BIND CreateBind();
};

} // namespace duckdb
51 changes: 44 additions & 7 deletions src/mysql_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "duckdb/parser/parser.hpp"
#include "duckdb/storage/table_storage_info.hpp"

#include "mysql_parameter.hpp"
#include "mysql_types.hpp"

namespace duckdb {
Expand Down Expand Up @@ -38,7 +39,7 @@ MySQLConnection MySQLConnection::Open(MySQLTypeConfig type_config, const string
return MySQLConnection(std::move(connection), connection_string, std::move(type_config));
}

idx_t MySQLConnection::MySQLExecute(MYSQL_STMT *stmt, const string &query, bool streaming) {
idx_t MySQLConnection::MySQLExecute(MYSQL_STMT *stmt, const string &query, vector<Value> params, bool streaming) {
if (MySQLConnection::DebugPrintQueries()) {
Printer::Print(query + "\n");
}
Expand All @@ -62,6 +63,30 @@ idx_t MySQLConnection::MySQLExecute(MYSQL_STMT *stmt, const string &query, bool
throw IOException("Failed to prepare MySQL query \"%s\": %s\n", query.c_str(), mysql_stmt_error(stmt));
}

vector<MySQLParameter> mysql_params;
vector<MYSQL_BIND> binds;
if (params.size() > 0) {
size_t expected_count = mysql_stmt_param_count(stmt);
if (expected_count != params.size()) {
throw IOException(
"Incorrect query parameters count specified, expected: %zu, actual: %zu, MySQL query \"%s\": %s\n",
expected_count, params.size(), query.c_str(), mysql_stmt_error(stmt));
}
mysql_params.reserve(params.size());
binds.reserve(params.size());
for (Value &dp : params) {
MySQLParameter mp(query, std::move(dp));
mysql_params.emplace_back(std::move(mp));
MySQLParameter &mp_ref = mysql_params.back();
binds.push_back(mp_ref.CreateBind());
}
auto res_bind = mysql_stmt_bind_param(stmt, binds.data());
if (res_bind != 0) {
throw IOException("Failed to bind parameters, count: %zu, MySQL query \"%s\": %s\n", binds.size(),
query.c_str(), mysql_stmt_error(stmt));
}
}

int res_exec = mysql_stmt_execute(stmt);
if (res_exec != 0) {
throw IOException("Failed to execute MySQL query \"%s\": %s\n", query.c_str(), mysql_stmt_error(stmt));
Expand All @@ -87,31 +112,43 @@ idx_t MySQLConnection::MySQLExecute(MYSQL_STMT *stmt, const string &query, bool
return affected_rows;
}

unique_ptr<MySQLResult> MySQLConnection::QueryInternal(const string &query, MySQLResultStreaming streaming,
unique_ptr<MySQLResult> MySQLConnection::QueryInternal(const string &query, vector<Value> params,
MySQLResultStreaming streaming,
MySQLConnectorInterface con_interface) {
auto con = GetConn();
bool result_streaming = streaming == MySQLResultStreaming::ALLOW_STREAMING;
bool basic_interface = con_interface == MySQLConnectorInterface::BASIC;

if (basic_interface) {
MySQLExecute(nullptr, query, result_streaming);
MySQLExecute(nullptr, query, params, result_streaming);
return unique_ptr<MySQLResult>(nullptr);
}

auto stmt = MySQLStatementPtr(mysql_stmt_init(con), MySQLStatementDelete);
if (!stmt) {
throw IOException("Failed to initialize MySQL query \"%s\": %s\n", query.c_str(), mysql_error(con));
}
idx_t affected_rows = MySQLExecute(stmt.get(), query, result_streaming);
idx_t affected_rows = MySQLExecute(stmt.get(), query, params, result_streaming);
return make_uniq<MySQLResult>(query, std::move(stmt), type_config, affected_rows);
}

unique_ptr<MySQLResult> MySQLConnection::Query(const string &query, MySQLResultStreaming streaming) {
return QueryInternal(query, streaming, MySQLConnectorInterface::PREPARED_STATEMENT);
return Query(query, vector<Value>(), streaming);
}

unique_ptr<MySQLResult> MySQLConnection::Query(const string &query, vector<Value> params,
MySQLResultStreaming streaming) {
return QueryInternal(query, params, streaming, MySQLConnectorInterface::PREPARED_STATEMENT);
}

void MySQLConnection::Execute(const string &query) {
Execute(query, vector<Value>());
}

void MySQLConnection::Execute(const string &query, MySQLConnectorInterface con_interface) {
QueryInternal(query, MySQLResultStreaming::FORCE_MATERIALIZATION, con_interface);
void MySQLConnection::Execute(const string &query, vector<Value> params) {
MySQLConnectorInterface con_interface =
params.size() > 0 ? MySQLConnectorInterface::PREPARED_STATEMENT : MySQLConnectorInterface::BASIC;
QueryInternal(query, std::move(params), MySQLResultStreaming::FORCE_MATERIALIZATION, con_interface);
}

bool MySQLConnection::IsOpen() {
Expand Down
22 changes: 18 additions & 4 deletions src/mysql_execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
namespace duckdb {

struct MySQLExecuteBindData : public TableFunctionData {
explicit MySQLExecuteBindData(MySQLCatalog &mysql_catalog, string query_p)
: mysql_catalog(mysql_catalog), query(std::move(query_p)) {
explicit MySQLExecuteBindData(MySQLCatalog &mysql_catalog, string query_p, vector<Value> params_p)
: mysql_catalog(mysql_catalog), query(std::move(query_p)), params(std::move(params_p)) {
}

bool finished = false;
MySQLCatalog &mysql_catalog;
string query;
vector<Value> params;
};

static duckdb::unique_ptr<FunctionData> MySQLExecuteBind(ClientContext &context, TableFunctionBindInput &input,
Expand All @@ -37,7 +38,19 @@ static duckdb::unique_ptr<FunctionData> MySQLExecuteBind(ClientContext &context,
throw BinderException("Attached database \"%s\" does not refer to a MySQL database", db_name);
}
auto &mysql_catalog = catalog.Cast<MySQLCatalog>();
return make_uniq<MySQLExecuteBindData>(mysql_catalog, input.inputs[1].GetValue<string>());
vector<Value> params;
auto params_it = input.named_parameters.find("params");
if (params_it != input.named_parameters.end()) {
Value &struct_val = params_it->second;
if (struct_val.IsNull()) {
throw BinderException("Parameters to mysql_execute cannot be NULL");
}
if (struct_val.type().id() != LogicalTypeId::STRUCT) {
throw BinderException("Query parameters must be specified in a STRUCT");
}
params = StructValue::GetChildren(struct_val);
}
return make_uniq<MySQLExecuteBindData>(mysql_catalog, input.inputs[1].GetValue<string>(), std::move(params));
}

static void MySQLExecuteFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
Expand All @@ -49,12 +62,13 @@ static void MySQLExecuteFunc(ClientContext &context, TableFunctionInput &data_p,
if (transaction.GetAccessMode() == AccessMode::READ_ONLY) {
throw PermissionException("mysql_execute cannot be run in a read-only connection");
}
transaction.GetConnection().Execute(data.query);
transaction.GetConnection().Execute(data.query, std::move(data.params));
data.finished = true;
}

MySQLExecuteFunction::MySQLExecuteFunction()
: TableFunction("mysql_execute", {LogicalType::VARCHAR, LogicalType::VARCHAR}, MySQLExecuteFunc, MySQLExecuteBind) {
named_parameters["params"] = LogicalType::ANY;
}

} // namespace duckdb
175 changes: 175 additions & 0 deletions src/mysql_parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#include "mysql_parameter.hpp"

#include "duckdb/common/types/datetime.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"

namespace duckdb {

template <typename NUM_TYPE>
static void FillNumberBuffer(Value &value, vector<char> &bind_buffer) {
bind_buffer.resize(sizeof(NUM_TYPE));
NUM_TYPE val = value.GetValueUnsafe<NUM_TYPE>();
std::memcpy(bind_buffer.data(), &val, sizeof(NUM_TYPE));
}

static void FillDateBuffer(Value &value, vector<char> &bind_buffer) {
MYSQL_TIME mt;
std::memset(&mt, '\0', sizeof(MYSQL_TIME));
date_t dd = DateValue::Get(value);
int32_t year, month, day;
Date::Convert(dd, year, month, day);

mt.year = static_cast<unsigned int>(std::abs(year));
mt.month = static_cast<unsigned int>(std::abs(month));
mt.day = static_cast<unsigned int>(std::abs(day));

bind_buffer.resize(sizeof(MYSQL_TIME));
std::memcpy(bind_buffer.data(), &mt, sizeof(MYSQL_TIME));
}

static void FillTimeBuffer(Value &value, vector<char> &bind_buffer) {
MYSQL_TIME mt;
std::memset(&mt, '\0', sizeof(MYSQL_TIME));
dtime_t dt = TimeValue::Get(value);
int32_t hour, minute, second, micros;
Time::Convert(dt, hour, minute, second, micros);

mt.hour = static_cast<unsigned int>(std::abs(hour));
mt.minute = static_cast<unsigned int>(std::abs(minute));
mt.second = static_cast<unsigned int>(std::abs(second));
mt.second_part = static_cast<unsigned long>(std::abs(micros));

bind_buffer.resize(sizeof(MYSQL_TIME));
std::memcpy(bind_buffer.data(), &mt, sizeof(MYSQL_TIME));
}

static void FillTimestampBuffer(Value &value, vector<char> &bind_buffer) {
MYSQL_TIME mt;
std::memset(&mt, '\0', sizeof(MYSQL_TIME));
timestamp_t ts = TimestampValue::Get(value);
date_t dd;
dtime_t dt;
Timestamp::Convert(ts, dd, dt);
int32_t year, month, day;
Date::Convert(dd, year, month, day);
int32_t hour, minute, second, micros;
Time::Convert(dt, hour, minute, second, micros);

mt.year = static_cast<unsigned int>(std::abs(year));
mt.month = static_cast<unsigned int>(std::abs(month));
mt.day = static_cast<unsigned int>(std::abs(day));
mt.hour = static_cast<unsigned int>(std::abs(hour));
mt.minute = static_cast<unsigned int>(std::abs(minute));
mt.second = static_cast<unsigned int>(std::abs(second));
mt.second_part = static_cast<unsigned long>(std::abs(micros));

bind_buffer.resize(sizeof(MYSQL_TIME));
std::memcpy(bind_buffer.data(), &mt, sizeof(MYSQL_TIME));
}

MySQLParameter::MySQLParameter(const string &query, Value value_p) : value(std::move(value_p)) {
if (value.IsNull()) {
return;
}

switch (value.type().id()) {
case LogicalTypeId::BOOLEAN:
this->buffer_type = MYSQL_TYPE_TINY;
FillNumberBuffer<bool>(value, bind_buffer);
break;
case LogicalTypeId::TINYINT:
this->buffer_type = MYSQL_TYPE_TINY;
FillNumberBuffer<int8_t>(value, bind_buffer);
break;
case LogicalTypeId::UTINYINT:
this->buffer_type = MYSQL_TYPE_TINY;
this->is_unsigned = true;
FillNumberBuffer<uint8_t>(value, bind_buffer);
break;
case LogicalTypeId::SMALLINT:
this->buffer_type = MYSQL_TYPE_SHORT;
FillNumberBuffer<int16_t>(value, bind_buffer);
break;
case LogicalTypeId::USMALLINT:
this->buffer_type = MYSQL_TYPE_SHORT;
this->is_unsigned = true;
FillNumberBuffer<uint16_t>(value, bind_buffer);
break;
case LogicalTypeId::INTEGER:
this->buffer_type = MYSQL_TYPE_LONG;
FillNumberBuffer<int32_t>(value, bind_buffer);
break;
case LogicalTypeId::UINTEGER:
this->buffer_type = MYSQL_TYPE_LONG;
this->is_unsigned = true;
FillNumberBuffer<uint32_t>(value, bind_buffer);
break;
case LogicalTypeId::BIGINT:
this->buffer_type = MYSQL_TYPE_LONGLONG;
FillNumberBuffer<int64_t>(value, bind_buffer);
break;
case LogicalTypeId::UBIGINT:
this->buffer_type = MYSQL_TYPE_LONGLONG;
this->is_unsigned = true;
FillNumberBuffer<uint64_t>(value, bind_buffer);
break;
case LogicalTypeId::FLOAT:
this->buffer_type = MYSQL_TYPE_FLOAT;
FillNumberBuffer<float>(value, bind_buffer);
break;
case LogicalTypeId::DOUBLE:
this->buffer_type = MYSQL_TYPE_DOUBLE;
FillNumberBuffer<double>(value, bind_buffer);
break;
case LogicalTypeId::DATE:
this->buffer_type = MYSQL_TYPE_DATE;
FillDateBuffer(value, bind_buffer);
break;
case LogicalTypeId::TIME:
this->buffer_type = MYSQL_TYPE_TIME;
FillTimeBuffer(value, bind_buffer);
break;
case LogicalTypeId::TIMESTAMP:
this->buffer_type = MYSQL_TYPE_DATETIME;
FillTimestampBuffer(value, bind_buffer);
break;
case LogicalTypeId::TIMESTAMP_TZ:
this->buffer_type = MYSQL_TYPE_TIMESTAMP;
FillTimestampBuffer(value, bind_buffer);
break;
case LogicalTypeId::VARCHAR:
// use string ref from the value
break;
default:
throw IOException("Unsupported parameters type: \"%s\", MySQL query \"%s\"", value.type(), query.c_str());
}
}

MYSQL_BIND MySQLParameter::CreateBind() {
MYSQL_BIND bind;
std::memset(&bind, '\0', sizeof(MYSQL_BIND));

if (value.IsNull()) {
bind.buffer_type = MYSQL_TYPE_NULL;
bind.length = &bind_length;
} else if (value.type().id() == LogicalTypeId::VARCHAR) {
const string &str = StringValue::Get(value);
bind.buffer_type = MYSQL_TYPE_VARCHAR;
bind.buffer = const_cast<char *>(str.c_str());
bind.buffer_length = str.length();
bind_length = str.length();
bind.length = &bind_length;
} else {
bind.buffer_type = buffer_type;
bind.is_unsigned = is_unsigned;
bind.buffer = bind_buffer.data();
bind.buffer_length = bind_buffer.size();
bind_length = bind_buffer.size();
bind.length = &bind_length;
}

return bind;
}

} // namespace duckdb
Loading
Loading