Skip to content

Commit cce8da2

Browse files
committed
Bug#35115810 stmt-execute with new-params-bound = 0 [2/2]
The COM_STMT_EXECUTE message is handled by the StmtExecute Codec which currently doesn't support messages with "new-params-bound = 0" which is used when a prepared statement is executed twice, with the same data-types, but possibly different values. In the new-param-bound = 0 case, the types of the first execute of a statement must be remembered. Change ------ - handle new-param-bound = 0 properly Change-Id: Ic71a9fd00b0f56d8cfc2d707873c164fed7495fd
1 parent 1250a9a commit cce8da2

File tree

5 files changed

+115
-11
lines changed

5 files changed

+115
-11
lines changed

router/src/mysql_protocol/include/mysqlrouter/classic_protocol_codec_base.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ stdx::expected<size_t, std::error_code> encode(const T &v,
110110
*
111111
* @param buffer buffer to read from
112112
* @param caps protocol capabilities
113+
* @tparam T the message class
113114
* @returns number of bytes read from 'buffers' and a T on success, or
114115
* std::error_code on error
115116
*/
@@ -119,6 +120,28 @@ stdx::expected<std::pair<size_t, T>, std::error_code> decode(
119120
return Codec<T>::decode(buffer, caps);
120121
}
121122

123+
/**
124+
* decode a message from a buffer.
125+
*
126+
* @param buffer buffer to read from
127+
* @param caps protocol capabilities
128+
* @param args arguments that shall be forwarded to T's decode()
129+
* @tparam T the message class
130+
* @tparam Args Types of the extra arguments to be forwarded to T's decode()
131+
* function.
132+
* @returns number of bytes read from 'buffers' and a T on success, or
133+
* std::error_code on error
134+
*/
135+
template <class T, class... Args>
136+
stdx::expected<std::pair<size_t, T>, std::error_code> decode(
137+
const net::const_buffer &buffer, capabilities::value_type caps,
138+
// clang-format off
139+
Args &&... args
140+
// clang-format on
141+
) {
142+
return Codec<T>::decode(buffer, caps, std::forward<Args>(args)...);
143+
}
144+
122145
namespace impl {
123146

124147
/**

router/src/mysql_protocol/include/mysqlrouter/classic_protocol_codec_message.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,17 +2111,13 @@ class Codec<borrowable::message::client::StmtExecute<Borrowed>>
21112111
auto new_params_bound_res = accu.template step<bw::FixedInt<1>>();
21122112
if (!accu.result()) return stdx::make_unexpected(accu.result().error());
21132113

2114-
auto new_params_bound = new_params_bound_res->value();
2115-
if (new_params_bound != 1) {
2116-
// new-params-bound is required as long as the decoder doesn't know the
2117-
// old param-defs
2118-
return stdx::make_unexpected(make_error_code(codec_errc::invalid_input));
2119-
}
2120-
21212114
std::vector<typename value_type::ParamDef> types;
21222115
std::vector<std::optional<typename value_type::string_type>> values;
21232116

2124-
if (new_params_bound == 1) {
2117+
auto new_params_bound = new_params_bound_res->value();
2118+
if (new_params_bound == 0) {
2119+
types = *metadata_res;
2120+
} else if (new_params_bound == 1) {
21252121
types.reserve(param_count);
21262122

21272123
for (size_t n{}; n < param_count; ++n) {
@@ -2130,6 +2126,8 @@ class Codec<borrowable::message::client::StmtExecute<Borrowed>>
21302126

21312127
types.push_back(type_res->value());
21322128
}
2129+
} else {
2130+
return stdx::make_unexpected(make_error_code(codec_errc::invalid_input));
21332131
}
21342132

21352133
const auto nullbits = nullbits_res->value();

router/src/routing/src/classic_frame.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,56 @@ class ClassicFrame {
108108
}
109109
};
110110

111+
/**
112+
* receive a StmtExecute message from a channel.
113+
*
114+
* specialization of recv_msg<> as StmtExecute needs a the data from the
115+
* StmtPrepareOk.
116+
*/
117+
template <>
118+
inline stdx::expected<classic_protocol::borrowed::message::client::StmtExecute,
119+
std::error_code>
120+
ClassicFrame::recv_msg<
121+
classic_protocol::borrowed::message::client::StmtExecute>(
122+
Channel *src_channel, ClassicProtocolState *src_protocol,
123+
classic_protocol::capabilities::value_type caps) {
124+
using msg_type = classic_protocol::borrowed::message::client::StmtExecute;
125+
126+
auto read_res = ClassicFrame::recv_frame_sequence(src_channel, src_protocol);
127+
if (!read_res) return stdx::make_unexpected(read_res.error());
128+
129+
const auto &recv_buf = src_channel->recv_plain_view();
130+
131+
auto frame_decode_res = classic_protocol::decode<
132+
classic_protocol::frame::Frame<classic_protocol::borrowed::wire::String>>(
133+
net::buffer(recv_buf), caps);
134+
if (!frame_decode_res) {
135+
return stdx::make_unexpected(frame_decode_res.error());
136+
}
137+
138+
src_protocol->seq_id(frame_decode_res->second.seq_id());
139+
140+
auto decode_res = classic_protocol::decode<msg_type>(
141+
net::buffer(frame_decode_res->second.payload().value()), caps,
142+
[src_protocol](auto stmt_id)
143+
-> stdx::expected<std::vector<msg_type::ParamDef>, std::error_code> {
144+
const auto it = src_protocol->prepared_statements().find(stmt_id);
145+
if (it == src_protocol->prepared_statements().end()) {
146+
return stdx::make_unexpected(make_error_code(
147+
classic_protocol::codec_errc::statement_id_not_found));
148+
}
149+
150+
std::vector<msg_type::ParamDef> params;
151+
params.reserve(it->second.parameters.size());
152+
153+
for (const auto &param : it->second.parameters) {
154+
params.emplace_back(param.type_and_flags);
155+
}
156+
157+
return params;
158+
});
159+
160+
return decode_res->second;
161+
}
162+
111163
#endif

router/src/routing/src/classic_stmt_execute_forwarder.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "classic_connection_base.h"
2828
#include "classic_frame.h"
29+
#include "hexify.h"
2930
#include "mysql/harness/stdx/expected.h"
3031
#include "mysql/harness/tls_error.h"
3132
#include "mysqld_error.h" // mysql-server error-codes
@@ -61,7 +62,26 @@ StmtExecuteForwarder::process() {
6162
stdx::expected<Processor::Result, std::error_code>
6263
StmtExecuteForwarder::command() {
6364
if (auto &tr = tracer()) {
64-
tr.trace(Tracer::Event().stage("stmt_execute::command"));
65+
auto *socket_splicer = connection()->socket_splicer();
66+
auto *src_channel = socket_splicer->client_channel();
67+
auto *src_protocol = connection()->client_protocol();
68+
69+
auto msg_res = ClassicFrame::recv_msg<
70+
classic_protocol::borrowed::message::client::StmtExecute>(src_channel,
71+
src_protocol);
72+
if (!msg_res) return recv_client_failed(msg_res.error());
73+
74+
const auto &recv_buf = src_channel->recv_plain_view();
75+
76+
tr.trace(Tracer::Event().stage(
77+
"stmt_execute::command:\nstmt-id: " + //
78+
std::to_string(msg_res->statement_id()) + "\n" + //
79+
"flags: " + msg_res->flags().to_string() + "\n" + //
80+
"new-params-bound: " + std::to_string(msg_res->new_params_bound()) +
81+
"\n" + //
82+
"types::size(): " + std::to_string(msg_res->types().size()) + "\n" +
83+
"values::size(): " + std::to_string(msg_res->values().size()) + "\n" +
84+
mysql_harness::hexify(recv_buf)));
6585
}
6686

6787
auto &server_conn = connection()->socket_splicer()->server_conn();

router/tests/integration/test_routing_direct.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,17 @@ TEST_P(ConnectionTest, classic_protocol_prepare_execute) {
21612161
};
21622162
ASSERT_NO_ERROR(stmt.bind_params(params));
21632163

2164+
// execute again to trigger a StmtExecute with new-params-bound = 1.
2165+
{
2166+
auto exec_res = stmt.execute();
2167+
ASSERT_NO_ERROR(exec_res);
2168+
2169+
for ([[maybe_unused]] auto res : *exec_res) {
2170+
// drain the resultsets.
2171+
}
2172+
}
2173+
2174+
// execute again to trigger a StmtExecute with new-params-bound = 0.
21642175
{
21652176
auto exec_res = stmt.execute();
21662177
ASSERT_NO_ERROR(exec_res);
@@ -2174,7 +2185,7 @@ TEST_P(ConnectionTest, classic_protocol_prepare_execute) {
21742185
auto events_res = changed_event_counters(cli);
21752186
ASSERT_NO_ERROR(events_res);
21762187

2177-
EXPECT_THAT(*events_res, ElementsAre(Pair("statement/com/Execute", 1),
2188+
EXPECT_THAT(*events_res, ElementsAre(Pair("statement/com/Execute", 2),
21782189
Pair("statement/com/Prepare", 1)));
21792190
}
21802191

@@ -2186,7 +2197,7 @@ TEST_P(ConnectionTest, classic_protocol_prepare_execute) {
21862197
ASSERT_NO_ERROR(events_res);
21872198

21882199
EXPECT_THAT(*events_res,
2189-
ElementsAre(Pair("statement/com/Execute", 1),
2200+
ElementsAre(Pair("statement/com/Execute", 2),
21902201
Pair("statement/com/Prepare", 1),
21912202
// explicit
21922203
Pair("statement/com/Reset Connection", 1),

0 commit comments

Comments
 (0)