Skip to content

Commit 6670ed9

Browse files
committed
Bug#36072058 simplify MysqlRoutingClassicConnection class
MysqlRoutingClassicConnectionBase contains a ProtocolSplicerBase class which wraps a client-side and server-side connection. That class doesn't offer anything extra and can be removed. Change ====== - moved ClassicProtocolState into its own header - split ClassicProtocolState into client-side and server-side - templatized TlsSwitchableConnection for the Protocol-class to avoid heap allocation for std::make_unique<ProtocolStateBase> - embed Channel directly into TlsSwitchableConnection instead of std::unqiue_ptr<Channel> - embed TlsSwitchableConnection directly in MysqlRoutingClassicConnectionBase instead of std::unique_ptr<TlsSwitchableConnection> - removed now unused ProtocolStateBase - removed now unused ProtocolSplicerBase - added "connection" based recv_msg<>, send_msg<> and used them where possible ... which leads to a lot of changes like: - refer to Channel and ProtocolState by 'ref' instead of 'pointer' - replace all calls through socket_splicer()->{method} by calls to {method} Change-Id: I947bd95deb28f5bf5a185b2addfe2388dd385924
1 parent 2e364cf commit 6670ed9

File tree

65 files changed

+2759
-3087
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2759
-3087
lines changed

