From 4c3ceff120c896e7865f5e86740400df4c7363cd Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 29 Apr 2024 16:54:22 +0300 Subject: [PATCH 01/17] feat: client tracking optin --- src/facade/dragonfly_connection.cc | 22 +++++++++++-- src/facade/dragonfly_connection.h | 23 ++++++++++++-- src/server/db_slice.cc | 3 +- src/server/main_service.cc | 4 ++- src/server/server_family.cc | 51 ++++++++++++++++++++++++------ src/server/server_family_test.cc | 30 ++++++++++++++++++ 6 files changed, 116 insertions(+), 17 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index c536cd27cf5..fe7fbc76a0b 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1642,13 +1642,29 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { } void Connection::SetClientTrackingSwitch(bool is_on) { - tracking_enabled_ = is_on; - if (tracking_enabled_) + tracking_info_.tracking_enabled = is_on; + if (is_on) cc_->subscriptions++; } +void Connection::SetOptin(bool optin) { + tracking_info_.optin = optin; +} + +void Connection::LastCommandIsClientCaching() { + tracking_info_.last_command = true; +} + +void Connection::UpdatePrevAndLastCommand() { + tracking_info_.prev_command = std::exchange(tracking_info_.last_command, false); +} + bool Connection::IsTrackingOn() const { - return tracking_enabled_; + return tracking_info_.tracking_enabled; +} + +bool Connection::ShouldTrackKeys() const { + return !tracking_info_.optin || tracking_info_.prev_command; } void Connection::StartTrafficLogging(string_view path) { diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index c6a0c5d4a0b..386f709fb11 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -294,8 +294,16 @@ class Connection : public util::Connection { void SetClientTrackingSwitch(bool is_on); + void LastCommandIsClientCaching(); + + void UpdatePrevAndLastCommand(); + + void SetOptin(bool optin); + bool IsTrackingOn() const; + bool ShouldTrackKeys() const; + // Starts traffic logging in the calling thread. Must be a proactor thread. // Each thread creates its own log file combining requests from all the connections in // that thread. A noop if the thread is already logging. @@ -444,8 +452,19 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; - // a flag indicating whether the client has turned on client tracking. - bool tracking_enabled_ = false; + struct ClientTracking { + // a flag indicating whether the client has turned on client tracking. + bool tracking_enabled = false; + bool optin = false; + // remember if CLIENT CACHING TRUE was the last command + // true if prev command was CLIENT CACHING TRUE + bool prev_command = false; + // true if last command was CLIENT CACHING TRUE + bool last_command = false; + }; + + ClientTracking tracking_info_; + bool skip_next_squashing_ = false; // Forcefully skip next squashing // Connection migration vars, see RequestAsyncMigration() above. diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 12382f123d0..b28543ee6e5 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1394,7 +1394,8 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { if (it != client_tracking_map_.end()) { // notify all the clients. auto& client_set = it->second; - auto cb = [key, client_set = std::move(client_set)](unsigned idx, util::ProactorBase*) { + auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, + util::ProactorBase*) { for (auto it = client_set.begin(); it != client_set.end(); ++it) { if ((unsigned int)it->Thread() != idx) continue; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 37ba88f1bd0..b47af56807d 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1251,7 +1251,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // if this is a read command, and client tracking has enabled, // start tracking all the updates to the keys in this read command if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && - cid->IsTransactional()) { + dfly_cntx->conn()->ShouldTrackKeys() && cid->IsTransactional()) { facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id())); @@ -1260,6 +1260,8 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->transaction->ScheduleSingleHopT(cb); } + cntx->conn()->UpdatePrevAndLastCommand(); + if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } diff --git a/src/server/server_family.cc b/src/server/server_family.cc index dd8e7321272..ac9eb9a0bd9 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -450,29 +450,58 @@ void ClientPauseCmd(CmdArgList args, vector listeners, Connec } void ClientTracking(CmdArgList args, ConnectionContext* cntx) { - if (args.size() != 1) - return cntx->SendError(kSyntaxErr); - auto* rb = static_cast(cntx->reply_builder()); if (!rb->IsResp3()) return cntx->SendError( "Client tracking is currently not supported for RESP2. Please use RESP3."); - ToUpper(&args[0]); - string_view state = ArgS(args, 0); - bool is_on; - if (state == "ON") { + CmdArgParser parser{args}; + if (!parser.HasAtLeast(1) || args.size() > 2) + return cntx->SendError(kSyntaxErr); + + bool is_on = false; + bool optin = false; + if (parser.Check("ON").IgnoreCase()) { is_on = true; - } else if (state == "OFF") { - is_on = false; - } else { + } else if (!parser.Check("OFF").IgnoreCase()) { return cntx->SendError(kSyntaxErr); } + if (parser.HasNext()) { + if (parser.Check("OPTIN").IgnoreCase()) { + optin = true; + } else { + return cntx->SendError(kSyntaxErr); + } + } + cntx->conn()->SetClientTrackingSwitch(is_on); + cntx->conn()->SetOptin(optin); return cntx->SendOk(); } +void ClientCaching(CmdArgList args, ConnectionContext* cntx) { + auto* rb = static_cast(cntx->reply_builder()); + if (!rb->IsResp3()) + return cntx->SendError( + "Client caching is currently not supported for RESP2. Please use RESP3."); + + if (args.size() != 1) { + return cntx->SendError(kSyntaxErr); + } + + CmdArgParser parser{args}; + if (parser.Check("TRUE").IgnoreCase()) { + cntx->conn()->LastCommandIsClientCaching(); + } else if (!parser.Check("FALSE").IgnoreCase()) { + return cntx->SendError(kSyntaxErr); + } + + return cntx->SendError( + "Client caching is currently not properly supported. Use CLIENT CACHING TRUE with CLIENT " + "TRACKING OPTIN only"); +} + void ClientKill(CmdArgList args, absl::Span listeners, ConnectionContext* cntx) { std::function evaluator; @@ -1590,6 +1619,8 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return ClientTracking(sub_args, cntx); } else if (sub_cmd == "KILL") { return ClientKill(sub_args, absl::MakeSpan(listeners_), cntx); + } else if (sub_cmd == "CACHING") { + return ClientCaching(sub_args, cntx); } if (sub_cmd == "SETINFO") { diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 88708f4faca..7c46f77e441 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -229,6 +229,36 @@ TEST_F(ServerFamilyTest, ClientTrackingReadKey) { EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); } +TEST_F(ServerFamilyTest, ClientTrackingOptin) { + Run({"HELLO", "3"}); + Run({"CLIENT", "TRACKING", "ON", "OPTIN"}); + + Run({"GET", "FOO"}); + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + + Run({"CLIENT", "CACHING", "TRUE"}); + // Start tracking once + Run({"GET", "FOO"}); + Run({"SET", "FOO", "20"}); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 1); + + Run({"GET", "BAR"}); + Run({"SET", "BAR", "20"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 1); + + // Start tracking once + Run({"CLIENT", "CACHING", "TRUE"}); + Run({"GET", "BAR"}); + Run({"SET", "BAR", "20"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); +} + TEST_F(ServerFamilyTest, ClientTrackingUpdateKey) { Run({"HELLO", "3"}); Run({"CLIENT", "TRACKING", "ON"}); From 76444b33b7f5fd39b77f3de766bf1637ed6a53e5 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 29 Apr 2024 19:05:14 +0300 Subject: [PATCH 02/17] fix CLIENT CACHING command wrong args --- src/server/main_service.cc | 4 +++- src/server/server_family.cc | 8 +++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index b47af56807d..f46f0b908df 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1260,7 +1260,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->transaction->ScheduleSingleHopT(cb); } - cntx->conn()->UpdatePrevAndLastCommand(); + if (cntx->conn()) { + cntx->conn()->UpdatePrevAndLastCommand(); + } if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; diff --git a/src/server/server_family.cc b/src/server/server_family.cc index ac9eb9a0bd9..f4df7c307fc 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -491,15 +491,13 @@ void ClientCaching(CmdArgList args, ConnectionContext* cntx) { } CmdArgParser parser{args}; - if (parser.Check("TRUE").IgnoreCase()) { + if (parser.Check("YES").IgnoreCase()) { cntx->conn()->LastCommandIsClientCaching(); - } else if (!parser.Check("FALSE").IgnoreCase()) { + } else if (!parser.Check("NO").IgnoreCase()) { return cntx->SendError(kSyntaxErr); } - return cntx->SendError( - "Client caching is currently not properly supported. Use CLIENT CACHING TRUE with CLIENT " - "TRACKING OPTIN only"); + cntx->SendOk(); } void ClientKill(CmdArgList args, absl::Span listeners, ConnectionContext* cntx) { From 7a542889cc8d0557538dd65b320c364785114c4c Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 29 Apr 2024 19:26:28 +0300 Subject: [PATCH 03/17] fix tests --- src/server/server_family_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 7c46f77e441..bc40bb84815 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -239,7 +239,7 @@ TEST_F(ServerFamilyTest, ClientTrackingOptin) { Run({"GET", "FOO"}); EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); - Run({"CLIENT", "CACHING", "TRUE"}); + Run({"CLIENT", "CACHING", "YES"}); // Start tracking once Run({"GET", "FOO"}); Run({"SET", "FOO", "20"}); @@ -252,7 +252,7 @@ TEST_F(ServerFamilyTest, ClientTrackingOptin) { EXPECT_EQ(InvalidationMessagesLen("IO0"), 1); // Start tracking once - Run({"CLIENT", "CACHING", "TRUE"}); + Run({"CLIENT", "CACHING", "YES"}); Run({"GET", "BAR"}); Run({"SET", "BAR", "20"}); Run({"GET", "BAR"}); From 8390b045584ac73581836cf259f2cd72ef7a2424 Mon Sep 17 00:00:00 2001 From: kostas Date: Wed, 1 May 2024 22:34:05 +0300 Subject: [PATCH 04/17] refactor client tracking, fix atomicity, squashing and multi/exec --- src/facade/dragonfly_connection.cc | 35 +++--------- src/facade/dragonfly_connection.h | 27 +--------- src/server/conn_context.cc | 85 +++++++++++++++++++++++++++++- src/server/conn_context.h | 80 +++++++++++++++++++++++++++- src/server/db_slice.cc | 41 +++++++------- src/server/main_service.cc | 57 +++++++------------- src/server/server_family.cc | 11 ++-- src/server/server_family_test.cc | 56 ++++++++++++++++++++ src/server/transaction.cc | 14 ++++- src/server/transaction.h | 15 ++++++ 10 files changed, 300 insertions(+), 121 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index fe7fbc76a0b..1f27ebf0006 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -254,8 +254,9 @@ thread_local vector Connection::pipeline_req_poo thread_local Connection::QueueBackpressure Connection::tl_queue_backpressure_; void Connection::QueueBackpressure::EnsureBelowLimit() { - ec.await( - [this] { return subscriber_bytes.load(memory_order_relaxed) <= subscriber_thread_limit; }); + ec.await([this] { + return done || subscriber_bytes.load(memory_order_relaxed) <= subscriber_thread_limit; + }); } struct Connection::Shutdown { @@ -885,6 +886,8 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { // After the client disconnected. cc_->conn_closing = true; // Signal dispatch to close. evc_.notify(); + queue_backpressure_->done = true; + queue_backpressure_->ec.notify(); phase_ = SHUTTING_DOWN; VLOG(2) << "Before dispatch_fb.join()"; @@ -1114,7 +1117,7 @@ void Connection::HandleMigrateRequest() { this->Migrate(dest); } - DCHECK(dispatch_q_.empty()); + // DCHECK(dispatch_q_.empty()); // In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started. LaunchDispatchFiberIfNeeded(); @@ -1641,32 +1644,6 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { migration_request_ = dest; } -void Connection::SetClientTrackingSwitch(bool is_on) { - tracking_info_.tracking_enabled = is_on; - if (is_on) - cc_->subscriptions++; -} - -void Connection::SetOptin(bool optin) { - tracking_info_.optin = optin; -} - -void Connection::LastCommandIsClientCaching() { - tracking_info_.last_command = true; -} - -void Connection::UpdatePrevAndLastCommand() { - tracking_info_.prev_command = std::exchange(tracking_info_.last_command, false); -} - -bool Connection::IsTrackingOn() const { - return tracking_info_.tracking_enabled; -} - -bool Connection::ShouldTrackKeys() const { - return !tracking_info_.optin || tracking_info_.prev_command; -} - void Connection::StartTrafficLogging(string_view path) { OpenTrafficLogger(path); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 386f709fb11..cdfbe320c1e 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -292,18 +292,6 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); - void SetClientTrackingSwitch(bool is_on); - - void LastCommandIsClientCaching(); - - void UpdatePrevAndLastCommand(); - - void SetOptin(bool optin); - - bool IsTrackingOn() const; - - bool ShouldTrackKeys() const; - // Starts traffic logging in the calling thread. Must be a proactor thread. // Each thread creates its own log file combining requests from all the connections in // that thread. A noop if the thread is already logging. @@ -341,6 +329,8 @@ class Connection : public util::Connection { size_t subscriber_thread_limit = 0; // cached flag subscriber_thread_limit size_t pipeline_cache_limit = 0; // cached flag pipeline_cache_limit + // cancelation flag + bool done = false; }; private: @@ -452,19 +442,6 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; - struct ClientTracking { - // a flag indicating whether the client has turned on client tracking. - bool tracking_enabled = false; - bool optin = false; - // remember if CLIENT CACHING TRUE was the last command - // true if prev command was CLIENT CACHING TRUE - bool prev_command = false; - // true if last command was CLIENT CACHING TRUE - bool last_command = false; - }; - - ClientTracking tracking_info_; - bool skip_next_squashing_ = false; // Forcefully skip next squashing // Connection migration vars, see RequestAsyncMigration() above. diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 011793636f4..79a53d18045 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -88,9 +88,9 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own acl_commands = std::vector(acl::NumberOfFamilies(), acl::ALL_COMMANDS); } -ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx, +ConnectionContext::ConnectionContext(ConnectionContext* owner, Transaction* tx, facade::CapturingReplyBuilder* crb) - : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { + : facade::ConnectionContext(nullptr, nullptr), transaction{tx}, parent_cntx_(owner) { acl_commands = std::vector(acl::NumberOfFamilies(), acl::ALL_COMMANDS); if (tx) { // If we have a carrier transaction, this context is used for squashing DCHECK(owner); @@ -119,6 +119,13 @@ void ConnectionContext::ChangeMonitor(bool start) { EnableMonitoring(start); } +ConnectionState::ClientTracking& ConnectionContext::ClientTrackingInfo() { + if (parent_cntx_) { + return parent_cntx_->conn_state.tracking_info_; + } + return conn_state.tracking_info_; +} + vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply, ConnectionContext* conn) { vector result(to_reply ? args.size() : 0, 0); @@ -265,4 +272,78 @@ void ConnectionState::ExecInfo::ClearWatched() { watched_existed = 0; } +void ConnectionState::ClientTracking::SetClientTracking(bool is_on) { + tracking_enabled_ = is_on; +} + +void ConnectionState::ClientTracking::TrackClientCaching() { + executing_command_ = true; +} + +void ConnectionState::ClientTracking::UpdatePrevAndLastCommand() { + if (prev_command_ && multi_) { + return; + } + prev_command_ = std::exchange(executing_command_, false); +} + +void ConnectionState::ClientTracking::SetOptin(bool optin) { + optin_ = optin; +} + +void ConnectionState::ClientTracking::SetMulti(bool multi) { + multi_ = multi; +} + +bool ConnectionState::ClientTracking::IsTrackingOn() const { + return tracking_enabled_; +} + +bool ConnectionState::ClientTracking::ShouldTrackKeys() const { + if (!IsTrackingOn()) { + return false; + } + + return !optin_ || prev_command_; +} + +OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::WeakRef& conn_ref, + const ShardArgs& args) { + if (conn_ref.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return OpStatus::OK; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() + << " with thread ID: " << conn_ref.Thread(); + + auto& db_slice = slice_args.shard->db_slice(); + // TODO: There is a bug here that we track all arguments instead of tracking only keys. + for (auto key : args) { + DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() + << " into the tracking client set of key " << key; + db_slice.TrackKey(conn_ref, key); + } + + return OpStatus::OK; +} + +void ConnectionState::ClientTracking::Track(ConnectionContext* cntx, const CommandId* cid) { + auto& info = cntx->ClientTrackingInfo(); + auto shards = cntx->transaction->GetActiveShards(); + if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { + if (cntx->parent_cntx_) { + } + auto conn = cntx->parent_cntx_ ? cntx->parent_cntx_->conn()->Borrow() : cntx->conn()->Borrow(); + auto cb = [&, conn](unsigned i, auto* pb) { + if (shards.find(i) != shards.end()) { + auto* t = cntx->transaction; + CHECK(t); + auto* shard = EngineShard::tlocal(); + OpTrackKeys(t->GetOpArgs(shard), conn, t->GetShardArgs(shard->shard_id())); + } + }; + shard_set->pool()->AwaitFiberOnAll(std::move(cb)); + } +} } // namespace dfly diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 4c007b4e135..cfdefa6fff8 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -147,6 +147,78 @@ struct ConnectionState { size_t UsedMemory() const; + // Client tracking is a per-connection state machine that adheres to the requirements + // of the CLIENT TRACKING command. Note that the semantics described below are enforced + // by the tests in server_family_test. The rules are: + // 1. If CLIENT TRACKING is ON then each READ command must be tracked. Invalidation + // messages are sent `only once`. Subsequent changes of the same key require the + // client to re-read the key in order to receive the next invalidation message. + // 2. CLIENT TRACKING ON OPTIN turns on optional tracking. Read commands are not + // tracked unless the client issues a CLIENT CACHING YES command which conditionally + // allows the tracking of the command that follows CACHING YES). For example: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> GET foo <--------------------- From now foo is being tracked + // However: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> SET foo bar + // >> GET foo <--------------------- is *NOT* tracked since GET does not succeed CACHING + // Also, in the context of multi transactions, CLIENT CACHING YES is *STICKY*: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> MULTI + // >> GET foo + // >> SET foo bar + // >> GET brother_foo + // >> EXEC + // From this point onwards `foo` and `get` keys are tracked. Same aplies if CACHING YES + // is used within the MULTI/EXEC block. + // + // The state machine implements the above rules. We need to track two commands at each time: + // 1. The command invoked previously. + // 2. The command that is invoked now (via InvokeCmd). + // Which is tracked by current_command_ and prev_command_ respectively. When CACHING YES + // is invoked the current_command_ is set to true which is later moved to the prev_command_ + // when the next command is invoked. This is needed to keep track of the different rules + // described above. Stickiness is covered similarly by the multi/exec/discard command which + // when called sets the corresponding multi_ variable to true. + class ClientTracking { + public: + // Sets to true when CLIENT TRACKING is ON + void SetClientTracking(bool is_on); + // Enable tracking on the client + void TrackClientCaching(); + + void UpdatePrevAndLastCommand(); + // Set if OPTIN subcommand is used in CLIENT TRACKING + void SetOptin(bool optin); + // When Multi command is invoked, it calls this to broadcast that we are on a multi + // transaction. + void SetMulti(bool multi); + + // Check if the keys should be tracked. Result adheres to the state machine described above. + bool ShouldTrackKeys() const; + // Check only if CLIENT TRACKING is ON + bool IsTrackingOn() const; + + // Iterates over the active shards of the transaction. If a key satisfies + // the tracking requirements, is is set for tracking. + void Track(ConnectionContext* cntx, const CommandId* cid); + + private: + // a flag indicating whether the client has turned on client tracking. + bool tracking_enabled_ = false; + bool optin_ = false; + // remember if CLIENT CACHING TRUE was the last command + // true if the previous command invoked is CLIENT CACHING TRUE + bool prev_command_ = false; + // true if the currently executing command is CLIENT CACHING TRUE + bool executing_command_ = false; + // true if we are in a multi transaction + bool multi_ = false; + }; + public: DbIndex db_index = 0; @@ -161,14 +233,14 @@ struct ConnectionState { std::optional squashing_info; std::unique_ptr script_info; std::unique_ptr subscribe_info; + ClientTracking tracking_info_; }; class ConnectionContext : public facade::ConnectionContext { public: ConnectionContext(::io::Sink* stream, facade::Connection* owner); - ConnectionContext(const ConnectionContext* owner, Transaction* tx, - facade::CapturingReplyBuilder* crb); + ConnectionContext(ConnectionContext* owner, Transaction* tx, facade::CapturingReplyBuilder* crb); struct DebugInfo { uint32_t shards_count = 0; @@ -183,6 +255,10 @@ class ConnectionContext : public facade::ConnectionContext { // TODO: to introduce proper accessors. Transaction* transaction = nullptr; const CommandId* cid = nullptr; + ConnectionContext* parent_cntx_ = nullptr; + + ConnectionState::ClientTracking& ClientTrackingInfo(); + ConnectionState conn_state; DbIndex db_index() const { diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index b28543ee6e5..87240c721d8 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1391,25 +1391,30 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { return; auto it = client_tracking_map_.find(key); - if (it != client_tracking_map_.end()) { - // notify all the clients. - auto& client_set = it->second; - auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, - util::ProactorBase*) { - for (auto it = client_set.begin(); it != client_set.end(); ++it) { - if ((unsigned int)it->Thread() != idx) - continue; - facade::Connection* conn = it->Get(); - if ((conn != nullptr) && conn->IsTrackingOn()) { - std::string key_str = {key.begin(), key.end()}; - conn->SendInvalidationMessageAsync({key_str}); - } - } - }; - shard_set->pool()->DispatchBrief(std::move(cb)); - // remove this key from the tracking table as the key no longer exists - client_tracking_map_.erase(key); + if (it == client_tracking_map_.end()) { + return; } + auto& client_set = it->second; + // notify all the clients. + auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, + util::ProactorBase*) { + for (auto& client : client_set) { + if (client.IsExpired()) { + continue; + } + if (client.Thread() != idx) { + continue; + } + auto* conn = client.Get(); + auto* cntx = static_cast(conn->cntx()); + if (cntx && cntx->ClientTrackingInfo().IsTrackingOn()) { + conn->SendInvalidationMessageAsync({key}); + } + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); + // remove this key from the tracking table as the key no longer exists + client_tracking_map_.erase(key); } void DbSlice::PerformDeletion(PrimeIterator del_it, DbTable* table) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index f46f0b908df..7359bc3a4a7 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -4,6 +4,8 @@ #include "server/main_service.h" +#include "facade/resp_expr.h" + #ifdef __FreeBSD__ #include #endif @@ -1127,28 +1129,6 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } -OpResult OpTrackKeys(const OpArgs& op_args, const facade::Connection::WeakRef& conn_ref, - const ShardArgs& args) { - if (conn_ref.IsExpired()) { - DVLOG(2) << "Connection expired, exiting TrackKey function."; - return OpStatus::OK; - } - - DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() - << " with thread ID: " << conn_ref.Thread(); - - DbSlice& db_slice = op_args.shard->db_slice(); - - // TODO: There is a bug here that we track all arguments instead of tracking only keys. - for (auto key : args) { - DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() - << " into the tracking client set of key " << key; - db_slice.TrackKey(conn_ref, key); - } - - return OpStatus::OK; -} - void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); DCHECK_NE(0u, shard_set->size()) << "Init was not called"; @@ -1206,6 +1186,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) if (stored_cmd.Cid()->IsWriteOnly()) { dfly_cntx->conn_state.exec_info.is_write = true; } + dfly_cntx->conn_state.tracking_info_.UpdatePrevAndLastCommand(); return cntx->SendSimpleString("QUEUED"); } @@ -1242,28 +1223,16 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) } dfly_cntx->cid = cid; + std::string res; + for (auto arg : args_no_cmd) { + absl::StrAppend(&res, " ", facade::ToSV(arg)); + } if (!InvokeCmd(cid, args_no_cmd, dfly_cntx)) { dfly_cntx->reply_builder()->SendError("Internal Error"); dfly_cntx->reply_builder()->CloseConnection(); } - // if this is a read command, and client tracking has enabled, - // start tracking all the updates to the keys in this read command - if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && - dfly_cntx->conn()->ShouldTrackKeys() && cid->IsTransactional()) { - facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); - auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { - return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id())); - }; - dfly_cntx->transaction->Refurbish(); - dfly_cntx->transaction->ScheduleSingleHopT(cb); - } - - if (cntx->conn()) { - cntx->conn()->UpdatePrevAndLastCommand(); - } - if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } @@ -1319,6 +1288,10 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo ServerState::tlocal()->RecordCmd(); + if (cntx->transaction) { + cntx->transaction->SetConnectionContextAndInvokeCid(cntx, cid); + } + #ifndef NDEBUG // Verifies that we reply to the client when needed. ReplyGuard reply_guard(cntx, cid->name()); @@ -1331,6 +1304,8 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo return false; } + cntx->conn_state.tracking_info_.UpdatePrevAndLastCommand(); + // TODO: we should probably discard more commands here, // not just the blocking ones const auto* conn = cntx->conn(); @@ -1616,6 +1591,7 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError("MULTI calls can not be nested"); } cntx->conn_state.exec_info.state = ConnectionState::ExecInfo::EXEC_COLLECT; + cntx->conn_state.tracking_info_.SetMulti(true); // TODO: to protect against huge exec transactions. return cntx->SendOk(); } @@ -2022,6 +1998,7 @@ void Service::Discard(CmdArgList args, ConnectionContext* cntx) { } MultiCleanup(cntx); + cntx->conn_state.tracking_info_.SetMulti(false); rb->SendOk(); } @@ -2167,6 +2144,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { // EXEC should not run if any of the watched keys expired. if (!exec_info.watched_keys.empty() && !CheckWatchedKeyExpiry(cntx, registry_)) { cntx->transaction->UnlockMulti(); + cntx->conn_state.tracking_info_.SetMulti(false); return rb->SendNull(); } @@ -2218,6 +2196,7 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "Exec unlocking " << exec_info.body.size() << " commands"; cntx->transaction->UnlockMulti(); } + cntx->conn_state.tracking_info_.SetMulti(false); cntx->cid = exec_cid_; VLOG(1) << "Exec completed"; @@ -2557,7 +2536,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) { server_family_.OnClose(server_cntx); - cntx->conn()->SetClientTrackingSwitch(false); + conn_state.tracking_info_.SetClientTracking(false); } Service::ContextInfo Service::GetContextInfo(facade::ConnectionContext* cntx) const { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index f4df7c307fc..046807f29fa 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -475,8 +475,11 @@ void ClientTracking(CmdArgList args, ConnectionContext* cntx) { } } - cntx->conn()->SetClientTrackingSwitch(is_on); - cntx->conn()->SetOptin(optin); + if (is_on) { + ++cntx->subscriptions; + } + cntx->ClientTrackingInfo().SetClientTracking(is_on); + cntx->ClientTrackingInfo().SetOptin(optin); return cntx->SendOk(); } @@ -492,7 +495,7 @@ void ClientCaching(CmdArgList args, ConnectionContext* cntx) { CmdArgParser parser{args}; if (parser.Check("YES").IgnoreCase()) { - cntx->conn()->LastCommandIsClientCaching(); + cntx->ClientTrackingInfo().TrackClientCaching(); } else if (!parser.Check("NO").IgnoreCase()) { return cntx->SendError(kSyntaxErr); } @@ -1526,7 +1529,7 @@ void ServerFamily::SendInvalidationMessages() const { facade::ConnectionContext* fc = static_cast(conn)->cntx(); if (fc) { ConnectionContext* cntx = static_cast(fc); - if (cntx->conn()->IsTrackingOn()) { + if (cntx->ClientTrackingInfo().IsTrackingOn()) { facade::Connection::InvalidationMessage x; x.invalidate_due_to_flush = true; cntx->conn()->SendInvalidationMessageAsync(x); diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index bc40bb84815..eb356de3d90 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -259,6 +259,62 @@ TEST_F(ServerFamilyTest, ClientTrackingOptin) { EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); } +TEST_F(ServerFamilyTest, ClientTrackingMulti) { + Run({"HELLO", "3"}); + Run({"CLIENT", "TRACKING", "ON"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + Run({"SET", "FOO", "10"}); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); +} + +TEST_F(ServerFamilyTest, ClientTrackingMultiOptin) { + Run({"HELLO", "3"}); + // Check stickiness + Run({"CLIENT", "TRACKING", "ON", "OPTIN"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"DISCARD"}); + + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + + Run({"CLIENT", "CACHING", "YES"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + Run({"SET", "FOO", "10"}); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + + // CACHING enclosed in MULTI + Run({"MULTI"}); + Run({"GET", "TMP"}); + Run({"SET", "TMP", "10"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"GET", "FOO"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + Run({"SET", "TMP", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 3); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 4); +} + TEST_F(ServerFamilyTest, ClientTrackingUpdateKey) { Run({"HELLO", "3"}); Run({"CLIENT", "TRACKING", "ON"}); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 3ae781eec4b..e9c8f7e9494 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -838,13 +838,23 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) { // Runs in coordinator thread. void Transaction::Execute(RunnableType cb, bool conclude) { + auto tracking_wrap = [cb, this](Transaction* t, EngineShard* shard) -> RunnableResult { + auto res = cb(t, shard); + if (cntx_) { + cntx_->ClientTrackingInfo().Track(cntx_, invoke_cid_); + } + return res; + }; + + RunnableType wrapper = tracking_wrap; + if (multi_ && multi_->role == SQUASHED_STUB) { - local_result_ = RunSquashedMultiCb(cb); + local_result_ = RunSquashedMultiCb(wrapper); return; } local_result_ = OpStatus::OK; - cb_ptr_ = &cb; + cb_ptr_ = &wrapper; if (IsAtomicMulti()) { multi_->concluding = conclude; diff --git a/src/server/transaction.h b/src/server/transaction.h index 07ffc6260b0..cb94787096d 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -20,6 +20,7 @@ #include "facade/op_status.h" #include "server/cluster/unique_slot_checker.h" #include "server/common.h" +#include "server/conn_context.h" #include "server/journal/types.h" #include "server/table.h" #include "server/tx_base.h" @@ -362,6 +363,17 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } + void SetConnectionContextAndInvokeCid(ConnectionContext* cntx, const CommandId* cid) { + cntx_ = cntx; + invoke_cid_ = cid; + } + + std::set GetActiveShards() { + std::set active_shards; + IterateActiveShards([&](const auto& sd, ShardId i) mutable { active_shards.insert(i); }); + return active_shards; + } + private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt { @@ -636,6 +648,9 @@ class Transaction { ShardId coordinator_index = 0; } stats_; + ConnectionContext* cntx_{nullptr}; + const CommandId* invoke_cid_{nullptr}; + private: struct TLTmpSpace { std::vector& GetShardIndex(unsigned size); From db25d6d26033e64955b1a92625c24a7220792297 Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 2 May 2024 13:50:22 +0300 Subject: [PATCH 05/17] remove unused code --- src/server/db_slice.cc | 5 +---- src/server/main_service.cc | 4 ---- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 87240c721d8..dfedc124de5 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1399,10 +1399,7 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, util::ProactorBase*) { for (auto& client : client_set) { - if (client.IsExpired()) { - continue; - } - if (client.Thread() != idx) { + if (client.IsExpired() || (client.Thread() != idx)) { continue; } auto* conn = client.Get(); diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 7359bc3a4a7..67700e77ed7 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1223,10 +1223,6 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) } dfly_cntx->cid = cid; - std::string res; - for (auto arg : args_no_cmd) { - absl::StrAppend(&res, " ", facade::ToSV(arg)); - } if (!InvokeCmd(cid, args_no_cmd, dfly_cntx)) { dfly_cntx->reply_builder()->SendError("Internal Error"); From 5c8b2159d9c62cd7aaafd37f5b24a7a01c3148bd Mon Sep 17 00:00:00 2001 From: kostas Date: Wed, 8 May 2024 19:42:17 +0300 Subject: [PATCH 06/17] address gh comments --- src/facade/dragonfly_connection.cc | 8 +--- src/facade/dragonfly_connection.h | 2 - src/server/CMakeLists.txt | 6 +-- src/server/conn_context.cc | 40 +------------------ src/server/conn_context.h | 63 ++++++++++++++++++------------ src/server/db_slice.cc | 2 +- src/server/main_service.cc | 23 +++++------ src/server/server_family.cc | 13 +++--- src/server/server_family_test.cc | 25 ++++++++++++ src/server/transaction.cc | 20 ++++------ src/server/transaction.h | 4 +- 11 files changed, 98 insertions(+), 108 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 1f27ebf0006..77c7953b19c 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -254,9 +254,8 @@ thread_local vector Connection::pipeline_req_poo thread_local Connection::QueueBackpressure Connection::tl_queue_backpressure_; void Connection::QueueBackpressure::EnsureBelowLimit() { - ec.await([this] { - return done || subscriber_bytes.load(memory_order_relaxed) <= subscriber_thread_limit; - }); + ec.await( + [this] { return subscriber_bytes.load(memory_order_relaxed) <= subscriber_thread_limit; }); } struct Connection::Shutdown { @@ -886,10 +885,7 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { // After the client disconnected. cc_->conn_closing = true; // Signal dispatch to close. evc_.notify(); - queue_backpressure_->done = true; - queue_backpressure_->ec.notify(); phase_ = SHUTTING_DOWN; - VLOG(2) << "Before dispatch_fb.join()"; dispatch_fb_.JoinIfNeeded(); VLOG(2) << "After dispatch_fb.join()"; diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index cdfbe320c1e..fd38ed6c8ef 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -329,8 +329,6 @@ class Connection : public util::Connection { size_t subscriber_thread_limit = 0; // cached flag subscriber_thread_limit size_t pipeline_cache_limit = 0; // cached flag pipeline_cache_limit - // cancelation flag - bool done = false; }; private: diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 57e7514511a..75345a81930 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -33,14 +33,14 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc - ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc + ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc conn_context.cc channel_store.cc ) SET(SEARCH_FILES search/search_family.cc search/doc_index.cc search/doc_accessors.cc search/aggregator.cc) -add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc channel_store.cc - config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc +add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc + config_registry.cc debugcmd.cc dflycmd.cc generic_family.cc hset_family.cc http_api.cc json_family.cc ${SEARCH_FILES} list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 79a53d18045..ea109f3f160 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -119,13 +119,6 @@ void ConnectionContext::ChangeMonitor(bool start) { EnableMonitoring(start); } -ConnectionState::ClientTracking& ConnectionContext::ClientTrackingInfo() { - if (parent_cntx_) { - return parent_cntx_->conn_state.tracking_info_; - } - return conn_state.tracking_info_; -} - vector ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply, ConnectionContext* conn) { vector result(to_reply ? args.size() : 0, 0); @@ -272,39 +265,12 @@ void ConnectionState::ExecInfo::ClearWatched() { watched_existed = 0; } -void ConnectionState::ClientTracking::SetClientTracking(bool is_on) { - tracking_enabled_ = is_on; -} - -void ConnectionState::ClientTracking::TrackClientCaching() { - executing_command_ = true; -} - -void ConnectionState::ClientTracking::UpdatePrevAndLastCommand() { - if (prev_command_ && multi_) { - return; - } - prev_command_ = std::exchange(executing_command_, false); -} - -void ConnectionState::ClientTracking::SetOptin(bool optin) { - optin_ = optin; -} - -void ConnectionState::ClientTracking::SetMulti(bool multi) { - multi_ = multi; -} - -bool ConnectionState::ClientTracking::IsTrackingOn() const { - return tracking_enabled_; -} - bool ConnectionState::ClientTracking::ShouldTrackKeys() const { if (!IsTrackingOn()) { return false; } - return !optin_ || prev_command_; + return !optin_ || (seq_num_ == (1 + caching_seq_num_)); } OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::WeakRef& conn_ref, @@ -329,11 +295,9 @@ OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::We } void ConnectionState::ClientTracking::Track(ConnectionContext* cntx, const CommandId* cid) { - auto& info = cntx->ClientTrackingInfo(); + auto& info = cntx->conn_state.tracking_info_; auto shards = cntx->transaction->GetActiveShards(); if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { - if (cntx->parent_cntx_) { - } auto conn = cntx->parent_cntx_ ? cntx->parent_cntx_->conn()->Borrow() : cntx->conn()->Borrow(); auto cb = [&, conn](unsigned i, auto* pb) { if (shards.find(i) != shards.end()) { diff --git a/src/server/conn_context.h b/src/server/conn_context.h index cfdefa6fff8..91419738945 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -175,48 +175,61 @@ struct ConnectionState { // From this point onwards `foo` and `get` keys are tracked. Same aplies if CACHING YES // is used within the MULTI/EXEC block. // - // The state machine implements the above rules. We need to track two commands at each time: - // 1. The command invoked previously. - // 2. The command that is invoked now (via InvokeCmd). - // Which is tracked by current_command_ and prev_command_ respectively. When CACHING YES - // is invoked the current_command_ is set to true which is later moved to the prev_command_ - // when the next command is invoked. This is needed to keep track of the different rules - // described above. Stickiness is covered similarly by the multi/exec/discard command which - // when called sets the corresponding multi_ variable to true. + // The state machine implements the above rules. We need to track: + // 1. If TRACKING is ON and OPTIN + // 2. Stickiness of CACHING as described above + // + // We introduce a monotonic counter called sequence number which we increment only: + // * On InvokeCmd when we are not Collecting (multi) + // We introduce another counter called caching_seq_num which is set to seq_num + // when the users sends a CLIENT CACHING YES command + // If seq_num == caching_seq_num + 1 then we know that we should Track(). class ClientTracking { public: // Sets to true when CLIENT TRACKING is ON - void SetClientTracking(bool is_on); - // Enable tracking on the client - void TrackClientCaching(); + void SetClientTracking(bool is_on) { + tracking_enabled_ = is_on; + } + + // Increment current sequence number + void IncrementSequenceNumber() { + ++seq_num_; + } - void UpdatePrevAndLastCommand(); // Set if OPTIN subcommand is used in CLIENT TRACKING - void SetOptin(bool optin); - // When Multi command is invoked, it calls this to broadcast that we are on a multi - // transaction. - void SetMulti(bool multi); + void SetOptin(bool optin) { + optin_ = optin; + } // Check if the keys should be tracked. Result adheres to the state machine described above. bool ShouldTrackKeys() const; + // Check only if CLIENT TRACKING is ON - bool IsTrackingOn() const; + bool IsTrackingOn() const { + return tracking_enabled_; + } // Iterates over the active shards of the transaction. If a key satisfies // the tracking requirements, is is set for tracking. void Track(ConnectionContext* cntx, const CommandId* cid); + // Called by CLIENT CACHING YES and caches the current seq_num_ + void SetCachingSequenceNumber(bool is_multi) { + // We need -1 when we are in multi + caching_seq_num_ = is_multi && seq_num_ != 0 ? seq_num_ - 1 : seq_num_; + } + + void ResetCachingSequenceNumber() { + caching_seq_num_ = 0; + } + private: // a flag indicating whether the client has turned on client tracking. bool tracking_enabled_ = false; bool optin_ = false; - // remember if CLIENT CACHING TRUE was the last command - // true if the previous command invoked is CLIENT CACHING TRUE - bool prev_command_ = false; - // true if the currently executing command is CLIENT CACHING TRUE - bool executing_command_ = false; - // true if we are in a multi transaction - bool multi_ = false; + // sequence number + size_t seq_num_ = 0; + size_t caching_seq_num_ = 0; }; public: @@ -257,8 +270,6 @@ class ConnectionContext : public facade::ConnectionContext { const CommandId* cid = nullptr; ConnectionContext* parent_cntx_ = nullptr; - ConnectionState::ClientTracking& ClientTrackingInfo(); - ConnectionState conn_state; DbIndex db_index() const { diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index dfedc124de5..6a5670060cb 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1404,7 +1404,7 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { } auto* conn = client.Get(); auto* cntx = static_cast(conn->cntx()); - if (cntx && cntx->ClientTrackingInfo().IsTrackingOn()) { + if (cntx && cntx->conn_state.tracking_info_.IsTrackingOn()) { conn->SendInvalidationMessageAsync({key}); } } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 67700e77ed7..a9c05b8c5c8 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1093,7 +1093,9 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA if (cmd_name == "SELECT" || absl::EndsWith(cmd_name, "SUBSCRIBE")) return ErrorReply{absl::StrCat("Can not call ", cmd_name, " within a transaction")}; - if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB") + // for some reason we get a trailing \n\r, and that's why we use StartsWith + bool client_cmd = cmd_name == "CLIENT" && !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); + if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB" || client_cmd) return ErrorReply{absl::StrCat("'", cmd_name, "' inside MULTI is not allowed")}; } @@ -1186,7 +1188,6 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) if (stored_cmd.Cid()->IsWriteOnly()) { dfly_cntx->conn_state.exec_info.is_write = true; } - dfly_cntx->conn_state.tracking_info_.UpdatePrevAndLastCommand(); return cntx->SendSimpleString("QUEUED"); } @@ -1284,8 +1285,9 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo ServerState::tlocal()->RecordCmd(); - if (cntx->transaction) { - cntx->transaction->SetConnectionContextAndInvokeCid(cntx, cid); + auto* trans = cntx->transaction; + if (trans) { + cntx->transaction->SetConnectionContextAndInvokeCid(cntx); } #ifndef NDEBUG @@ -1300,7 +1302,10 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo return false; } - cntx->conn_state.tracking_info_.UpdatePrevAndLastCommand(); + auto cid_name = cid->name(); + if ((!trans && cid_name != "MULTI") || (trans && !trans->IsMulti())) { + cntx->conn_state.tracking_info_.IncrementSequenceNumber(); + } // TODO: we should probably discard more commands here, // not just the blocking ones @@ -1587,7 +1592,6 @@ void Service::Multi(CmdArgList args, ConnectionContext* cntx) { return cntx->SendError("MULTI calls can not be nested"); } cntx->conn_state.exec_info.state = ConnectionState::ExecInfo::EXEC_COLLECT; - cntx->conn_state.tracking_info_.SetMulti(true); // TODO: to protect against huge exec transactions. return cntx->SendOk(); } @@ -1994,7 +1998,6 @@ void Service::Discard(CmdArgList args, ConnectionContext* cntx) { } MultiCleanup(cntx); - cntx->conn_state.tracking_info_.SetMulti(false); rb->SendOk(); } @@ -2140,7 +2143,6 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { // EXEC should not run if any of the watched keys expired. if (!exec_info.watched_keys.empty() && !CheckWatchedKeyExpiry(cntx, registry_)) { cntx->transaction->UnlockMulti(); - cntx->conn_state.tracking_info_.SetMulti(false); return rb->SendNull(); } @@ -2158,7 +2160,8 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { ServerState::tlocal()->exec_freq_count[descr]++; } - if (absl::GetFlag(FLAGS_multi_exec_squash) && state == ExecEvalState::NONE) { + if (absl::GetFlag(FLAGS_multi_exec_squash) && state == ExecEvalState::NONE && + !cntx->conn_state.tracking_info_.IsTrackingOn()) { MultiCommandSquasher::Execute(absl::MakeSpan(exec_info.body), cntx, this); } else { CmdArgVec arg_vec; @@ -2192,7 +2195,6 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { VLOG(1) << "Exec unlocking " << exec_info.body.size() << " commands"; cntx->transaction->UnlockMulti(); } - cntx->conn_state.tracking_info_.SetMulti(false); cntx->cid = exec_cid_; VLOG(1) << "Exec completed"; @@ -2205,7 +2207,6 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { auto* cs = ServerState::tlocal()->channel_store(); vector subscribers = cs->FetchSubscribers(channel); int num_published = subscribers.size(); - if (!subscribers.empty()) { // Make sure neither of the threads limits is reached. // This check actually doesn't reserve any memory ahead and doesn't prevent the buffer diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 046807f29fa..81768a7a5e3 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -478,8 +478,8 @@ void ClientTracking(CmdArgList args, ConnectionContext* cntx) { if (is_on) { ++cntx->subscriptions; } - cntx->ClientTrackingInfo().SetClientTracking(is_on); - cntx->ClientTrackingInfo().SetOptin(optin); + cntx->conn_state.tracking_info_.SetClientTracking(is_on); + cntx->conn_state.tracking_info_.SetOptin(optin); return cntx->SendOk(); } @@ -495,8 +495,11 @@ void ClientCaching(CmdArgList args, ConnectionContext* cntx) { CmdArgParser parser{args}; if (parser.Check("YES").IgnoreCase()) { - cntx->ClientTrackingInfo().TrackClientCaching(); - } else if (!parser.Check("NO").IgnoreCase()) { + bool is_multi = cntx->transaction && cntx->transaction->IsMulti(); + cntx->conn_state.tracking_info_.SetCachingSequenceNumber(is_multi); + } else if (parser.Check("NO").IgnoreCase()) { + cntx->conn_state.tracking_info_.ResetCachingSequenceNumber(); + } else { return cntx->SendError(kSyntaxErr); } @@ -1529,7 +1532,7 @@ void ServerFamily::SendInvalidationMessages() const { facade::ConnectionContext* fc = static_cast(conn)->cntx(); if (fc) { ConnectionContext* cntx = static_cast(fc); - if (cntx->ClientTrackingInfo().IsTrackingOn()) { + if (cntx->conn_state.tracking_info_.IsTrackingOn()) { facade::Connection::InvalidationMessage x; x.invalidate_due_to_flush = true; cntx->conn()->SendInvalidationMessageAsync(x); diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index eb356de3d90..2a8f4576cba 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -271,6 +271,11 @@ TEST_F(ServerFamilyTest, ClientTrackingMulti) { Run({"SET", "FOO", "10"}); Run({"SET", "FOOBAR", "10"}); EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + + Run({"MULTI"}); + auto resp = Run({"CLIENT", "TRACKING", "ON"}); + EXPECT_THAT(resp, ArgType(RespExpr::ERROR)); + Run({"DISCARD"}); } TEST_F(ServerFamilyTest, ClientTrackingMultiOptin) { @@ -301,18 +306,38 @@ TEST_F(ServerFamilyTest, ClientTrackingMultiOptin) { // CACHING enclosed in MULTI Run({"MULTI"}); Run({"GET", "TMP"}); + Run({"GET", "TMP_TMP"}); Run({"SET", "TMP", "10"}); Run({"CLIENT", "CACHING", "YES"}); Run({"GET", "FOO"}); Run({"GET", "FOOBAR"}); Run({"EXEC"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); Run({"SET", "TMP", "10"}); EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); Run({"SET", "FOO", "10"}); EXPECT_EQ(InvalidationMessagesLen("IO0"), 3); Run({"SET", "FOOBAR", "10"}); EXPECT_EQ(InvalidationMessagesLen("IO0"), 4); + + // CACHING enclosed in MULTI, ON/OFF + Run({"MULTI"}); + Run({"GET", "TMP"}); + Run({"SET", "TMP", "10"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"GET", "FOO"}); + Run({"CLIENT", "CACHING", "NO"}); + Run({"GET", "BAR"}); + Run({"EXEC"}); + + EXPECT_EQ(InvalidationMessagesLen("IO0"), 4); + Run({"SET", "FOO", "10"}); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 5); + Run({"SET", "BAR", "10"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 5); } TEST_F(ServerFamilyTest, ClientTrackingUpdateKey) { diff --git a/src/server/transaction.cc b/src/server/transaction.cc index e9c8f7e9494..2d36a07ad96 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -691,8 +691,12 @@ void Transaction::RunCallback(EngineShard* shard) { } // Log to journal only once the command finished running - if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) + if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); + if (cntx_) { + cntx_->conn_state.tracking_info_.Track(cntx_, cid_); + } + } } // TODO: For multi-transactions we should be able to deduce mode() at run-time based @@ -838,23 +842,13 @@ OpStatus Transaction::ScheduleSingleHop(RunnableType cb) { // Runs in coordinator thread. void Transaction::Execute(RunnableType cb, bool conclude) { - auto tracking_wrap = [cb, this](Transaction* t, EngineShard* shard) -> RunnableResult { - auto res = cb(t, shard); - if (cntx_) { - cntx_->ClientTrackingInfo().Track(cntx_, invoke_cid_); - } - return res; - }; - - RunnableType wrapper = tracking_wrap; - if (multi_ && multi_->role == SQUASHED_STUB) { - local_result_ = RunSquashedMultiCb(wrapper); + local_result_ = RunSquashedMultiCb(cb); return; } local_result_ = OpStatus::OK; - cb_ptr_ = &wrapper; + cb_ptr_ = &cb; if (IsAtomicMulti()) { multi_->concluding = conclude; diff --git a/src/server/transaction.h b/src/server/transaction.h index cb94787096d..5888b8eddc4 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -363,9 +363,8 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } - void SetConnectionContextAndInvokeCid(ConnectionContext* cntx, const CommandId* cid) { + void SetConnectionContextAndInvokeCid(ConnectionContext* cntx) { cntx_ = cntx; - invoke_cid_ = cid; } std::set GetActiveShards() { @@ -649,7 +648,6 @@ class Transaction { } stats_; ConnectionContext* cntx_{nullptr}; - const CommandId* invoke_cid_{nullptr}; private: struct TLTmpSpace { From 40e2eaa9429fadb54a9983dab398e0195955b7d9 Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 9 May 2024 15:39:56 +0300 Subject: [PATCH 07/17] address comments --- src/facade/dragonfly_connection.cc | 2 ++ src/server/conn_context.cc | 13 ++++++------- src/server/conn_context.h | 1 - src/server/db_slice.cc | 3 ++- src/server/main_service.cc | 6 +++++- src/server/transaction.h | 8 +------- 6 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 77c7953b19c..441c843df4a 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1113,6 +1113,7 @@ void Connection::HandleMigrateRequest() { this->Migrate(dest); } + // This triggers on rueidis SingleIntegrationTest // DCHECK(dispatch_q_.empty()); // In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started. @@ -1325,6 +1326,7 @@ void Connection::DispatchFiber(util::FiberSocketBase* peer) { uint64_t prev_epoch = fb2::FiberSwitchEpoch(); while (!builder->GetError()) { + DCHECK_EQ(socket()->proactor(), ProactorBase::me()); evc_.await( [this] { return cc_->conn_closing || (!dispatch_q_.empty() && !cc_->sync_dispatch); }); if (cc_->conn_closing) diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index ea109f3f160..dcad1c8e2d0 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -90,7 +90,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own ConnectionContext::ConnectionContext(ConnectionContext* owner, Transaction* tx, facade::CapturingReplyBuilder* crb) - : facade::ConnectionContext(nullptr, nullptr), transaction{tx}, parent_cntx_(owner) { + : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { acl_commands = std::vector(acl::NumberOfFamilies(), acl::ALL_COMMANDS); if (tx) { // If we have a carrier transaction, this context is used for squashing DCHECK(owner); @@ -296,14 +296,13 @@ OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::We void ConnectionState::ClientTracking::Track(ConnectionContext* cntx, const CommandId* cid) { auto& info = cntx->conn_state.tracking_info_; - auto shards = cntx->transaction->GetActiveShards(); if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { - auto conn = cntx->parent_cntx_ ? cntx->parent_cntx_->conn()->Borrow() : cntx->conn()->Borrow(); + auto conn = cntx->conn()->Borrow(); auto cb = [&, conn](unsigned i, auto* pb) { - if (shards.find(i) != shards.end()) { - auto* t = cntx->transaction; - CHECK(t); - auto* shard = EngineShard::tlocal(); + auto* t = cntx->transaction; + CHECK(t); + auto* shard = EngineShard::tlocal(); + if (shard && t->IsActive(i)) { OpTrackKeys(t->GetOpArgs(shard), conn, t->GetShardArgs(shard->shard_id())); } }; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 91419738945..f328ec52caa 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -268,7 +268,6 @@ class ConnectionContext : public facade::ConnectionContext { // TODO: to introduce proper accessors. Transaction* transaction = nullptr; const CommandId* cid = nullptr; - ConnectionContext* parent_cntx_ = nullptr; ConnectionState conn_state; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 6a5670060cb..365bde0f69f 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1395,7 +1395,8 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { return; } auto& client_set = it->second; - // notify all the clients. + // notify all the clients. copy key because we dispatch briefly below and we need to preserve + // lifetime auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, util::ProactorBase*) { for (auto& client : client_set) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index a9c05b8c5c8..747ac1a4cd8 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1287,7 +1287,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo auto* trans = cntx->transaction; if (trans) { - cntx->transaction->SetConnectionContextAndInvokeCid(cntx); + cntx->transaction->SetConnectionContext(cntx); } #ifndef NDEBUG @@ -1372,6 +1372,10 @@ size_t Service::DispatchManyCommands(absl::Span args_list, for (auto args : args_list) { ToUpper(&args[0]); const auto [cid, tail_args] = FindCmd(args); + // is client tracking command + if (cid->name() == "CLIENT" && !tail_args.empty() && ToSV(tail_args[0]) == "TRACKING") { + break; + } // MULTI...EXEC commands need to be collected into a single context, so squashing is not // possible diff --git a/src/server/transaction.h b/src/server/transaction.h index 5888b8eddc4..14e0df8d6b6 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -363,16 +363,10 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } - void SetConnectionContextAndInvokeCid(ConnectionContext* cntx) { + void SetConnectionContext(ConnectionContext* cntx) { cntx_ = cntx; } - std::set GetActiveShards() { - std::set active_shards; - IterateActiveShards([&](const auto& sd, ShardId i) mutable { active_shards.insert(i); }); - return active_shards; - } - private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt { From 94c08662eb07cdefde80c86658205ddd9a63d5e0 Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 9 May 2024 17:04:41 +0300 Subject: [PATCH 08/17] fix tests --- src/server/main_service.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 747ac1a4cd8..73062048a19 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1373,7 +1373,7 @@ size_t Service::DispatchManyCommands(absl::Span args_list, ToUpper(&args[0]); const auto [cid, tail_args] = FindCmd(args); // is client tracking command - if (cid->name() == "CLIENT" && !tail_args.empty() && ToSV(tail_args[0]) == "TRACKING") { + if (cid && cid->name() == "CLIENT" && !tail_args.empty() && ToSV(tail_args[0]) == "TRACKING") { break; } From 49ce6b6d21e433281386f1220aa89d39a8e8fdf9 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 27 May 2024 13:55:57 +0300 Subject: [PATCH 09/17] fix small bug --- src/server/conn_context.cc | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index dcad1c8e2d0..25ec1ecd129 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -298,15 +298,10 @@ void ConnectionState::ClientTracking::Track(ConnectionContext* cntx, const Comma auto& info = cntx->conn_state.tracking_info_; if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { auto conn = cntx->conn()->Borrow(); - auto cb = [&, conn](unsigned i, auto* pb) { - auto* t = cntx->transaction; - CHECK(t); - auto* shard = EngineShard::tlocal(); - if (shard && t->IsActive(i)) { - OpTrackKeys(t->GetOpArgs(shard), conn, t->GetShardArgs(shard->shard_id())); - } - }; - shard_set->pool()->AwaitFiberOnAll(std::move(cb)); + auto* t = cntx->transaction; + CHECK(t); + auto* shard = EngineShard::tlocal(); + OpTrackKeys(t->GetOpArgs(shard), conn, t->GetShardArgs(shard->shard_id())); } } } // namespace dfly From 6c45d9b1428f7572e9e22135fea86270d98cc6c4 Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 27 May 2024 18:56:03 +0300 Subject: [PATCH 10/17] add small comments --- src/server/db_slice.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 9d0b5e5dc39..e7fdab0f4a4 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1403,8 +1403,9 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { return; } auto& client_set = it->second; - // notify all the clients. copy key because we dispatch briefly below and we need to preserve - // lifetime + // Notify all the clients. We copy key because we dispatch briefly below and + // we need to preserve its lifetime + // TODO this key is further copied within DispatchFiber. Fix this. auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, util::ProactorBase*) { for (auto& client : client_set) { From b9e2103ce336b04aa14e0eeaed953d6587e3e057 Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 30 May 2024 18:59:28 +0300 Subject: [PATCH 11/17] replace connection context with callback in transaction --- src/server/main_service.cc | 3 ++- src/server/transaction.cc | 4 ++-- src/server/transaction.h | 7 +++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 6e4e3afd4aa..e7d4bd2f3a8 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1292,7 +1292,8 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo auto* trans = cntx->transaction; if (trans) { - cntx->transaction->SetConnectionContext(cntx); + cntx->transaction->SetTrackingCallback( + [cntx](const auto* cid) { cntx->conn_state.tracking_info_.Track(cntx, cid); }); } #ifndef NDEBUG diff --git a/src/server/transaction.cc b/src/server/transaction.cc index be15a49ac79..28894b9d4c4 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -693,8 +693,8 @@ void Transaction::RunCallback(EngineShard* shard) { // Log to journal only once the command finished running if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); - if (cntx_) { - cntx_->conn_state.tracking_info_.Track(cntx_, cid_); + if (tracking_cb_) { + tracking_cb_(cid_); } } } diff --git a/src/server/transaction.h b/src/server/transaction.h index 90195c73c3d..ccfdcf1821e 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -20,7 +20,6 @@ #include "facade/op_status.h" #include "server/cluster/cluster_utility.h" #include "server/common.h" -#include "server/conn_context.h" #include "server/journal/types.h" #include "server/table.h" #include "server/tx_base.h" @@ -363,8 +362,8 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } - void SetConnectionContext(ConnectionContext* cntx) { - cntx_ = cntx; + void SetTrackingCallback(std::function f) { + tracking_cb_ = std::move(f); } private: @@ -641,7 +640,7 @@ class Transaction { ShardId coordinator_index = 0; } stats_; - ConnectionContext* cntx_{nullptr}; + std::function tracking_cb_; private: struct TLTmpSpace { From c13961de1caeb6cb97af6302b5d4ce1713a6a365 Mon Sep 17 00:00:00 2001 From: kostas Date: Thu, 30 May 2024 19:04:56 +0300 Subject: [PATCH 12/17] rename and reword comment --- src/facade/dragonfly_connection.cc | 3 ++- src/server/conn_context.cc | 2 +- src/server/conn_context.h | 2 +- src/server/main_service.cc | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index e12b9a8a006..4bac140cf55 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1118,7 +1118,8 @@ void Connection::HandleMigrateRequest() { this->Migrate(dest); } - // This triggers on rueidis SingleIntegrationTest + // This triggers when a pub/sub connection both publish and subscribe to the + // same channel. See #3035 on github for details. // DCHECK(dispatch_q_.empty()); // In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started. diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 25ec1ecd129..637f79852e4 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -294,7 +294,7 @@ OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::We return OpStatus::OK; } -void ConnectionState::ClientTracking::Track(ConnectionContext* cntx, const CommandId* cid) { +void ConnectionState::ClientTracking::TrackOnShard(ConnectionContext* cntx, const CommandId* cid) { auto& info = cntx->conn_state.tracking_info_; if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { auto conn = cntx->conn()->Borrow(); diff --git a/src/server/conn_context.h b/src/server/conn_context.h index f328ec52caa..29ce6a68d99 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -211,7 +211,7 @@ struct ConnectionState { // Iterates over the active shards of the transaction. If a key satisfies // the tracking requirements, is is set for tracking. - void Track(ConnectionContext* cntx, const CommandId* cid); + void TrackOnShard(ConnectionContext* cntx, const CommandId* cid); // Called by CLIENT CACHING YES and caches the current seq_num_ void SetCachingSequenceNumber(bool is_multi) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index e7d4bd2f3a8..b9c1a082b81 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1293,7 +1293,7 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo auto* trans = cntx->transaction; if (trans) { cntx->transaction->SetTrackingCallback( - [cntx](const auto* cid) { cntx->conn_state.tracking_info_.Track(cntx, cid); }); + [cntx](const auto* cid) { cntx->conn_state.tracking_info_.TrackOnShard(cntx, cid); }); } #ifndef NDEBUG From f9b13f269398ace8857bd3361567ff9206673011 Mon Sep 17 00:00:00 2001 From: kostas Date: Fri, 31 May 2024 16:54:12 +0300 Subject: [PATCH 13/17] move check in invoke command --- src/server/CMakeLists.txt | 7 +++---- src/server/conn_context.cc | 31 ------------------------------- src/server/conn_context.h | 4 ---- src/server/main_service.cc | 33 +++++++++++++++++++++++++++++++-- src/server/transaction.cc | 2 +- src/server/transaction.h | 4 ++-- 6 files changed, 37 insertions(+), 44 deletions(-) diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index b114d2d7ae1..a007a431c31 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -33,14 +33,13 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc - ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc conn_context.cc channel_store.cc - ) + ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc) SET(SEARCH_FILES search/search_family.cc search/doc_index.cc search/doc_accessors.cc search/aggregator.cc) -add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc - config_registry.cc debugcmd.cc dflycmd.cc +add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc channel_store.cc + config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc generic_family.cc hset_family.cc http_api.cc json_family.cc ${SEARCH_FILES} list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 637f79852e4..fc9b0b64419 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -273,35 +273,4 @@ bool ConnectionState::ClientTracking::ShouldTrackKeys() const { return !optin_ || (seq_num_ == (1 + caching_seq_num_)); } -OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::WeakRef& conn_ref, - const ShardArgs& args) { - if (conn_ref.IsExpired()) { - DVLOG(2) << "Connection expired, exiting TrackKey function."; - return OpStatus::OK; - } - - DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() - << " with thread ID: " << conn_ref.Thread(); - - auto& db_slice = slice_args.shard->db_slice(); - // TODO: There is a bug here that we track all arguments instead of tracking only keys. - for (auto key : args) { - DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() - << " into the tracking client set of key " << key; - db_slice.TrackKey(conn_ref, key); - } - - return OpStatus::OK; -} - -void ConnectionState::ClientTracking::TrackOnShard(ConnectionContext* cntx, const CommandId* cid) { - auto& info = cntx->conn_state.tracking_info_; - if ((cid->opt_mask() & CO::READONLY) && cid->IsTransactional() && info.ShouldTrackKeys()) { - auto conn = cntx->conn()->Borrow(); - auto* t = cntx->transaction; - CHECK(t); - auto* shard = EngineShard::tlocal(); - OpTrackKeys(t->GetOpArgs(shard), conn, t->GetShardArgs(shard->shard_id())); - } -} } // namespace dfly diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 29ce6a68d99..38d29451ea6 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -209,10 +209,6 @@ struct ConnectionState { return tracking_enabled_; } - // Iterates over the active shards of the transaction. If a key satisfies - // the tracking requirements, is is set for tracking. - void TrackOnShard(ConnectionContext* cntx, const CommandId* cid); - // Called by CLIENT CACHING YES and caches the current seq_num_ void SetCachingSequenceNumber(bool is_multi) { // We need -1 when we are in multi diff --git a/src/server/main_service.cc b/src/server/main_service.cc index b9c1a082b81..852e3afbf13 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1266,6 +1266,27 @@ class ReplyGuard { SinkReplyBuilder* builder_ = nullptr; }; +OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::WeakRef& conn_ref, + const ShardArgs& args) { + if (conn_ref.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return OpStatus::OK; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() + << " with thread ID: " << conn_ref.Thread(); + + auto& db_slice = slice_args.shard->db_slice(); + // TODO: There is a bug here that we track all arguments instead of tracking only keys. + for (auto key : args) { + DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() + << " into the tracking client set of key " << key; + db_slice.TrackKey(conn_ref, key); + } + + return OpStatus::OK; +} + bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionContext* cntx) { DCHECK(cid); DCHECK(!cid->Validate(tail_args)); @@ -1290,10 +1311,18 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo ServerState::tlocal()->RecordCmd(); + auto& info = cntx->conn_state.tracking_info_; auto* trans = cntx->transaction; + const bool is_read_only = cid->opt_mask() & CO::READONLY; if (trans) { - cntx->transaction->SetTrackingCallback( - [cntx](const auto* cid) { cntx->conn_state.tracking_info_.TrackOnShard(cntx, cid); }); + trans->SetTrackingCallback({}); + if (is_read_only && info.ShouldTrackKeys()) { + auto conn = cntx->conn()->Borrow(); + trans->SetTrackingCallback([trans, conn]() { + auto* shard = EngineShard::tlocal(); + OpTrackKeys(trans->GetOpArgs(shard), conn, trans->GetShardArgs(shard->shard_id())); + }); + } } #ifndef NDEBUG diff --git a/src/server/transaction.cc b/src/server/transaction.cc index e26d665800e..23cf886609d 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -694,7 +694,7 @@ void Transaction::RunCallback(EngineShard* shard) { if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); if (tracking_cb_) { - tracking_cb_(cid_); + tracking_cb_(); } } } diff --git a/src/server/transaction.h b/src/server/transaction.h index c32b4b5cd16..bbbb2dd73c3 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -363,7 +363,7 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } - void SetTrackingCallback(std::function f) { + void SetTrackingCallback(std::function f) { tracking_cb_ = std::move(f); } @@ -641,7 +641,7 @@ class Transaction { ShardId coordinator_index = 0; } stats_; - std::function tracking_cb_; + std::function tracking_cb_; private: struct TLTmpSpace { From a8360daac247985bed5f042741db70bb732513cf Mon Sep 17 00:00:00 2001 From: kostas Date: Sat, 1 Jun 2024 13:43:29 +0300 Subject: [PATCH 14/17] add DCHECK --- src/server/main_service.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 852e3afbf13..5d1620d8f07 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1095,6 +1095,7 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA // for some reason we get a trailing \n\r, and that's why we use StartsWith bool client_cmd = cmd_name == "CLIENT" && !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); + DCHECK(!tail_args.empty()); if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB" || client_cmd) return ErrorReply{absl::StrCat("'", cmd_name, "' inside MULTI is not allowed")}; } From 34304acf51c84ec17c7bc385427381db4953f716 Mon Sep 17 00:00:00 2001 From: kostas Date: Sat, 1 Jun 2024 14:06:24 +0300 Subject: [PATCH 15/17] fix dcheck --- src/server/main_service.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 5d1620d8f07..7a37d19b732 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1094,8 +1094,11 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return ErrorReply{absl::StrCat("Can not call ", cmd_name, " within a transaction")}; // for some reason we get a trailing \n\r, and that's why we use StartsWith - bool client_cmd = cmd_name == "CLIENT" && !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); - DCHECK(!tail_args.empty()); + bool client_cmd = true; + if (cmd_name == "CLIENT") { + DCHECK(!tail_args.empty()); + client_cmd = !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); + } if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB" || client_cmd) return ErrorReply{absl::StrCat("'", cmd_name, "' inside MULTI is not allowed")}; } From 4e33d726bdf4042cb47404500b44c74bba1714e5 Mon Sep 17 00:00:00 2001 From: kostas Date: Sat, 1 Jun 2024 14:46:33 +0300 Subject: [PATCH 16/17] fix condition --- src/server/main_service.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 7a37d19b732..0dbae4a5d4c 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1094,7 +1094,7 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return ErrorReply{absl::StrCat("Can not call ", cmd_name, " within a transaction")}; // for some reason we get a trailing \n\r, and that's why we use StartsWith - bool client_cmd = true; + bool client_cmd = false; if (cmd_name == "CLIENT") { DCHECK(!tail_args.empty()); client_cmd = !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); From 669a204f8e07551f91160a16a97fcd3added223e Mon Sep 17 00:00:00 2001 From: kostas Date: Mon, 3 Jun 2024 16:39:58 +0300 Subject: [PATCH 17/17] address gh comments --- src/server/conn_context.cc | 2 +- src/server/conn_context.h | 3 ++- src/server/main_service.cc | 10 +++++++++- src/server/transaction.cc | 2 +- src/server/transaction.h | 4 ++-- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index fc9b0b64419..d223cd592ae 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -88,7 +88,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own acl_commands = std::vector(acl::NumberOfFamilies(), acl::ALL_COMMANDS); } -ConnectionContext::ConnectionContext(ConnectionContext* owner, Transaction* tx, +ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx, facade::CapturingReplyBuilder* crb) : facade::ConnectionContext(nullptr, nullptr), transaction{tx} { acl_commands = std::vector(acl::NumberOfFamilies(), acl::ALL_COMMANDS); diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 38d29451ea6..7d57fc2c9a6 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -249,7 +249,8 @@ class ConnectionContext : public facade::ConnectionContext { public: ConnectionContext(::io::Sink* stream, facade::Connection* owner); - ConnectionContext(ConnectionContext* owner, Transaction* tx, facade::CapturingReplyBuilder* crb); + ConnectionContext(const ConnectionContext* owner, Transaction* tx, + facade::CapturingReplyBuilder* crb); struct DebugInfo { uint32_t shards_count = 0; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 0dbae4a5d4c..da85f4d5d9e 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1319,10 +1319,13 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo auto* trans = cntx->transaction; const bool is_read_only = cid->opt_mask() & CO::READONLY; if (trans) { + // Reset it, because in multi/exec the transaction pointer is the same and + // we will end up triggerring the callback on the following commands. To avoid this + // we reset it. trans->SetTrackingCallback({}); if (is_read_only && info.ShouldTrackKeys()) { auto conn = cntx->conn()->Borrow(); - trans->SetTrackingCallback([trans, conn]() { + trans->SetTrackingCallback([conn](Transaction* trans) { auto* shard = EngineShard::tlocal(); OpTrackKeys(trans->GetOpArgs(shard), conn, trans->GetShardArgs(shard->shard_id())); }); @@ -1343,6 +1346,11 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo auto cid_name = cid->name(); if ((!trans && cid_name != "MULTI") || (trans && !trans->IsMulti())) { + // Each time we execute a command we need to increase the sequence number in + // order to properly track clients when OPTIN is used. + // We don't do this for `multi/exec` because it would break the + // semantics, i.e, CACHING should stick for all commands following + // the CLIENT CACHING ON within a multi/exec block cntx->conn_state.tracking_info_.IncrementSequenceNumber(); } diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 23cf886609d..1b4e994b99e 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -694,7 +694,7 @@ void Transaction::RunCallback(EngineShard* shard) { if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); if (tracking_cb_) { - tracking_cb_(); + tracking_cb_(this); } } } diff --git a/src/server/transaction.h b/src/server/transaction.h index a71f1fd7dca..ffd8abecde8 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -363,7 +363,7 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } - void SetTrackingCallback(std::function f) { + void SetTrackingCallback(std::function f) { tracking_cb_ = std::move(f); } @@ -648,7 +648,7 @@ class Transaction { ShardId coordinator_index = 0; } stats_; - std::function tracking_cb_; + std::function tracking_cb_; private: struct TLTmpSpace {