diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 09a3fbeeaac..0e5b8ea409c 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -883,7 +883,6 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { cc_->conn_closing = true; // Signal dispatch to close. evc_.notify(); phase_ = SHUTTING_DOWN; - VLOG(2) << "Before dispatch_fb.join()"; dispatch_fb_.JoinIfNeeded(); VLOG(2) << "After dispatch_fb.join()"; @@ -1119,6 +1118,10 @@ void Connection::HandleMigrateRequest() { } } + // This triggers when a pub/sub connection both publish and subscribe to the + // same channel. See #3035 on github for details. + // DCHECK(dispatch_q_.empty()); + // In case we Yield()ed in Migrate() above, dispatch_fb_ might have been started. LaunchDispatchFiberIfNeeded(); } @@ -1648,16 +1651,6 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { migration_request_ = dest; } -void Connection::SetClientTrackingSwitch(bool is_on) { - tracking_enabled_ = is_on; - if (tracking_enabled_) - cc_->subscriptions++; -} - -bool Connection::IsTrackingOn() const { - return tracking_enabled_; -} - void Connection::StartTrafficLogging(string_view path) { OpenTrafficLogger(path); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 929ad5f8b5f..71d5164e959 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -298,10 +298,6 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); - void SetClientTrackingSwitch(bool is_on); - - bool IsTrackingOn() const; - // Starts traffic logging in the calling thread. Must be a proactor thread. // Each thread creates its own log file combining requests from all the connections in // that thread. A noop if the thread is already logging. @@ -450,8 +446,6 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; - // a flag indicating whether the client has turned on client tracking. - bool tracking_enabled_ = false; bool skip_next_squashing_ = false; // Forcefully skip next squashing // Connection migration vars, see RequestAsyncMigration() above. diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 833942dbe4e..a007a431c31 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -33,8 +33,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc - ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc - ) + ${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc) SET(SEARCH_FILES search/search_family.cc search/doc_index.cc search/doc_accessors.cc search/aggregator.cc) diff --git a/src/server/conn_context.cc b/src/server/conn_context.cc index 011793636f4..d223cd592ae 100644 --- a/src/server/conn_context.cc +++ b/src/server/conn_context.cc @@ -265,4 +265,12 @@ void ConnectionState::ExecInfo::ClearWatched() { watched_existed = 0; } +bool ConnectionState::ClientTracking::ShouldTrackKeys() const { + if (!IsTrackingOn()) { + return false; + } + + return !optin_ || (seq_num_ == (1 + caching_seq_num_)); +} + } // namespace dfly diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 4c007b4e135..7d57fc2c9a6 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -147,6 +147,87 @@ struct ConnectionState { size_t UsedMemory() const; + // Client tracking is a per-connection state machine that adheres to the requirements + // of the CLIENT TRACKING command. Note that the semantics described below are enforced + // by the tests in server_family_test. The rules are: + // 1. If CLIENT TRACKING is ON then each READ command must be tracked. Invalidation + // messages are sent `only once`. Subsequent changes of the same key require the + // client to re-read the key in order to receive the next invalidation message. + // 2. CLIENT TRACKING ON OPTIN turns on optional tracking. Read commands are not + // tracked unless the client issues a CLIENT CACHING YES command which conditionally + // allows the tracking of the command that follows CACHING YES). For example: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> GET foo <--------------------- From now foo is being tracked + // However: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> SET foo bar + // >> GET foo <--------------------- is *NOT* tracked since GET does not succeed CACHING + // Also, in the context of multi transactions, CLIENT CACHING YES is *STICKY*: + // >> CLIENT TRACKING ON + // >> CLIENT CACHING YES + // >> MULTI + // >> GET foo + // >> SET foo bar + // >> GET brother_foo + // >> EXEC + // From this point onwards `foo` and `get` keys are tracked. Same aplies if CACHING YES + // is used within the MULTI/EXEC block. + // + // The state machine implements the above rules. We need to track: + // 1. If TRACKING is ON and OPTIN + // 2. Stickiness of CACHING as described above + // + // We introduce a monotonic counter called sequence number which we increment only: + // * On InvokeCmd when we are not Collecting (multi) + // We introduce another counter called caching_seq_num which is set to seq_num + // when the users sends a CLIENT CACHING YES command + // If seq_num == caching_seq_num + 1 then we know that we should Track(). + class ClientTracking { + public: + // Sets to true when CLIENT TRACKING is ON + void SetClientTracking(bool is_on) { + tracking_enabled_ = is_on; + } + + // Increment current sequence number + void IncrementSequenceNumber() { + ++seq_num_; + } + + // Set if OPTIN subcommand is used in CLIENT TRACKING + void SetOptin(bool optin) { + optin_ = optin; + } + + // Check if the keys should be tracked. Result adheres to the state machine described above. + bool ShouldTrackKeys() const; + + // Check only if CLIENT TRACKING is ON + bool IsTrackingOn() const { + return tracking_enabled_; + } + + // Called by CLIENT CACHING YES and caches the current seq_num_ + void SetCachingSequenceNumber(bool is_multi) { + // We need -1 when we are in multi + caching_seq_num_ = is_multi && seq_num_ != 0 ? seq_num_ - 1 : seq_num_; + } + + void ResetCachingSequenceNumber() { + caching_seq_num_ = 0; + } + + private: + // a flag indicating whether the client has turned on client tracking. + bool tracking_enabled_ = false; + bool optin_ = false; + // sequence number + size_t seq_num_ = 0; + size_t caching_seq_num_ = 0; + }; + public: DbIndex db_index = 0; @@ -161,6 +242,7 @@ struct ConnectionState { std::optional squashing_info; std::unique_ptr script_info; std::unique_ptr subscribe_info; + ClientTracking tracking_info_; }; class ConnectionContext : public facade::ConnectionContext { @@ -183,6 +265,7 @@ class ConnectionContext : public facade::ConnectionContext { // TODO: to introduce proper accessors. Transaction* transaction = nullptr; const CommandId* cid = nullptr; + ConnectionState conn_state; DbIndex db_index() const { diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index efb18e569cb..73b818e53a8 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1423,24 +1423,29 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { return; auto it = client_tracking_map_.find(key); - if (it != client_tracking_map_.end()) { - // notify all the clients. - auto& client_set = it->second; - auto cb = [key, client_set = std::move(client_set)](unsigned idx, util::ProactorBase*) { - for (auto it = client_set.begin(); it != client_set.end(); ++it) { - if ((unsigned int)it->Thread() != idx) - continue; - facade::Connection* conn = it->Get(); - if ((conn != nullptr) && conn->IsTrackingOn()) { - std::string key_str = {key.begin(), key.end()}; - conn->SendInvalidationMessageAsync({key_str}); - } - } - }; - shard_set->pool()->DispatchBrief(std::move(cb)); - // remove this key from the tracking table as the key no longer exists - client_tracking_map_.erase(key); + if (it == client_tracking_map_.end()) { + return; } + auto& client_set = it->second; + // Notify all the clients. We copy key because we dispatch briefly below and + // we need to preserve its lifetime + // TODO this key is further copied within DispatchFiber. Fix this. + auto cb = [key = std::string(key), client_set = std::move(client_set)](unsigned idx, + util::ProactorBase*) { + for (auto& client : client_set) { + if (client.IsExpired() || (client.Thread() != idx)) { + continue; + } + auto* conn = client.Get(); + auto* cntx = static_cast(conn->cntx()); + if (cntx && cntx->conn_state.tracking_info_.IsTrackingOn()) { + conn->SendInvalidationMessageAsync({key}); + } + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); + // remove this key from the tracking table as the key no longer exists + client_tracking_map_.erase(key); } void DbSlice::PerformDeletion(PrimeIterator del_it, DbTable* table) { diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 276fc5515a6..da85f4d5d9e 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -4,6 +4,8 @@ #include "server/main_service.h" +#include "facade/resp_expr.h" + #ifdef __FreeBSD__ #include #endif @@ -1091,7 +1093,13 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA if (cmd_name == "SELECT" || absl::EndsWith(cmd_name, "SUBSCRIBE")) return ErrorReply{absl::StrCat("Can not call ", cmd_name, " within a transaction")}; - if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB") + // for some reason we get a trailing \n\r, and that's why we use StartsWith + bool client_cmd = false; + if (cmd_name == "CLIENT") { + DCHECK(!tail_args.empty()); + client_cmd = !absl::StartsWith(ToSV(tail_args[0]), "CACHING"); + } + if (cmd_name == "WATCH" || cmd_name == "FLUSHALL" || cmd_name == "FLUSHDB" || client_cmd) return ErrorReply{absl::StrCat("'", cmd_name, "' inside MULTI is not allowed")}; } @@ -1127,28 +1135,6 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions", tail_args); } -OpResult OpTrackKeys(const OpArgs& op_args, const facade::Connection::WeakRef& conn_ref, - const ShardArgs& args) { - if (conn_ref.IsExpired()) { - DVLOG(2) << "Connection expired, exiting TrackKey function."; - return OpStatus::OK; - } - - DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() - << " with thread ID: " << conn_ref.Thread(); - - DbSlice& db_slice = op_args.shard->db_slice(); - - // TODO: There is a bug here that we track all arguments instead of tracking only keys. - for (auto key : args) { - DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() - << " into the tracking client set of key " << key; - db_slice.TrackKey(conn_ref, key); - } - - return OpStatus::OK; -} - void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); DCHECK_NE(0u, shard_set->size()) << "Init was not called"; @@ -1253,18 +1239,6 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->reply_builder()->CloseConnection(); } - // if this is a read command, and client tracking has enabled, - // start tracking all the updates to the keys in this read command - if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && - cid->IsTransactional()) { - facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); - auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { - return OpTrackKeys(t->GetOpArgs(shard), conn_ref, t->GetShardArgs(shard->shard_id())); - }; - dfly_cntx->transaction->Refurbish(); - dfly_cntx->transaction->ScheduleSingleHopT(cb); - } - if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } @@ -1296,6 +1270,27 @@ class ReplyGuard { SinkReplyBuilder* builder_ = nullptr; }; +OpResult OpTrackKeys(const OpArgs slice_args, const facade::Connection::WeakRef& conn_ref, + const ShardArgs& args) { + if (conn_ref.IsExpired()) { + DVLOG(2) << "Connection expired, exiting TrackKey function."; + return OpStatus::OK; + } + + DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId() + << " with thread ID: " << conn_ref.Thread(); + + auto& db_slice = slice_args.shard->db_slice(); + // TODO: There is a bug here that we track all arguments instead of tracking only keys. + for (auto key : args) { + DVLOG(2) << "Inserting client ID " << conn_ref.GetClientId() + << " into the tracking client set of key " << key; + db_slice.TrackKey(conn_ref, key); + } + + return OpStatus::OK; +} + bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionContext* cntx) { DCHECK(cid); DCHECK(!cid->Validate(tail_args)); @@ -1320,6 +1315,23 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo ServerState::tlocal()->RecordCmd(); + auto& info = cntx->conn_state.tracking_info_; + auto* trans = cntx->transaction; + const bool is_read_only = cid->opt_mask() & CO::READONLY; + if (trans) { + // Reset it, because in multi/exec the transaction pointer is the same and + // we will end up triggerring the callback on the following commands. To avoid this + // we reset it. + trans->SetTrackingCallback({}); + if (is_read_only && info.ShouldTrackKeys()) { + auto conn = cntx->conn()->Borrow(); + trans->SetTrackingCallback([conn](Transaction* trans) { + auto* shard = EngineShard::tlocal(); + OpTrackKeys(trans->GetOpArgs(shard), conn, trans->GetShardArgs(shard->shard_id())); + }); + } + } + #ifndef NDEBUG // Verifies that we reply to the client when needed. ReplyGuard reply_guard(cntx, cid->name()); @@ -1332,6 +1344,16 @@ bool Service::InvokeCmd(const CommandId* cid, CmdArgList tail_args, ConnectionCo return false; } + auto cid_name = cid->name(); + if ((!trans && cid_name != "MULTI") || (trans && !trans->IsMulti())) { + // Each time we execute a command we need to increase the sequence number in + // order to properly track clients when OPTIN is used. + // We don't do this for `multi/exec` because it would break the + // semantics, i.e, CACHING should stick for all commands following + // the CLIENT CACHING ON within a multi/exec block + cntx->conn_state.tracking_info_.IncrementSequenceNumber(); + } + // TODO: we should probably discard more commands here, // not just the blocking ones const auto* conn = cntx->conn(); @@ -1397,6 +1419,10 @@ size_t Service::DispatchManyCommands(absl::Span args_list, for (auto args : args_list) { ToUpper(&args[0]); const auto [cid, tail_args] = FindCmd(args); + // is client tracking command + if (cid && cid->name() == "CLIENT" && !tail_args.empty() && ToSV(tail_args[0]) == "TRACKING") { + break; + } // MULTI...EXEC commands need to be collected into a single context, so squashing is not // possible @@ -2186,7 +2212,8 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) { ServerState::tlocal()->exec_freq_count[descr]++; } - if (absl::GetFlag(FLAGS_multi_exec_squash) && state == ExecEvalState::NONE) { + if (absl::GetFlag(FLAGS_multi_exec_squash) && state == ExecEvalState::NONE && + !cntx->conn_state.tracking_info_.IsTrackingOn()) { MultiCommandSquasher::Execute(absl::MakeSpan(exec_info.body), cntx, this); } else { CmdArgVec arg_vec; @@ -2232,7 +2259,6 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { auto* cs = ServerState::tlocal()->channel_store(); vector subscribers = cs->FetchSubscribers(channel); int num_published = subscribers.size(); - if (!subscribers.empty()) { // Make sure neither of the threads limits is reached. // This check actually doesn't reserve any memory ahead and doesn't prevent the buffer @@ -2559,7 +2585,7 @@ void Service::OnClose(facade::ConnectionContext* cntx) { server_family_.OnClose(server_cntx); - cntx->conn()->SetClientTrackingSwitch(false); + conn_state.tracking_info_.SetClientTracking(false); } Service::ContextInfo Service::GetContextInfo(facade::ConnectionContext* cntx) const { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 119020e7f7c..c1b87f3f198 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -456,29 +456,62 @@ void ClientPauseCmd(CmdArgList args, vector listeners, Connec } void ClientTracking(CmdArgList args, ConnectionContext* cntx) { - if (args.size() != 1) - return cntx->SendError(kSyntaxErr); - auto* rb = static_cast(cntx->reply_builder()); if (!rb->IsResp3()) return cntx->SendError( "Client tracking is currently not supported for RESP2. Please use RESP3."); - ToUpper(&args[0]); - string_view state = ArgS(args, 0); - bool is_on; - if (state == "ON") { + CmdArgParser parser{args}; + if (!parser.HasAtLeast(1) || args.size() > 2) + return cntx->SendError(kSyntaxErr); + + bool is_on = false; + bool optin = false; + if (parser.Check("ON").IgnoreCase()) { is_on = true; - } else if (state == "OFF") { - is_on = false; - } else { + } else if (!parser.Check("OFF").IgnoreCase()) { return cntx->SendError(kSyntaxErr); } - cntx->conn()->SetClientTrackingSwitch(is_on); + if (parser.HasNext()) { + if (parser.Check("OPTIN").IgnoreCase()) { + optin = true; + } else { + return cntx->SendError(kSyntaxErr); + } + } + + if (is_on) { + ++cntx->subscriptions; + } + cntx->conn_state.tracking_info_.SetClientTracking(is_on); + cntx->conn_state.tracking_info_.SetOptin(optin); return cntx->SendOk(); } +void ClientCaching(CmdArgList args, ConnectionContext* cntx) { + auto* rb = static_cast(cntx->reply_builder()); + if (!rb->IsResp3()) + return cntx->SendError( + "Client caching is currently not supported for RESP2. Please use RESP3."); + + if (args.size() != 1) { + return cntx->SendError(kSyntaxErr); + } + + CmdArgParser parser{args}; + if (parser.Check("YES").IgnoreCase()) { + bool is_multi = cntx->transaction && cntx->transaction->IsMulti(); + cntx->conn_state.tracking_info_.SetCachingSequenceNumber(is_multi); + } else if (parser.Check("NO").IgnoreCase()) { + cntx->conn_state.tracking_info_.ResetCachingSequenceNumber(); + } else { + return cntx->SendError(kSyntaxErr); + } + + cntx->SendOk(); +} + void ClientKill(CmdArgList args, absl::Span listeners, ConnectionContext* cntx) { std::function evaluator; @@ -1539,7 +1572,7 @@ void ServerFamily::SendInvalidationMessages() const { facade::ConnectionContext* fc = static_cast(conn)->cntx(); if (fc) { ConnectionContext* cntx = static_cast(fc); - if (cntx->conn()->IsTrackingOn()) { + if (cntx->conn_state.tracking_info_.IsTrackingOn()) { facade::Connection::InvalidationMessage x; x.invalidate_due_to_flush = true; cntx->conn()->SendInvalidationMessageAsync(x); @@ -1630,6 +1663,8 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return ClientTracking(sub_args, cntx); } else if (sub_cmd == "KILL") { return ClientKill(sub_args, absl::MakeSpan(listeners_), cntx); + } else if (sub_cmd == "CACHING") { + return ClientCaching(sub_args, cntx); } if (sub_cmd == "SETINFO") { diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 88708f4faca..2a8f4576cba 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -229,6 +229,117 @@ TEST_F(ServerFamilyTest, ClientTrackingReadKey) { EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); } +TEST_F(ServerFamilyTest, ClientTrackingOptin) { + Run({"HELLO", "3"}); + Run({"CLIENT", "TRACKING", "ON", "OPTIN"}); + + Run({"GET", "FOO"}); + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + + Run({"CLIENT", "CACHING", "YES"}); + // Start tracking once + Run({"GET", "FOO"}); + Run({"SET", "FOO", "20"}); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 1); + + Run({"GET", "BAR"}); + Run({"SET", "BAR", "20"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 1); + + // Start tracking once + Run({"CLIENT", "CACHING", "YES"}); + Run({"GET", "BAR"}); + Run({"SET", "BAR", "20"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); +} + +TEST_F(ServerFamilyTest, ClientTrackingMulti) { + Run({"HELLO", "3"}); + Run({"CLIENT", "TRACKING", "ON"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + Run({"SET", "FOO", "10"}); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + + Run({"MULTI"}); + auto resp = Run({"CLIENT", "TRACKING", "ON"}); + EXPECT_THAT(resp, ArgType(RespExpr::ERROR)); + Run({"DISCARD"}); +} + +TEST_F(ServerFamilyTest, ClientTrackingMultiOptin) { + Run({"HELLO", "3"}); + // Check stickiness + Run({"CLIENT", "TRACKING", "ON", "OPTIN"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"DISCARD"}); + + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 0); + + Run({"CLIENT", "CACHING", "YES"}); + Run({"MULTI"}); + Run({"GET", "FOO"}); + Run({"SET", "TMP", "10"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + Run({"SET", "FOO", "10"}); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + + // CACHING enclosed in MULTI + Run({"MULTI"}); + Run({"GET", "TMP"}); + Run({"GET", "TMP_TMP"}); + Run({"SET", "TMP", "10"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"GET", "FOO"}); + Run({"GET", "FOOBAR"}); + Run({"EXEC"}); + + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + Run({"SET", "TMP", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 2); + Run({"SET", "FOO", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 3); + Run({"SET", "FOOBAR", "10"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 4); + + // CACHING enclosed in MULTI, ON/OFF + Run({"MULTI"}); + Run({"GET", "TMP"}); + Run({"SET", "TMP", "10"}); + Run({"CLIENT", "CACHING", "YES"}); + Run({"GET", "FOO"}); + Run({"CLIENT", "CACHING", "NO"}); + Run({"GET", "BAR"}); + Run({"EXEC"}); + + EXPECT_EQ(InvalidationMessagesLen("IO0"), 4); + Run({"SET", "FOO", "10"}); + Run({"GET", "FOO"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 5); + Run({"SET", "BAR", "10"}); + Run({"GET", "BAR"}); + EXPECT_EQ(InvalidationMessagesLen("IO0"), 5); +} + TEST_F(ServerFamilyTest, ClientTrackingUpdateKey) { Run({"HELLO", "3"}); Run({"CLIENT", "TRACKING", "ON"}); diff --git a/src/server/transaction.cc b/src/server/transaction.cc index 343b0e04a06..1b4e994b99e 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -691,8 +691,12 @@ void Transaction::RunCallback(EngineShard* shard) { } // Log to journal only once the command finished running - if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) + if ((coordinator_state_ & COORD_CONCLUDING) || (multi_ && multi_->concluding)) { LogAutoJournalOnShard(shard, result); + if (tracking_cb_) { + tracking_cb_(this); + } + } } // TODO: For multi-transactions we should be able to deduce mode() at run-time based diff --git a/src/server/transaction.h b/src/server/transaction.h index a88ee43cdb7..ffd8abecde8 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -363,6 +363,10 @@ class Transaction { return shard_data_[SidToId(sid)].local_mask; } + void SetTrackingCallback(std::function f) { + tracking_cb_ = std::move(f); + } + // Remove once BZPOP is stabilized std::string DEBUGV18_BlockInfo() { return "claimed=" + std::to_string(blocking_barrier_.IsClaimed()) + @@ -644,6 +648,8 @@ class Transaction { ShardId coordinator_index = 0; } stats_; + std::function tracking_cb_; + private: struct TLTmpSpace { std::vector& GetShardIndex(unsigned size);