router/src/routing/src/await_client_or_server.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ AwaitClientOrServerProcessor::init() {
6565
*/
6666
stdx::expected<Processor::Result, std::error_code>
6767
AwaitClientOrServerProcessor::wait_both() {
68-
auto *socket_splicer = connection()->socket_splicer();
69-
7068
switch (connection()->recv_from_either()) {
7169
case MysqlRoutingClassicConnectionBase::FromEither::RecvedFromServer: {
7270
// server side sent something.
@@ -76,7 +74,7 @@ AwaitClientOrServerProcessor::wait_both() {
7674

7775
stage(Stage::WaitClientCancelled);
7876

79-
(void)socket_splicer->client_conn().cancel();
77+
(void)connection()->client_conn().cancel();
8078

8179
// end this execution branch.
8280
return Result::Void;
@@ -88,7 +86,7 @@ AwaitClientOrServerProcessor::wait_both() {
8886
// - read from client in ::wait_server_cancelled
8987
stage(Stage::WaitServerCancelled);
9088

91-
(void)socket_splicer->server_conn().cancel();
89+
(void)connection()->server_conn().cancel();
9290

9391
// end this execution branch.
9492
return Result::Void;

router/src/routing/src/basic_protocol_splicer.h

Lines changed: 20 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,6 @@ using TcpConnection = BasicConnection<net::ip::tcp>;
287287
using UnixDomainConnection = BasicConnection<local::stream_protocol>;
288288
#endif
289289

290-
class ProtocolStateBase {
291-
public:
292-
virtual ~ProtocolStateBase() = default;
293-
};
294-
295290
/**
296291
* a Connection that can be switched to TLS.
297292
*
@@ -302,6 +297,7 @@ class ProtocolStateBase {
302297
* - a tls switchable (a SSL_CTX * wrapper)
303298
* - protocol state (classic, xproto)
304299
*/
300+
template <class T>
305301
class TlsSwitchableConnection {
306302
public:
307303
// 16kb per buffer
@@ -310,29 +306,29 @@ class TlsSwitchableConnection {
310306
// 10000 connections
311307
// = 640MByte
312308
static constexpr size_t kRecvBufferSize{16UL * 1024};
309+
using protocol_state_type = T;
313310

314311
TlsSwitchableConnection(std::unique_ptr<ConnectionBase> conn,
315312
std::unique_ptr<RoutingConnectionBase> routing_conn,
316-
SslMode ssl_mode,
317-
std::unique_ptr<ProtocolStateBase> state)
313+
SslMode ssl_mode, protocol_state_type state)
318314
: conn_{std::move(conn)},
319315
routing_conn_{std::move(routing_conn)},
320316
ssl_mode_{std::move(ssl_mode)},
321-
channel_{std::make_unique<Channel>()},
317+
channel_{},
322318
protocol_{std::move(state)} {
323-
channel_->recv_buffer().reserve(kRecvBufferSize);
319+
channel_.recv_buffer().reserve(kRecvBufferSize);
324320
}
325321

326322
TlsSwitchableConnection(std::unique_ptr<ConnectionBase> conn,
327323
std::unique_ptr<RoutingConnectionBase> routing_conn,
328-
SslMode ssl_mode, std::unique_ptr<Channel> channel,
329-
std::unique_ptr<ProtocolStateBase> state)
324+
SslMode ssl_mode, Channel channel,
325+
protocol_state_type state)
330326
: conn_{std::move(conn)},
331327
routing_conn_{std::move(routing_conn)},
332328
ssl_mode_{std::move(ssl_mode)},
333329
channel_{std::move(channel)},
334330
protocol_{std::move(state)} {
335-
channel_->recv_buffer().reserve(kRecvBufferSize);
331+
channel_.recv_buffer().reserve(kRecvBufferSize);
336332
}
337333

338334
[[nodiscard]] std::vector<std::pair<std::string, std::string>>
@@ -355,15 +351,14 @@ class TlsSwitchableConnection {
355351
template <class Func>
356352
void async_recv(Func &&func) {
357353
harness_assert(conn_ != nullptr);
358-
harness_assert(channel_ != nullptr);
359354

360355
// discard everything that has been marked as 'consumed'
361-
channel_->view_discard_raw();
356+
channel_.view_discard_raw();
362357

363-
conn_->async_recv(channel_->recv_buffer(),
358+
conn_->async_recv(channel_.recv_buffer(),
364359
[this, func = std::forward<Func>(func)](
365360
std::error_code ec, size_t transferred) {
366-
channel_->view_sync_raw();
361+
channel_.view_sync_raw();
367362

368363
func(ec, transferred);
369364
});
@@ -376,7 +371,7 @@ class TlsSwitchableConnection {
376371
*/
377372
template <class Func>
378373
void async_send(Func &&func) {
379-
conn_->async_send(channel_->send_buffer(), std::forward<Func>(func));
374+
conn_->async_send(channel_.send_buffer(), std::forward<Func>(func));
380375
}
381376

382377
/**
@@ -394,9 +389,9 @@ class TlsSwitchableConnection {
394389
conn_->async_wait_error(std::forward<Func>(func));
395390
}
396391

397-
[[nodiscard]] Channel *channel() { return channel_.get(); }
392+
[[nodiscard]] Channel &channel() { return channel_; }
398393

399-
[[nodiscard]] const Channel *channel() const { return channel_.get(); }
394+
[[nodiscard]] const Channel &channel() const { return channel_; }
400395

401396
[[nodiscard]] SslMode ssl_mode() const { return ssl_mode_; }
402397

@@ -443,10 +438,10 @@ class TlsSwitchableConnection {
443438
return conn_->cancel();
444439
}
445440

446-
[[nodiscard]] ProtocolStateBase *protocol() { return protocol_.get(); }
441+
[[nodiscard]] protocol_state_type &protocol() { return protocol_; }
447442

448-
[[nodiscard]] const ProtocolStateBase *protocol() const {
449-
return protocol_.get();
443+
[[nodiscard]] const protocol_state_type &protocol() const {
444+
return protocol_;
450445
}
451446

452447
std::unique_ptr<ConnectionBase> &connection() { return conn_; }
@@ -458,7 +453,7 @@ class TlsSwitchableConnection {
458453
* - if transport is secure, the channel is secure
459454
*/
460455
[[nodiscard]] bool is_secure_transport() const {
461-
return conn_->is_secure_transport() || channel_->ssl();
456+
return conn_->is_secure_transport() || (channel_.ssl() != nullptr);
462457
}
463458

464459
private:
@@ -469,117 +464,10 @@ class TlsSwitchableConnection {
469464
SslMode ssl_mode_;
470465

471466
// socket buffers
472-
std::unique_ptr<Channel> channel_;
467+
Channel channel_;
473468

474469
// higher-level protocol
475-
std::unique_ptr<ProtocolStateBase> protocol_;
476-
};
477-
478-
/**
479-
* splices two connections together.
480-
*/
481-
class ProtocolSplicerBase {
482-
public:
483-
ProtocolSplicerBase(TlsSwitchableConnection client_conn,
484-
TlsSwitchableConnection server_conn)
485-
: client_conn_{std::move(client_conn)},
486-
server_conn_{std::move(server_conn)} {}
487-
488-
template <class Func>
489-
void async_wait_send_server(Func &&func) {
490-
server_conn_.async_wait_send(std::forward<Func>(func));
491-
}
492-
493-
template <class Func>
494-
void async_recv_server(Func &&func) {
495-
server_conn_.async_recv(std::forward<Func>(func));
496-
}
497-
498-
template <class Func>
499-
void async_send_server(Func &&func) {
500-
server_conn_.async_send(std::forward<Func>(func));
501-
}
502-
503-
template <class Func>
504-
void async_recv_client(Func &&func) {
505-
client_conn_.async_recv(std::forward<Func>(func));
506-
}
507-
508-
template <class Func>
509-
void async_send_client(Func &&func) {
510-
client_conn_.async_send(std::forward<Func>(func));
511-
}
512-
513-
template <class Func>
514-
void async_client_wait_error(Func &&func) {
515-
client_conn_.async_wait_error(std::forward<Func>(func));
516-
}
517-
518-
[[nodiscard]] TlsSwitchableConnection &client_conn() { return client_conn_; }
519-
520-
[[nodiscard]] const TlsSwitchableConnection &client_conn() const {
521-
return client_conn_;
522-
}
523-
524-
[[nodiscard]] TlsSwitchableConnection &server_conn() { return server_conn_; }
525-
526-
[[nodiscard]] const TlsSwitchableConnection &server_conn() const {
527-
return server_conn_;
528-
}
529-
530-
[[nodiscard]] SslMode source_ssl_mode() const {
531-
return client_conn().ssl_mode();
532-
}
533-
534-
[[nodiscard]] SslMode dest_ssl_mode() const {
535-
return server_conn().ssl_mode();
536-
}
537-
538-
[[nodiscard]] Channel *client_channel() { return client_conn().channel(); }
539-
540-
[[nodiscard]] const Channel *client_channel() const {
541-
return client_conn().channel();
542-
}
543-
544-
[[nodiscard]] Channel *server_channel() { return server_conn().channel(); }
545-
546-
/**
547-
* accept a TLS connection from the client_channel_.
548-
*/
549-
[[nodiscard]] stdx::expected<void, std::error_code> tls_accept() {
550-
// write socket data to SSL struct
551-
auto *channel = client_conn_.channel();
552-
553-
{
554-
const auto flush_res = channel->flush_from_recv_buf();
555-
if (!flush_res) return stdx::unexpected(flush_res.error());
556-
}
557-
558-
if (!channel->tls_init_is_finished()) {
559-
const auto res = channel->tls_accept();
560-
561-
// flush the TLS message to the send-buffer.
562-
{
563-
const auto flush_res = channel->flush_to_send_buf();
564-
if (!flush_res) {
565-
const auto ec = flush_res.error();
566-
if (ec != make_error_code(std::errc::operation_would_block)) {
567-
return stdx::unexpected(flush_res.error());
568-
}
569-
}
570-
}
571-
572-
if (!res) {
573-
return stdx::unexpected(res.error());
574-
}
575-
}
576-
577-
return {};
578-
}
579-
580-
protected:
581-
TlsSwitchableConnection client_conn_;
582-
TlsSwitchableConnection server_conn_;
470+
protocol_state_type protocol_;
583471
};
584472

585473
#endif

router/src/routing/src/classic_auth_caching_sha2.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ std::optional<std::string> AuthCachingSha2Password::scramble(
3939

4040
stdx::expected<size_t, std::error_code>
4141
AuthCachingSha2Password::send_public_key_request(
42-
Channel *dst_channel, ClassicProtocolState *dst_protocol) {
42+
Channel &dst_channel, ClassicProtocolState &dst_protocol) {
4343
return ClassicFrame::send_msg(
4444
dst_channel, dst_protocol,
4545
classic_protocol::borrowed::message::client::AuthMethodData{
4646
kPublicKeyRequest});
4747
}
4848

4949
stdx::expected<size_t, std::error_code>
50-
AuthCachingSha2Password::send_public_key(Channel *dst_channel,
51-
ClassicProtocolState *dst_protocol,
50+
AuthCachingSha2Password::send_public_key(Channel &dst_channel,
51+
ClassicProtocolState &dst_protocol,
5252
const std::string &public_key) {
5353
return ClassicFrame::send_msg(
5454
dst_channel, dst_protocol,
@@ -57,7 +57,7 @@ AuthCachingSha2Password::send_public_key(Channel *dst_channel,
5757

5858
stdx::expected<size_t, std::error_code>
5959
AuthCachingSha2Password::send_plaintext_password_request(
60-
Channel *dst_channel, ClassicProtocolState *dst_protocol) {
60+
Channel &dst_channel, ClassicProtocolState &dst_protocol) {
6161
return ClassicFrame::send_msg(
6262
dst_channel, dst_protocol,
6363
classic_protocol::borrowed::message::server::AuthMethodData{
@@ -66,7 +66,7 @@ AuthCachingSha2Password::send_plaintext_password_request(
6666

6767
stdx::expected<size_t, std::error_code>
6868
AuthCachingSha2Password::send_plaintext_password(
69-
Channel *dst_channel, ClassicProtocolState *dst_protocol,
69+
Channel &dst_channel, ClassicProtocolState &dst_protocol,
7070
const std::string &password) {
7171
return ClassicFrame::send_msg(
7272
dst_channel, dst_protocol,
@@ -76,7 +76,7 @@ AuthCachingSha2Password::send_plaintext_password(
7676

7777
stdx::expected<size_t, std::error_code>
7878
AuthCachingSha2Password::send_encrypted_password(
79-
Channel *dst_channel, ClassicProtocolState *dst_protocol,
79+
Channel &dst_channel, ClassicProtocolState &dst_protocol,
8080
const std::string &encrypted) {
8181
return ClassicFrame::send_msg(
8282
dst_channel, dst_protocol,

router/src/routing/src/classic_auth_caching_sha2.h

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include <openssl/ssl.h>
3232

33+
#include "basic_protocol_splicer.h"
3334
#include "classic_auth.h"
3435
#include "classic_connection_base.h"
3536
#include "mysql/harness/stdx/expected.h"
@@ -49,24 +50,48 @@ class AuthCachingSha2Password : public AuthBase {
4950
std::string_view pwd);
5051

5152
static stdx::expected<size_t, std::error_code> send_public_key_request(
52-
Channel *dst_channel, ClassicProtocolState *dst_protocol);
53+
Channel &dst_channel, ClassicProtocolState &dst_protocol);
54+
55+
template <class Proto>
56+
static stdx::expected<size_t, std::error_code> send_public_key_request(
57+
TlsSwitchableConnection<Proto> &conn) {
58+
return send_public_key_request(conn.channel(), conn.protocol());
59+
}
5360

5461
static stdx::expected<size_t, std::error_code> send_public_key(
55-
Channel *dst_channel, ClassicProtocolState *dst_protocol,
62+
Channel &dst_channel, ClassicProtocolState &dst_protocol,
5663
const std::string &public_key);
5764

65+
template <class Proto>
66+
static stdx::expected<size_t, std::error_code> send_public_key(
67+
TlsSwitchableConnection<Proto> &conn, const std::string &public_key) {
68+
return send_public_key(conn.channel(), conn.protocol(), public_key);
69+
}
70+
5871
static stdx::expected<size_t, std::error_code>
59-
send_plaintext_password_request(Channel *dst_channel,
60-
ClassicProtocolState *dst_protocol);
72+
send_plaintext_password_request(Channel &dst_channel,
73+
ClassicProtocolState &dst_protocol);
6174

6275
static stdx::expected<size_t, std::error_code> send_plaintext_password(
63-
Channel *dst_channel, ClassicProtocolState *dst_protocol,
76+
Channel &dst_channel, ClassicProtocolState &dst_protocol,
6477
const std::string &password);
6578

79+
template <class Proto>
80+
static stdx::expected<size_t, std::error_code> send_plaintext_password(
81+
TlsSwitchableConnection<Proto> &conn, const std::string &password) {
82+
return send_plaintext_password(conn.channel(), conn.protocol(), password);
83+
}
84+
6685
static stdx::expected<size_t, std::error_code> send_encrypted_password(
67-
Channel *dst_channel, ClassicProtocolState *dst_protocol,
86+
Channel &dst_channel, ClassicProtocolState &dst_protocol,
6887
const std::string &password);
6988

89+
template <class Proto>
90+
static stdx::expected<size_t, std::error_code> send_encrypted_password(
91+
TlsSwitchableConnection<Proto> &conn, const std::string &password) {
92+
return send_encrypted_password(conn.channel(), conn.protocol(), password);
93+
}
94+
7095
static bool is_public_key_request(const std::string_view &data);
7196
static bool is_public_key(const std::string_view &data);
7297
};

0 commit comments

Comments
 (0)