From 91fed71e699a64483e480a87265293241c6a42a9 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 18:29:52 +0300 Subject: [PATCH 01/24] fix(connection): Add ConnectionRef to replace pubsub wait token Signed-off-by: Vladislav Oleshko --- src/facade/dragonfly_connection.cc | 41 +++++++++++++++++++++++++----- src/facade/dragonfly_connection.h | 31 ++++++++++++++++++++++ src/server/channel_store.cc | 21 ++++----------- src/server/channel_store.h | 16 +++--------- src/server/conn_context.h | 3 +-- src/server/main_service.cc | 31 ++++++++-------------- 6 files changed, 85 insertions(+), 58 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 2cf2267e2be7..d45fa9606aec 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -313,6 +313,9 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections); + // Create dummy value for valid control block and then use aliasing contrutor to return `this` + self_ = {make_shared(nullptr), this}; + #ifdef DFLY_USE_SSL // Increment reference counter so Listener won't free the context while we're // still using it. @@ -612,6 +615,8 @@ void Connection::ConnectionFlow(FiberSocketBase* peer) { service_->OnClose(cc_.get()); + self_.reset(); // Drop manually, no more new references should be created + stats_->read_buf_capacity -= io_buf_.Capacity(); // Update num_replicas if this was a replica connection. @@ -827,9 +832,6 @@ void Connection::HandleMigrateRequest() { if (cc_->subscriptions == 0) { migration_request_ = nullptr; this->Migrate(dest); - - // We're now running in `dest` thread - queue_backpressure_ = &tl_queue_backpressure_; } DCHECK(dispatch_q_.empty()); @@ -1162,6 +1164,13 @@ void Connection::Migrate(util::fb2::ProactorBase* dest) { listener()->Migrate(this, dest); } +Connection::BorrowedRef Connection::Borrow(unsigned thread) { + DCHECK(self_); + DCHECK_GT(cc_->subscriptions, 0); + + return BorrowedRef{self_, queue_backpressure_, thread}; +} + void Connection::ShutdownThreadLocal() { pipeline_req_pool_.clear(); } @@ -1266,10 +1275,6 @@ void Connection::RecycleMessage(MessageHandle msg) { } } -void Connection::EnsureAsyncMemoryBudget() { - queue_backpressure_->EnsureBelowLimit(); -} - std::string Connection::LocalBindStr() const { if (socket_->IsUDS()) return "unix-domain-socket"; @@ -1333,4 +1338,26 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { }; } +Connection::BorrowedRef::BorrowedRef(std::shared_ptr ptr, + QueueBackpressure* backpressure, unsigned thread) + : ptr_{ptr}, backpressure_{backpressure}, thread_{thread} { +} + +unsigned Connection::BorrowedRef::Thread() const { + return thread_; +} + +Connection* Connection::BorrowedRef::Get() const { + DCHECK_EQ(ProactorBase::GetIndex(), int(thread_)); + return ptr_.lock().get(); +} + +bool Connection::BorrowedRef::EnsureMemoryBudget() const { + if (!ptr_.expired()) { + backpressure_->EnsureBelowLimit(); + return true; + } + return false; +} + } // namespace facade diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index c4ca814354d3..1e600b950cf4 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -53,6 +53,8 @@ class SinkReplyBuilder; // For pipelined requests, monitor and pubsub messages it uses // a separate dispatch queue that is processed on a separate fiber. class Connection : public util::Connection { + struct QueueBackpressure; + public: Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service); @@ -143,6 +145,29 @@ class Connection : public util::Connection { enum Phase { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES }; + // Stores an non-owning weak reference to a connection. + struct BorrowedRef { + public: + // Get residing thread of connection. Thread-safe. + unsigned Thread() const; + + // Get pointer to connection if still valid, nullptr if expired. + // Can only be called from connection's thread. + Connection* Get() const; + + // Ensure owner thread's memory budget. If expired, skips and returns false. Thread-safe. + bool EnsureMemoryBudget() const; + + private: + friend class Connection; + + BorrowedRef(std::shared_ptr ptr, QueueBackpressure* backpressure, unsigned thread); + + std::weak_ptr ptr_; + QueueBackpressure* backpressure_; + unsigned thread_; + }; + public: // Add PubMessage to dispatch queue. // Virtual because behavior is overridden in test_utils. @@ -176,6 +201,9 @@ class Connection : public util::Connection { // Migrate this connecton to a different thread. void Migrate(util::fb2::ProactorBase* dest); + // Borrow weak reference to connection. Can be called from any thread. + BorrowedRef Borrow(unsigned thread); + static void ShutdownThreadLocal(); bool IsCurrentlyDispatching() const; @@ -338,6 +366,9 @@ class Connection : public util::Connection { RespVec tmp_parse_args_; CmdArgVec tmp_cmd_vec_; + // Used to keep track of borrowed references. Does not really own itself + std::shared_ptr self_; + // Pointer to corresponding queue backpressure struct. // Needed for access from different threads by EnsureAsyncMemoryBudget(). QueueBackpressure* queue_backpressure_; diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 8eeae4ee1331..853e6e6bb8d5 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -25,18 +25,12 @@ bool Matches(string_view pattern, string_view channel) { } // namespace -ChannelStore::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid) - : conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) { -} - -ChannelStore::Subscriber::Subscriber(uint32_t tid) - : conn_cntx(nullptr), borrow_token(0), thread_id(tid) { +bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { + return ByThreadId(lhs, rhs.Thread()); } -bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { - if (lhs.thread_id == rhs.thread_id) - return (lhs.conn_cntx != nullptr) < (rhs.conn_cntx != nullptr); - return lhs.thread_id < rhs.thread_id; +bool ChannelStore::Subscriber::ByThreadId(const Subscriber& lhs, const unsigned thread) { + return lhs.Thread() < thread; } ChannelStore::UpdatablePointer::UpdatablePointer(const UpdatablePointer& other) { @@ -120,12 +114,7 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vectorreserve(out->size() + src.size()); for (const auto [cntx, thread_id] : src) { CHECK(cntx->conn_state.subscribe_info); - - Subscriber s(cntx, thread_id); - s.pattern = pattern; - s.borrow_token.Inc(); - - out->push_back(std::move(s)); + out->push_back({cntx->conn()->Borrow(thread_id), pattern}); } } diff --git a/src/server/channel_store.h b/src/server/channel_store.h index a60e6c59aeb1..66d1efb1fe67 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -7,6 +7,7 @@ #include +#include "facade/dragonfly_connection.h" #include "server/conn_context.h" namespace dfly { @@ -39,22 +40,11 @@ class ChannelStore { friend class ChannelStoreUpdater; public: - struct Subscriber { - Subscriber(ConnectionContext* cntx, uint32_t tid); - Subscriber(uint32_t tid); - - Subscriber(Subscriber&&) noexcept = default; - Subscriber& operator=(Subscriber&&) noexcept = default; - - Subscriber(const Subscriber&) = delete; - void operator=(const Subscriber&) = delete; - + struct Subscriber : public facade::Connection::BorrowedRef { // Sort by thread-id. Subscriber without owner comes first. static bool ByThread(const Subscriber& lhs, const Subscriber& rhs); + static bool ByThreadId(const Subscriber& lhs, const unsigned thread); - ConnectionContext* conn_cntx; - BlockingCounter borrow_token; // to keep connection alive - uint32_t thread_id; std::string pattern; // non-empty if registered via psubscribe }; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index e05f9970c0bc..a27d842e1669 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -117,8 +117,6 @@ struct ConnectionState { // TODO: to provide unique_strings across service. This will allow us to use string_view here. absl::flat_hash_set channels; absl::flat_hash_set patterns; - - BlockingCounter borrow_token{0}; }; struct ReplicationInfo { @@ -208,6 +206,7 @@ class ConnectionContext : public facade::ConnectionContext { subscriptions++; // required to support the monitoring monitor = enable; } + void SendSubscriptionChangedResponse(std::string_view action, std::optional topic, unsigned count); }; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 68cf13e6df9f..9ace4060ac75 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2081,12 +2081,12 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { // thus not adding any overhead to backpressure checks. optional last_thread; for (auto& sub : subscribers) { - DCHECK_LE(last_thread.value_or(0), sub.thread_id); - if (last_thread && *last_thread == sub.thread_id) // skip same thread + DCHECK_LE(last_thread.value_or(0), sub.Thread()); + if (last_thread && *last_thread == sub.Thread()) // skip same thread continue; - sub.conn_cntx->conn()->EnsureAsyncMemoryBudget(); - last_thread = sub.thread_id; + if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped + last_thread = sub.Thread(); } auto subscribers_ptr = make_shared(std::move(subscribers)); @@ -2096,14 +2096,13 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) { auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, - ChannelStore::Subscriber::ByThread); - - while (it != subscribers_ptr->end() && it->thread_id == idx) { - facade::Connection* conn = it->conn_cntx->conn(); - DCHECK(conn); - conn->SendPubMessageAsync( - {std::move(it->pattern), std::move(buf), channel.size(), msg.size()}); - it->borrow_token.Dec(); + ChannelStore::Subscriber::ByThreadId); + + while (it != subscribers_ptr->end() && it->Thread() == idx) { + if (auto* ptr = it->Get(); ptr) { + ptr->SendPubMessageAsync( + {std::move(it->pattern), std::move(buf), channel.size(), msg.size()}); + } it++; } }; @@ -2333,20 +2332,12 @@ void Service::OnClose(facade::ConnectionContext* cntx) { if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB if (!conn_state.subscribe_info->channels.empty()) { - auto token = conn_state.subscribe_info->borrow_token; server_cntx->UnsubscribeAll(false); - - // Check that all borrowers finished processing. - // token is increased in channel_slice (the publisher side). - token.Wait(); } if (conn_state.subscribe_info) { DCHECK(!conn_state.subscribe_info->patterns.empty()); - - auto token = conn_state.subscribe_info->borrow_token; server_cntx->PUnsubscribeAll(false); - token.Wait(); // Same as above } DCHECK(!conn_state.subscribe_info); From 0300a1552e66e92f3b474904162eaccca0115357 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 21:54:20 +0300 Subject: [PATCH 02/24] fix: add check to migrate Signed-off-by: Vladislav Oleshko --- src/facade/dragonfly_connection.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index d45fa9606aec..f995d3f98695 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1159,6 +1159,7 @@ void Connection::Migrate(util::fb2::ProactorBase* dest) { // connections CHECK(!cc_->async_dispatch); CHECK_EQ(cc_->subscriptions, 0); // are bound to thread local caches + CHECK_EQ(self_.use_count(), 1u); // references cache our thread and backpressure CHECK(!dispatch_fb_.IsJoinable()); // can't move once it started listener()->Migrate(this, dest); From 21ac976c2503b43836d84842c25a6eb701d062b4 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 22:02:28 +0300 Subject: [PATCH 03/24] fix: rename to weakref Signed-off-by: Vladislav Oleshko --- src/facade/dragonfly_connection.cc | 16 ++++++++-------- src/facade/dragonfly_connection.h | 9 +++++---- src/server/channel_store.h | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index f995d3f98695..c837c488b323 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -313,7 +313,7 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections); - // Create dummy value for valid control block and then use aliasing contrutor to return `this` + // Create dummy value for valid control block and then use aliasing construtor to return `this` self_ = {make_shared(nullptr), this}; #ifdef DFLY_USE_SSL @@ -1165,11 +1165,11 @@ void Connection::Migrate(util::fb2::ProactorBase* dest) { listener()->Migrate(this, dest); } -Connection::BorrowedRef Connection::Borrow(unsigned thread) { +Connection::WeakRef Connection::Borrow(unsigned thread) { DCHECK(self_); DCHECK_GT(cc_->subscriptions, 0); - return BorrowedRef{self_, queue_backpressure_, thread}; + return WeakRef{self_, queue_backpressure_, thread}; } void Connection::ShutdownThreadLocal() { @@ -1339,21 +1339,21 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { }; } -Connection::BorrowedRef::BorrowedRef(std::shared_ptr ptr, - QueueBackpressure* backpressure, unsigned thread) +Connection::WeakRef::WeakRef(std::shared_ptr ptr, QueueBackpressure* backpressure, + unsigned thread) : ptr_{ptr}, backpressure_{backpressure}, thread_{thread} { } -unsigned Connection::BorrowedRef::Thread() const { +unsigned Connection::WeakRef::Thread() const { return thread_; } -Connection* Connection::BorrowedRef::Get() const { +Connection* Connection::WeakRef::Get() const { DCHECK_EQ(ProactorBase::GetIndex(), int(thread_)); return ptr_.lock().get(); } -bool Connection::BorrowedRef::EnsureMemoryBudget() const { +bool Connection::WeakRef::EnsureMemoryBudget() const { if (!ptr_.expired()) { backpressure_->EnsureBelowLimit(); return true; diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 1e600b950cf4..63352b25c12a 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -145,8 +145,9 @@ class Connection : public util::Connection { enum Phase { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES }; - // Stores an non-owning weak reference to a connection. - struct BorrowedRef { + // Weak reference to a connection, invalidated upon connection close. + // Used to dispatch async operations for the connection without worrying about pointer lifetime. + struct WeakRef { public: // Get residing thread of connection. Thread-safe. unsigned Thread() const; @@ -161,7 +162,7 @@ class Connection : public util::Connection { private: friend class Connection; - BorrowedRef(std::shared_ptr ptr, QueueBackpressure* backpressure, unsigned thread); + WeakRef(std::shared_ptr ptr, QueueBackpressure* backpressure, unsigned thread); std::weak_ptr ptr_; QueueBackpressure* backpressure_; @@ -202,7 +203,7 @@ class Connection : public util::Connection { void Migrate(util::fb2::ProactorBase* dest); // Borrow weak reference to connection. Can be called from any thread. - BorrowedRef Borrow(unsigned thread); + WeakRef Borrow(unsigned thread); static void ShutdownThreadLocal(); diff --git a/src/server/channel_store.h b/src/server/channel_store.h index 66d1efb1fe67..ae9178b0b684 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -40,7 +40,7 @@ class ChannelStore { friend class ChannelStoreUpdater; public: - struct Subscriber : public facade::Connection::BorrowedRef { + struct Subscriber : public facade::Connection::WeakRef { // Sort by thread-id. Subscriber without owner comes first. static bool ByThread(const Subscriber& lhs, const Subscriber& rhs); static bool ByThreadId(const Subscriber& lhs, const unsigned thread); From 814f73db9bc98304a8d90f867a82d9b73a1799dc Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 23:16:21 +0300 Subject: [PATCH 04/24] fix: add back migration backpressure update --- src/facade/dragonfly_connection.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index c837c488b323..5058ada75954 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -832,6 +832,9 @@ void Connection::HandleMigrateRequest() { if (cc_->subscriptions == 0) { migration_request_ = nullptr; this->Migrate(dest); + + // We're now running in `dest` thread + queue_backpressure_ = &tl_queue_backpressure_; } DCHECK(dispatch_q_.empty()); From 93007107a627f6805c36bf0b9f72211ab26ac46e Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 23:46:35 +0300 Subject: [PATCH 05/24] fix: use monostate --- src/facade/dragonfly_connection.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 5058ada75954..c173169a4be9 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -313,8 +313,9 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections); - // Create dummy value for valid control block and then use aliasing construtor to return `this` - self_ = {make_shared(nullptr), this}; + // Create shared_ptr with empty value and associate it with `this` pointer (aliasing constructor). + // We use it for reference counting and accessing `this` (without managing it). + self_ = {std::make_shared(std::monostate{}), this}; #ifdef DFLY_USE_SSL // Increment reference counter so Listener won't free the context while we're From 402fb31b574734221dd7b066ff38a394ccd582f9 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Mon, 27 Nov 2023 23:53:03 +0300 Subject: [PATCH 06/24] fix: rebase on new helio GetPoolIndex Signed-off-by: Vladislav Oleshko --- src/facade/dragonfly_connection.cc | 6 +++--- src/facade/dragonfly_connection.h | 2 +- src/server/channel_store.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index c173169a4be9..f62eafccb16d 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1169,11 +1169,11 @@ void Connection::Migrate(util::fb2::ProactorBase* dest) { listener()->Migrate(this, dest); } -Connection::WeakRef Connection::Borrow(unsigned thread) { +Connection::WeakRef Connection::Borrow() { DCHECK(self_); DCHECK_GT(cc_->subscriptions, 0); - return WeakRef{self_, queue_backpressure_, thread}; + return WeakRef(self_, queue_backpressure_, socket_->proactor()->GetPoolIndex()); } void Connection::ShutdownThreadLocal() { @@ -1353,7 +1353,7 @@ unsigned Connection::WeakRef::Thread() const { } Connection* Connection::WeakRef::Get() const { - DCHECK_EQ(ProactorBase::GetIndex(), int(thread_)); + DCHECK_EQ(ProactorBase::me()->GetPoolIndex(), int(thread_)); return ptr_.lock().get(); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 63352b25c12a..57c0a68299f1 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -203,7 +203,7 @@ class Connection : public util::Connection { void Migrate(util::fb2::ProactorBase* dest); // Borrow weak reference to connection. Can be called from any thread. - WeakRef Borrow(unsigned thread); + WeakRef Borrow(); static void ShutdownThreadLocal(); diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 853e6e6bb8d5..10a782ee355b 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -114,7 +114,7 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vectorreserve(out->size() + src.size()); for (const auto [cntx, thread_id] : src) { CHECK(cntx->conn_state.subscribe_info); - out->push_back({cntx->conn()->Borrow(thread_id), pattern}); + out->push_back({cntx->conn()->Borrow(), pattern}); } } From e2eb0b1442461e856a6803b67298eb997300daf7 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 28 Nov 2023 09:52:01 +0300 Subject: [PATCH 07/24] fix: fix internal compiler error?? Signed-off-by: Vladislav Oleshko --- src/server/channel_store.cc | 3 ++- src/server/channel_store.h | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 10a782ee355b..4307d5151059 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -114,7 +114,8 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vectorreserve(out->size() + src.size()); for (const auto [cntx, thread_id] : src) { CHECK(cntx->conn_state.subscribe_info); - out->push_back({cntx->conn()->Borrow(), pattern}); + Subscriber sub{cntx->conn()->Borrow(), pattern}; + out->push_back(std::move(sub)); } } diff --git a/src/server/channel_store.h b/src/server/channel_store.h index ae9178b0b684..49e650d46b4b 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -41,6 +41,10 @@ class ChannelStore { public: struct Subscriber : public facade::Connection::WeakRef { + Subscriber(WeakRef ref, const std::string& pattern) + : facade::Connection::WeakRef(std::move(ref)), pattern(pattern) { + } + // Sort by thread-id. Subscriber without owner comes first. static bool ByThread(const Subscriber& lhs, const Subscriber& rhs); static bool ByThreadId(const Subscriber& lhs, const unsigned thread); From e35dbac2a5484f36933ad03b4d21913e85d224db Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 28 Nov 2023 17:23:45 +0300 Subject: [PATCH 08/24] fix: fix test framework --- src/server/test_utils.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index e9fed427289c..aa1cd4f4cb8f 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -66,6 +66,7 @@ static vector SplitLines(const std::string& src) { TestConnection::TestConnection(Protocol protocol, io::StringSink* sink) : facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) { cc_.reset(new dfly::ConnectionContext(sink_, this)); + SetSocket(ProactorBase::me()->CreateSocket()); } void TestConnection::SendPubMessageAsync(PubMessage pmsg) { From 2cdb8a9bf262e3eace1174844bd3376cbfbeeac4 Mon Sep 17 00:00:00 2001 From: Vladislav Oleshko Date: Tue, 28 Nov 2023 23:51:59 +0300 Subject: [PATCH 09/24] fix: add comment about EnsureMemoryBudget --- src/facade/dragonfly_connection.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index f62eafccb16d..70de2f3c3c49 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1358,7 +1358,10 @@ Connection* Connection::WeakRef::Get() const { } bool Connection::WeakRef::EnsureMemoryBudget() const { + // Simple optimization: If a connection was closed, don't check memory budget. if (!ptr_.expired()) { + // We don't rely on the connection ptr staying valid because we only access + // the threads backpressure backpressure_->EnsureBelowLimit(); return true; } From fbd1748218aa4dbd7380b14c1c6adfbcd39cf04e Mon Sep 17 00:00:00 2001 From: Yue Li Date: Wed, 29 Nov 2023 12:33:05 -0800 Subject: [PATCH 10/24] support client side tracking using resp3 --- src/facade/dragonfly_connection.cc | 22 +++++++ src/facade/dragonfly_connection.h | 38 +++++++++++- src/facade/reply_builder.cc | 9 ++- src/facade/reply_builder.h | 4 +- src/server/db_slice.cc | 98 ++++++++++++++++++++++++++---- src/server/db_slice.h | 30 +++++++-- src/server/main_service.cc | 88 +++++++++++++++++++++++++++ src/server/server_family.cc | 35 ++++++++++- src/server/server_family.h | 2 + src/server/server_family_test.cc | 7 +++ 10 files changed, 312 insertions(+), 21 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 70de2f3c3c49..d6892a2aa2b1 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -150,6 +150,7 @@ struct Connection::DispatchOperations { void operator()(const AclUpdateMessage& msg); void operator()(const MigrationRequestMessage& msg); void operator()(CheckpointMessage msg); + void operator()(const InvalidationMessage& msg); template void operator()(unique_ptr& ptr) { operator()(*ptr.get()); @@ -215,6 +216,9 @@ size_t Connection::MessageHandle::UsedMemory() const { size_t operator()(const CheckpointMessage& msg) { return 0; // no access to internal type, memory usage negligible } + size_t operator()(const InvalidationMessage& msg) { + return 0; + } }; return sizeof(MessageHandle) + visit(MessageSize{}, this->handle); @@ -278,10 +282,24 @@ void Connection::DispatchOperations::operator()(const MigrationRequestMessage& m // no-op } + void Connection::DispatchOperations::operator()(CheckpointMessage msg) { msg.bc.Dec(); } +void Connection::DispatchOperations::operator()(const InvalidationMessage& msg) { + RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; + rbuilder->SetResp3(true); + rbuilder->StartCollection(2, facade::RedisReplyBuilder::CollectionType::PUSH); + rbuilder->SendBulkString("invalidate"); + if (msg.invalidate_due_to_flush) { + rbuilder->SendNull(); + } else { + std::vector keys{msg.key}; + rbuilder->SendStringArr(keys); + } +} + Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service) : io_buf_(kMinReadSize), http_listener_(http_listener), ctx_(ctx), service_(service), name_{} { @@ -1192,6 +1210,10 @@ void Connection::SendPubMessageAsync(PubMessage msg) { SendAsync({PubMessagePtr{new (ptr) PubMessage{move(msg)}, MessageDeleter{}}}); } +void Connection::SendInvalidationMessageAsync(InvalidationMessage msg) { + SendAsync({std::move(msg)}); +} + void Connection::SendMonitorMessageAsync(string msg) { SendAsync({MonitorMessage{move(msg)}}); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 57c0a68299f1..c29633ea930d 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -77,7 +77,23 @@ class Connection : public util::Connection { size_t message_len); }; - // Pipeline message, accumulated command to be executed. + + struct InvalidationMessage { + std::string_view key; + bool invalidate_due_to_flush = false; + }; + + struct MonitorMessage : public std::string {}; + + struct AclUpdateMessage { + std::vector username; + std::vector categories; + std::vector> commands; + }; + + struct MigrationRequestMessage {}; + + // Pipeline message, accumulated command to be execute.d struct PipelineMessage { PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { } @@ -136,7 +152,7 @@ class Connection : public util::Connection { bool IsPubMsg() const; std::variant + MigrationRequestMessage, CheckpointMessage, InvalidationMessage> handle; }; @@ -174,6 +190,9 @@ class Connection : public util::Connection { // Virtual because behavior is overridden in test_utils. virtual void SendPubMessageAsync(PubMessage); + // Add InvalidationMessage to dispatch queue. + void SendInvalidationMessageAsync(InvalidationMessage); + // Add monitor message to dispatch queue. void SendMonitorMessageAsync(std::string); @@ -249,6 +268,18 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); + void EnableTracking() { + tracking_enabled_ = true; + } + + void DisableTracking() { + tracking_enabled_ = false; + } + + bool IsTrackingOn() const { + return tracking_enabled_; + } + protected: void OnShutdown() override; void OnPreMigrateThread() override; @@ -384,6 +415,9 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; + + // whether client tracking is enabled + bool tracking_enabled_ = false; }; } // namespace facade diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index bb7ab82d2823..426fd309449f 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -30,6 +30,8 @@ constexpr char kCRLF[] = "\r\n"; constexpr char kErrPref[] = "-ERR "; constexpr char kSimplePref[] = "+"; +constexpr char kRET[] = "$1\r\n\r\r\n"; + constexpr unsigned kConvFlags = DoubleToStringConverter::UNIQUE_ZERO | DoubleToStringConverter::EMIT_POSITIVE_EXPONENT_SIGN; @@ -289,9 +291,14 @@ void RedisReplyBuilder::SendProtocolError(std::string_view str) { SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error"); } +void RedisReplyBuilder::SendRET() { + // iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; + iovec v[1] = {IoVec(kRET)}; + Send(v, ABSL_ARRAYSIZE(v)); +} + void RedisReplyBuilder::SendSimpleString(std::string_view str) { iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; - Send(v, ABSL_ARRAYSIZE(v)); } diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 091b068d0fe7..56dcb00a410e 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -154,6 +154,7 @@ class SinkReplyBuilder { bool HasReplied() const; virtual size_t UsedMemory() const; + std::string batch_; protected: void SendRaw(std::string_view str); // Sends raw without any formatting. @@ -164,7 +165,6 @@ class SinkReplyBuilder { void StartAggregate(); void StopAggregate(); - std::string batch_; ::io::Sink* sink_; std::error_code ec_; @@ -247,6 +247,8 @@ class RedisReplyBuilder : public SinkReplyBuilder { static char* FormatDouble(double val, char* dest, unsigned dest_len); + void SendRET(); + protected: struct WrappedStrSpan : public StrSpan { size_t Size() const; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 39d6110832ba..e498f209bb3f 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -15,6 +15,7 @@ extern "C" { #include "server/journal/journal.h" #include "server/server_state.h" #include "server/tiered_storage.h" +#include "util/fibers/proactor_base.h" ABSL_FLAG(bool, enable_heartbeat_eviction, true, "Enable eviction during heartbeat when memory is under pressure."); @@ -219,9 +220,13 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotspotBuckets& eb, PrimeT return 0; } + std::string tmp; + std::string_view key = last_slot_it->first.GetSlice(&tmp); + DbTable* table = db_slice_->GetDBTable(cntx_.db_index); PerformDeletion(last_slot_it, db_slice_->shard_owner(), table); ++evicted_; + db_slice_->SendInvalidationTrackingMessage(key); } me->ShiftRight(bucket_it); @@ -320,7 +325,7 @@ void DbSlice::Reserve(DbIndex db_ind, size_t key_size) { db->prime.Reserve(key_size); } -auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) const +auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) -> OpResult { auto it = FindExt(cntx, key).first; @@ -334,7 +339,7 @@ auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) return it; } -pair DbSlice::FindExt(const Context& cntx, string_view key) const { +pair DbSlice::FindExt(const Context& cntx, string_view key) { pair res; if (!IsDbValid(cntx.db_index)) @@ -547,15 +552,15 @@ bool DbSlice::Del(DbIndex db_ind, PrimeIterator it) { auto& db = db_arr_[db_ind]; auto obj_type = it->second.ObjType(); + string tmp; + string_view key = it->first.GetSlice(&tmp); if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) { - string tmp; - string_view key = it->first.GetSlice(&tmp); DbContext cntx{db_ind, GetCurrentTimeMs()}; doc_del_cb_(key, cntx, it->second); } PerformDeletion(it, shard_owner(), db.get()); - + SendInvalidationTrackingMessage(key); return true; } @@ -571,6 +576,7 @@ void DbSlice::FlushSlotsFb(const SlotSet& slot_ids) { SlotId sid = ClusterConfig::KeySlot(key); if (slot_ids.contains(sid) && it.GetVersion() < next_version) { PerformDeletion(it, shard_owner(), db_arr_[0].get()); + SendInvalidationTrackingMessage(key); } return true; }; @@ -616,7 +622,10 @@ void DbSlice::FlushDb(DbIndex db_ind) { if (db_ptr->stats.tiered_entries > 0) { for (auto it = db_ptr->prime.begin(); it != db_ptr->prime.end(); ++it) { if (it->second.IsExternal()) { + std::string tmp; + std::string_view key = it->first.GetSlice(&tmp); PerformDeletion(it, shard_owner(), db_ptr.get()); + SendInvalidationTrackingMessage(key); } } } @@ -652,6 +661,9 @@ void DbSlice::FlushDb(DbIndex db_ind) { for (auto& db : all_dbs) db.reset(); mi_heap_collect(ServerState::tlocal()->data_heap(), true); + + // clear all the tracking table. + client_tracking_map_.clear(); } void DbSlice::AddExpire(DbIndex db_ind, PrimeIterator main_it, uint64_t at) { @@ -973,10 +985,11 @@ void DbSlice::PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key, if (ClusterConfig::IsEnabled()) { db.slots_stats[ClusterConfig::KeySlot(key)].total_writes += 1; } + + SendInvalidationTrackingMessage(key); } -pair DbSlice::ExpireIfNeeded(const Context& cntx, - PrimeIterator it) const { +pair DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) { DCHECK(it->second.HasExpire()); auto& db = db_arr_[cntx.db_index]; @@ -992,24 +1005,23 @@ pair DbSlice::ExpireIfNeeded(const Context& cntx, return make_pair(it, expire_it); string tmp_key_buf; - string_view tmp_key; + string_view tmp_key = it->first.GetSlice(&tmp_key_buf); // Replicate expiry if (auto journal = owner_->journal(); journal) { - tmp_key = it->first.GetSlice(&tmp_key_buf); RecordExpiry(cntx.db_index, tmp_key); } auto obj_type = it->second.ObjType(); if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) { - if (tmp_key.empty()) - tmp_key = it->first.GetSlice(&tmp_key_buf); doc_del_cb_(tmp_key, cntx, it->second); } PerformDeletion(it, expire_it, shard_owner(), db.get()); ++events_.expired_keys; + SendInvalidationTrackingMessage(tmp_key); + return make_pair(PrimeIterator{}, ExpireIterator{}); } @@ -1169,6 +1181,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes PerformDeletion(evict_it, shard_owner(), db_table.get()); ++evicted; + SendInvalidationTrackingMessage(key); + used_memory_after = owner_->UsedMemory(); // returns when whichever condition is met first if ((evicted == max_eviction_per_hb) || @@ -1231,6 +1245,7 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t return current < used_memory_start ? used_memory_start - current : 0; }; + std::string tmp; for (unsigned i = 0; !evict_succeeded && i < kNumStashBuckets; ++i) { unsigned stash_bid = i + PrimeTable::Segment_t::kNumBuckets; const auto& bucket = segment->GetBucket(stash_bid); @@ -1246,8 +1261,13 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t if (evict_it == it || evict_it->first.IsSticky()) continue; + string_view key = evict_it->first.GetSlice(&tmp); + PerformDeletion(evict_it, shard_owner(), table); ++evicted; + + SendInvalidationTrackingMessage(key); + if (freed_memory_fun() > memory_to_free) { evict_succeeded = true; break; @@ -1272,9 +1292,12 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t if (evict_it == it || evict_it->first.IsSticky()) continue; + string_view key = evict_it->first.GetSlice(&tmp); PerformDeletion(evict_it, shard_owner(), table); ++evicted; + SendInvalidationTrackingMessage(key); + if (freed_memory_fun() > memory_to_free) { evict_succeeded = true; break; @@ -1338,4 +1361,57 @@ void DbSlice::ResetUpdateEvents() { events_.update = 0; } +// todo: perhaps we need to limit the total number of entry in the tracking +// table so that we have a way to limit the amount of memory used for +// client tracking +void DbSlice::TrackKeys(ConnectionContext* cntx, int32_t tid, + const std::vector& keys) { + DVLOG(2) << "Start tracking the following keys for thread ID: " << tid + << ", client ID: " << cntx->conn()->GetClientId(); + for (auto key : keys) { + // std::string_view k = key; + std::string k{key.begin(), key.end()}; + DVLOG(2) << " " << k; + if (client_tracking_map_.find(k) == client_tracking_map_.end()) { + std::pair p{cntx, tid}; + absl::flat_hash_set, Hash> tracker_set{p}; + client_tracking_map_.insert({k, tracker_set}); + } else { + std::pair p{cntx, tid}; + client_tracking_map_[k].insert(p); + } + } +} + +void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { + std::string k{key.begin(), key.end()}; + if (client_tracking_map_.find(k) != client_tracking_map_.end()) { + // notify all the clients. + auto& client_set = client_tracking_map_[k]; + DVLOG(2) << "Garbage collect clients that are no longer tracking... "; + auto is_not_tracking = [](std::pair p) { + return (!p.first->conn()->IsTrackingOn()); + }; + absl::erase_if(client_set, is_not_tracking); + DVLOG(2) << "Number of clients left: " << client_set.size(); + + if (!client_set.empty()) { + auto cb = [key, client_set](unsigned idx, util::ProactorBase*) { + for (auto it = client_set.begin(); it != client_set.end(); ++it) { + if ((unsigned int)it->second != idx) + continue; + facade::Connection* conn = it->first->conn(); + DCHECK(conn); + facade::Connection::InvalidationMessage x{key}; + conn->SendInvalidationMessageAsync(x); + return; + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); + } + // remove this key from the tracking table as the key no longer exists + client_tracking_map_.erase(k); + } +} + } // namespace dfly diff --git a/src/server/db_slice.h b/src/server/db_slice.h index f828e057b29a..57b9c99284c7 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -4,6 +4,7 @@ #pragma once +#include "facade/dragonfly_connection.h" #include "facade/op_status.h" #include "server/common.h" #include "server/conn_context.h" @@ -147,11 +148,10 @@ class DbSlice { return ExpirePeriod{time_ms - expire_base_[0]}; } - OpResult Find(const Context& cntx, std::string_view key, - unsigned req_obj_type) const; + OpResult Find(const Context& cntx, std::string_view key, unsigned req_obj_type); // Returns (value, expire) dict entries if key exists, null if it does not exist or has expired. - std::pair FindExt(const Context& cntx, std::string_view key) const; + std::pair FindExt(const Context& cntx, std::string_view key); // Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. // If multiple keys are found, returns the first index in the ArgSlice. @@ -269,8 +269,7 @@ class DbSlice { // Check whether 'it' has not expired. Returns it if it's still valid. Otherwise, erases it // from both tables and return PrimeIterator{}. - std::pair ExpireIfNeeded(const Context& cntx, - PrimeIterator it) const; + std::pair ExpireIfNeeded(const Context& cntx, PrimeIterator it); // Iterate over all expire table entries and delete expired. void ExpireAllIfNeeded(); @@ -334,6 +333,13 @@ class DbSlice { expire_allowed_ = is_allowed; } + // Start tracking keys for the client with client_id + void TrackKeys(ConnectionContext*, int32_t, const std::vector&); + + // Send invalidatoin message when a key being tracked is updated/deleted. + // A connection that has been closed will be garbage collected along the way. + void SendInvalidationTrackingMessage(std::string_view key); + private: // Releases a single key. `key` must have been normalized by GetLockKey(). void ReleaseNormalized(IntentLock::Mode m, DbIndex db_index, std::string_view key, @@ -384,6 +390,20 @@ class DbSlice { // Registered by shard indices on when first document index is created. DocDeletionCallback doc_del_cb_; + + struct Hash { + size_t operator()(const std::pair& p) const { + // return std::hash()(p.second); + return std::hash()(p.first->conn()->GetClientId()); + } + }; + + // maps keys to the IDs of the clients that are tracking this key. + // absl::flat_hash_map, Hash> > client_tracking_map_; + absl::flat_hash_map, Hash>> + client_tracking_map_; }; } // namespace dfly diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 9ace4060ac75..5e6ba19ed576 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1040,6 +1040,13 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions"); } +OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, uint32_t tid, + vector& keys, CmdArgList args) { + auto& db_slice = op_args.shard->db_slice(); + db_slice.TrackKeys(cntx, tid, keys); + return true; +} + void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); DCHECK_NE(0u, shard_set->size()) << "Init was not called"; @@ -1140,6 +1147,83 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->reply_builder()->CloseConnection(); } + // if this is a read command, and client tracking has enabled, + // start tracking updates to the keys in this read command + // notify the client when there is update, see PostUpdate() in db_slice.cc + if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn()) { + OpResult key_index_res = DetermineKeys(cid, args_no_cmd); + if (!key_index_res) + return (*cntx)->SendError(key_index_res.status()); + + const auto& key_index = *key_index_res; + vector keys_to_track; + for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { + string_view key = ArgS(args_no_cmd, i); + keys_to_track.push_back(key); + } + + // let's pass thread id and connection to db_slice for tracking + int32_t tid = util::ProactorBase::GetIndex(); + + // uint32_t client_id = dfly_cntx->conn()->GetClientId(); + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); + }; + + // OpResult remember_key_result = + dfly_cntx->transaction->ScheduleSingleHopT(cb); + + //(*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); + // std::string inval = "invalidate"; + //(*dfly_cntx)->SendBulkString(inval); + //(*dfly_cntx)->SendStringArr(keys_to_track); + //(*dfly_cntx)->SendRET(); + + //(*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); + //(*dfly_cntx)->SendBulkString(inval); + //(*dfly_cntx)->SendStringArr(keys_to_track); + } + + /* + if (is_read_cmd) { + + OpResult key_index_res = DetermineKeys(cid, args_no_cmd); + if (!key_index_res) { + } + + const auto& key_index = *key_index_res; + std::vector keys; + // Iterate keys and check to which slot they belong. + for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { + string_view key = ArgS(args_no_cmd, i); + keys.push_back(key); + } + + uint32_t client_id = dfly_cntx->conn()->GetClientId(); + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpRememberKey(t->GetOpArgs(shard), client_id, keys, args); + }; + + + + //OpStatus status = + // dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); + OpResult remember_key_result = dfly_cntx->transaction->ScheduleSingleHopT(cb); + + + // send PUSH message to client to invalidate the key + (*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); + std::string inval = "invalidate"; + (*dfly_cntx)->SendBulkString(inval); + (*dfly_cntx)->SendStringArr(keys); + + } + */ + + uint64_t end_ns = ProactorBase::GetMonotonicTimeNs(); + request_latency_usec.IncBy(cid->name(), (end_ns - start_ns) / 1000); + if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } @@ -2346,6 +2430,10 @@ void Service::OnClose(facade::ConnectionContext* cntx) { DeactivateMonitoring(server_cntx); server_family_.OnClose(server_cntx); + + // disable client tracking. + if (server_cntx->conn()->IsTrackingOn()) + server_cntx->conn()->DisableTracking(); } string Service::GetContextInfo(facade::ConnectionContext* cntx) { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index e327b13241f0..f1b98ee08e91 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1164,10 +1164,29 @@ string GetPassword() { return ""; } +void ServerFamily::SendInvalidationMessages() const { + // send invalidation message (caused by flushdb) to all the clients which + // turned on client tracking + auto cb = [](unsigned thread_index, util::Connection* conn) { + facade::ConnectionContext* fc = static_cast(conn)->cntx(); + if (fc) { + ConnectionContext* cntx = static_cast(fc); + if (cntx->conn()->IsTrackingOn()) { + facade::Connection::InvalidationMessage x; + x.invalidate_due_to_flush = true; + cntx->conn()->SendInvalidationMessageAsync(x); + } + } + }; + for (auto* listener : listeners_) { + listener->TraverseConnections(cb); + } +} + void ServerFamily::FlushDb(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); Drakarys(cntx->transaction, cntx->transaction->GetDbIndex()); - + SendInvalidationMessages(); cntx->reply_builder()->SendOk(); } @@ -1179,6 +1198,7 @@ void ServerFamily::FlushAll(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); Drakarys(cntx->transaction, DbSlice::kDbAll); + SendInvalidationMessages(); (*cntx)->SendOk(); } @@ -1243,6 +1263,19 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendOk(); } + if (sub_cmd == "TRACKING" && args.size() == 2) { + ToUpper(&args[1]); + string_view switch_state = ArgS(args, 1); + if (switch_state == "ON") { + cntx->conn()->EnableTracking(); + return (*cntx)->SendOk(); + } else if (switch_state == "OFF") { + cntx->conn()->DisableTracking(); + // todo: the client id in tracking table will be garbage collected. + return (*cntx)->SendOk(); + } + } + LOG_FIRST_N(ERROR, 10) << "Subcommand " << sub_cmd << " not supported"; return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType); } diff --git a/src/server/server_family.h b/src/server/server_family.h index 08de412f6232..389483286161 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -256,6 +256,8 @@ class ServerFamily { void SnapshotScheduling(); + void SendInvalidationMessages() const; + Fiber snapshot_schedule_fb_; Future load_result_; diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index e0a12c26956e..481638f8d747 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -199,5 +199,12 @@ TEST_F(ServerFamilyTest, ClientPause) { Run({"set", "key", "value2"}); EXPECT_GT((absl::Now() - start), absl::Milliseconds(50)); } + +TEST_F(ServerFamilyTest, ClientTracking) { + // client tracking only works for RESP3 + auto resp = Run({"hello", "3"}); + resp = RUN({"client", "tracking", "on"}); + EXPECT_THAT(resp.GetString(), "OK"); +} } // namespace dfly From 31fd09782773dc6692d91d07f2c6a44ca64ab69b Mon Sep 17 00:00:00 2001 From: Yue Li Date: Wed, 29 Nov 2023 22:21:47 -0800 Subject: [PATCH 11/24] support verbatim string for resp3 and let INFO command uses it. --- src/facade/reply_builder.cc | 32 ++++++++++++++++++++++++++------ src/facade/reply_builder.h | 4 +++- src/server/server_family.cc | 2 +- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 426fd309449f..0fc09a19fc46 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -291,12 +291,6 @@ void RedisReplyBuilder::SendProtocolError(std::string_view str) { SendError(absl::StrCat("-ERR Protocol error: ", str), "protocol_error"); } -void RedisReplyBuilder::SendRET() { - // iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; - iovec v[1] = {IoVec(kRET)}; - Send(v, ABSL_ARRAYSIZE(v)); -} - void RedisReplyBuilder::SendSimpleString(std::string_view str) { iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; Send(v, ABSL_ARRAYSIZE(v)); @@ -338,6 +332,32 @@ void RedisReplyBuilder::SendBulkString(std::string_view str) { return Send(v, ABSL_ARRAYSIZE(v)); } +void RedisReplyBuilder::SendVerbatimString(std::string_view str, VerbatimFormat format) { + if (!is_resp3_) + return SendBulkString(str); + + char tmp[absl::numbers_internal::kFastToBufferSize + 3]; + tmp[0] = '='; + // + 4 because format is three byte, and need to be followed by a ":" + char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size() + 4), tmp + 1); + *next++ = '\r'; + *next++ = '\n'; + + std::string_view lenpref{tmp, size_t(next - tmp)}; + + std::string_view format_str; + if (format == VerbatimFormat::TXT) + format_str = "txt:"; + else if (format == VerbatimFormat::MARKDOWN) + format_str = "mkd:"; + else { + DVLOG(1) << "Unknown verbatim reply format: " << format; + return; + } + iovec v[4] = {IoVec(lenpref), IoVec(format_str), IoVec(str), IoVec(kCRLF)}; + return Send(v, ABSL_ARRAYSIZE(v)); +} + void RedisReplyBuilder::SendLong(long num) { string str = absl::StrCat(":", num, kCRLF); SendRaw(str); diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 56dcb00a410e..cd266b2e2b4a 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -212,6 +212,8 @@ class RedisReplyBuilder : public SinkReplyBuilder { public: enum CollectionType { ARRAY, SET, MAP, PUSH }; + enum VerbatimFormat { TXT, MARKDOWN }; + using StrSpan = std::variant, absl::Span>; RedisReplyBuilder(::io::Sink* stream); @@ -247,7 +249,7 @@ class RedisReplyBuilder : public SinkReplyBuilder { static char* FormatDouble(double val, char* dest, unsigned dest_len); - void SendRET(); + void SendVerbatimString(std::string_view str, VerbatimFormat format = TXT); protected: struct WrappedStrSpan : public StrSpan { diff --git a/src/server/server_family.cc b/src/server/server_family.cc index f1b98ee08e91..aed1957bec27 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1877,7 +1877,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) { append("cluster_enabled", ClusterConfig::IsEnabledOrEmulated()); } - (*cntx)->SendBulkString(info); + (*cntx)->SendVerbatimString(info); } void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) { From 0cbdcfe1dabb8f4d40050258e6f10fe93c88ca23 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Wed, 29 Nov 2023 23:26:30 -0800 Subject: [PATCH 12/24] tidy --- src/server/main_service.cc | 49 -------------------------------------- 1 file changed, 49 deletions(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 5e6ba19ed576..a777749803d7 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1169,58 +1169,9 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) auto cb = [&](Transaction* t, EngineShard* shard) { return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); }; - - // OpResult remember_key_result = dfly_cntx->transaction->ScheduleSingleHopT(cb); - - //(*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); - // std::string inval = "invalidate"; - //(*dfly_cntx)->SendBulkString(inval); - //(*dfly_cntx)->SendStringArr(keys_to_track); - //(*dfly_cntx)->SendRET(); - - //(*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); - //(*dfly_cntx)->SendBulkString(inval); - //(*dfly_cntx)->SendStringArr(keys_to_track); } - /* - if (is_read_cmd) { - - OpResult key_index_res = DetermineKeys(cid, args_no_cmd); - if (!key_index_res) { - } - - const auto& key_index = *key_index_res; - std::vector keys; - // Iterate keys and check to which slot they belong. - for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { - string_view key = ArgS(args_no_cmd, i); - keys.push_back(key); - } - - uint32_t client_id = dfly_cntx->conn()->GetClientId(); - - auto cb = [&](Transaction* t, EngineShard* shard) { - return OpRememberKey(t->GetOpArgs(shard), client_id, keys, args); - }; - - - - //OpStatus status = - // dfly_cntx->transaction->InitByArgs(dfly_cntx->conn_state.db_index, args_no_cmd); - OpResult remember_key_result = dfly_cntx->transaction->ScheduleSingleHopT(cb); - - - // send PUSH message to client to invalidate the key - (*dfly_cntx)->StartCollection(2, RedisReplyBuilder::CollectionType::PUSH); - std::string inval = "invalidate"; - (*dfly_cntx)->SendBulkString(inval); - (*dfly_cntx)->SendStringArr(keys); - - } - */ - uint64_t end_ns = ProactorBase::GetMonotonicTimeNs(); request_latency_usec.IncBy(cid->name(), (end_ns - start_ns) / 1000); From 9992a8562452e58a4e0df50d7b984cbdd9089cad Mon Sep 17 00:00:00 2001 From: Yue Li Date: Wed, 29 Nov 2023 23:34:33 -0800 Subject: [PATCH 13/24] clean --- src/facade/reply_builder.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index cd266b2e2b4a..d7f975b930c6 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -165,6 +165,7 @@ class SinkReplyBuilder { void StartAggregate(); void StopAggregate(); + std::string batch_; ::io::Sink* sink_; std::error_code ec_; From 203f9ed44042a8e090b397a5d2e62c704021b71e Mon Sep 17 00:00:00 2001 From: Yue Li Date: Thu, 30 Nov 2023 04:22:23 -0800 Subject: [PATCH 14/24] Refurbish transaction before its reuse --- src/server/main_service.cc | 3 +++ src/server/transaction.cc | 6 ++++++ src/server/transaction.h | 2 ++ 3 files changed, 11 insertions(+) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index a777749803d7..11cde342fdcb 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1165,10 +1165,13 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // let's pass thread id and connection to db_slice for tracking int32_t tid = util::ProactorBase::GetIndex(); + DVLOG(2) << "Ready to schedul transaction"; + // uint32_t client_id = dfly_cntx->conn()->GetClientId(); auto cb = [&](Transaction* t, EngineShard* shard) { return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); }; + dfly_cntx->transaction->Refurbish(); dfly_cntx->transaction->ScheduleSingleHopT(cb); } diff --git a/src/server/transaction.cc b/src/server/transaction.cc index ebf29b03409d..3e50c167029b 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -906,6 +906,12 @@ void Transaction::Conclude() { Execute(std::move(cb), true); } +void Transaction::Refurbish() { + txid_ = 0; + coordinator_state_ = 0; + cb_ptr_ = nullptr; +} + void Transaction::EnableShard(ShardId sid) { unique_shard_cnt_ = 1; unique_shard_id_ = sid; diff --git a/src/server/transaction.h b/src/server/transaction.h index 6d76e24cf4bd..64d175e28809 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -323,6 +323,8 @@ class Transaction { // Utility to run a single hop on a no-key command static void RunOnceAsCommand(const CommandId* cid, RunnableType cb); + void Refurbish(); + private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt { From 15e62b38fae576ce42dd3f96ce4cc1d7ffc41201 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Thu, 30 Nov 2023 11:41:39 -0800 Subject: [PATCH 15/24] if hrandfield returns one element, just send bulkstring --- src/server/hset_family.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 851dce09ae27..4e3fe4484aa5 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1135,7 +1135,10 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); if (result) { - (*cntx)->SendStringArr(*result); + if (result->size() == 1) + (*cntx)->SendBulkString(result->front()); + else + (*cntx)->SendStringArr(*result); } else if (result.status() == OpStatus::KEY_NOTFOUND) { (*cntx)->SendNull(); } else { From 8e0fc9ae3c0ee263ae30abaabf9d5e24eb5ec2e1 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Thu, 30 Nov 2023 17:25:14 -0800 Subject: [PATCH 16/24] only track keys of its own shard. --- src/server/main_service.cc | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 11cde342fdcb..8681939c90ef 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1151,24 +1151,12 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // start tracking updates to the keys in this read command // notify the client when there is update, see PostUpdate() in db_slice.cc if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn()) { - OpResult key_index_res = DetermineKeys(cid, args_no_cmd); - if (!key_index_res) - return (*cntx)->SendError(key_index_res.status()); - - const auto& key_index = *key_index_res; - vector keys_to_track; - for (unsigned i = key_index.start; i < key_index.end; i += key_index.step) { - string_view key = ArgS(args_no_cmd, i); - keys_to_track.push_back(key); - } - // let's pass thread id and connection to db_slice for tracking int32_t tid = util::ProactorBase::GetIndex(); - - DVLOG(2) << "Ready to schedul transaction"; - // uint32_t client_id = dfly_cntx->conn()->GetClientId(); auto cb = [&](Transaction* t, EngineShard* shard) { + auto keys = t->GetShardArgs(shard->shard_id()); + vector keys_to_track{keys.begin(), keys.end()}; return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); }; dfly_cntx->transaction->Refurbish(); From 7f44f1dab256775a4ae0f8c5682abba44147f1c7 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Fri, 1 Dec 2023 03:35:05 -0800 Subject: [PATCH 17/24] review comments --- src/server/db_slice.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index e498f209bb3f..61f4f0e5f2ec 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1402,8 +1402,7 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { continue; facade::Connection* conn = it->first->conn(); DCHECK(conn); - facade::Connection::InvalidationMessage x{key}; - conn->SendInvalidationMessageAsync(x); + conn->SendInvalidationMessageAsync({key}); return; } }; From 32bd6b347d8f17c5bb94d1379ba52e44030a746e Mon Sep 17 00:00:00 2001 From: Yue Li Date: Fri, 1 Dec 2023 04:55:57 -0800 Subject: [PATCH 18/24] rebase against vlad's fork (safe-ptr branch) --- src/facade/dragonfly_connection.h | 11 ----------- src/facade/reply_builder.h | 1 - src/server/main_service.cc | 5 +---- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index c29633ea930d..ce792458e7ff 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -77,22 +77,11 @@ class Connection : public util::Connection { size_t message_len); }; - struct InvalidationMessage { std::string_view key; bool invalidate_due_to_flush = false; }; - struct MonitorMessage : public std::string {}; - - struct AclUpdateMessage { - std::vector username; - std::vector categories; - std::vector> commands; - }; - - struct MigrationRequestMessage {}; - // Pipeline message, accumulated command to be execute.d struct PipelineMessage { PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index d7f975b930c6..af5d9e5020c6 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -154,7 +154,6 @@ class SinkReplyBuilder { bool HasReplied() const; virtual size_t UsedMemory() const; - std::string batch_; protected: void SendRaw(std::string_view str); // Sends raw without any formatting. diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 8681939c90ef..79ee3607e241 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1152,7 +1152,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // notify the client when there is update, see PostUpdate() in db_slice.cc if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn()) { // let's pass thread id and connection to db_slice for tracking - int32_t tid = util::ProactorBase::GetIndex(); + int32_t tid = ProactorBase::me()->GetPoolIndex(); // uint32_t client_id = dfly_cntx->conn()->GetClientId(); auto cb = [&](Transaction* t, EngineShard* shard) { auto keys = t->GetShardArgs(shard->shard_id()); @@ -1163,9 +1163,6 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->transaction->ScheduleSingleHopT(cb); } - uint64_t end_ns = ProactorBase::GetMonotonicTimeNs(); - request_latency_usec.IncBy(cid->name(), (end_ns - start_ns) / 1000); - if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } From ed6c75edf8f10123f2c8fe012be6127f80af67c3 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Fri, 1 Dec 2023 05:27:58 -0800 Subject: [PATCH 19/24] fix typo --- src/server/server_family_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index 481638f8d747..185ca9a821f8 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -199,11 +199,11 @@ TEST_F(ServerFamilyTest, ClientPause) { Run({"set", "key", "value2"}); EXPECT_GT((absl::Now() - start), absl::Milliseconds(50)); } - + TEST_F(ServerFamilyTest, ClientTracking) { // client tracking only works for RESP3 auto resp = Run({"hello", "3"}); - resp = RUN({"client", "tracking", "on"}); + resp = Run({"client", "tracking", "on"}); EXPECT_THAT(resp.GetString(), "OK"); } From 002bc5899b06990ce2a350ac12c566dfcfe4ada0 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Sun, 3 Dec 2023 03:47:29 -0800 Subject: [PATCH 20/24] use weakref and fix hrandfield output --- src/facade/dragonfly_connection.cc | 12 ++++++++++-- src/facade/dragonfly_connection.h | 16 ++++++++++------ src/server/db_slice.cc | 24 ++++++++++++++---------- src/server/db_slice.h | 13 +++++-------- src/server/hset_family.cc | 5 ++++- src/server/main_service.cc | 2 +- 6 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 08eefc423a16..9123740ca594 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -282,7 +282,6 @@ void Connection::DispatchOperations::operator()(const MigrationRequestMessage& m // no-op } - void Connection::DispatchOperations::operator()(CheckpointMessage msg) { msg.bc.Dec(); } @@ -1357,6 +1356,15 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { migration_request_ = dest; } +void Connection::EnableTracking() { + tracking_enabled_ = true; + cc_->subscriptions++; +} + +void Connection::DisableTracking() { + tracking_enabled_ = false; +} + Connection::MemoryUsage Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) + dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) + @@ -1384,7 +1392,7 @@ unsigned Connection::WeakRef::Thread() const { } Connection* Connection::WeakRef::Get() const { - DCHECK_EQ(ProactorBase::me()->GetPoolIndex(), int(thread_)); + // DCHECK_EQ(ProactorBase::me()->GetPoolIndex(), int(thread_)); return ptr_.lock().get(); } diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 14e0de700e56..80ccb90af94c 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -164,6 +164,12 @@ class Connection : public util::Connection { // Ensure owner thread's memory budget. If expired, skips and returns false. Thread-safe. bool EnsureMemoryBudget() const; + bool operator==(const WeakRef rhs) const { + auto rhs_ptr = rhs.ptr_.lock(); + auto lhs_ptr = ptr_.lock(); + return (rhs_ptr == lhs_ptr); + }; + private: friend class Connection; @@ -257,13 +263,11 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); - void EnableTracking() { - tracking_enabled_ = true; - } + // Set the flag to enable client side tracking + void EnableTracking(); - void DisableTracking() { - tracking_enabled_ = false; - } + // Set the flag to disable client side tracking + void DisableTracking(); bool IsTrackingOn() const { return tracking_enabled_; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index 7362bec04353..cc777701f153 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -1364,20 +1364,22 @@ void DbSlice::ResetUpdateEvents() { // todo: perhaps we need to limit the total number of entry in the tracking // table so that we have a way to limit the amount of memory used for // client tracking -void DbSlice::TrackKeys(ConnectionContext* cntx, int32_t tid, +void DbSlice::TrackKeys(facade::Connection::WeakRef conn, int32_t tid, const std::vector& keys) { DVLOG(2) << "Start tracking the following keys for thread ID: " << tid - << ", client ID: " << cntx->conn()->GetClientId(); + << ", client ID: " << conn.Get()->GetClientId(); + if (conn.Get() == nullptr) + return; + for (auto key : keys) { // std::string_view k = key; std::string k{key.begin(), key.end()}; DVLOG(2) << " " << k; + std::pair p{conn, tid}; if (client_tracking_map_.find(k) == client_tracking_map_.end()) { - std::pair p{cntx, tid}; - absl::flat_hash_set, Hash> tracker_set{p}; + absl::flat_hash_set, Hash> tracker_set{p}; client_tracking_map_.insert({k, tracker_set}); } else { - std::pair p{cntx, tid}; client_tracking_map_[k].insert(p); } } @@ -1389,10 +1391,10 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { // notify all the clients. auto& client_set = client_tracking_map_[k]; DVLOG(2) << "Garbage collect clients that are no longer tracking... "; - auto is_not_tracking = [](std::pair p) { - return (!p.first->conn()->IsTrackingOn()); + auto is_closed_or_not_tracking = [](std::pair p) { + return ((p.first.Get() == nullptr) || (!p.first.Get()->IsTrackingOn())); }; - absl::erase_if(client_set, is_not_tracking); + absl::erase_if(client_set, is_closed_or_not_tracking); DVLOG(2) << "Number of clients left: " << client_set.size(); if (!client_set.empty()) { @@ -1400,8 +1402,10 @@ void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { for (auto it = client_set.begin(); it != client_set.end(); ++it) { if ((unsigned int)it->second != idx) continue; - facade::Connection* conn = it->first->conn(); - DCHECK(conn); + + facade::Connection* conn = it->first.Get(); + if (conn == nullptr) + continue; conn->SendInvalidationMessageAsync({key}); return; } diff --git a/src/server/db_slice.h b/src/server/db_slice.h index f5570cea2d9c..652b44d8b6ff 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -334,7 +334,8 @@ class DbSlice { } // Start tracking keys for the client with client_id - void TrackKeys(ConnectionContext*, int32_t, const std::vector&); + // void TrackKeys(ConnectionContext*, int32_t, const std::vector&); + void TrackKeys(facade::Connection::WeakRef, int32_t, const std::vector&); // Send invalidatoin message when a key being tracked is updated/deleted. // A connection that has been closed will be garbage collected along the way. @@ -392,17 +393,13 @@ class DbSlice { DocDeletionCallback doc_del_cb_; struct Hash { - size_t operator()(const std::pair& p) const { - // return std::hash()(p.second); - return std::hash()(p.first->conn()->GetClientId()); + size_t operator()(const std::pair& p) const { + return std::hash()(p.first.Get()->GetClientId()); } }; - // maps keys to the IDs of the clients that are tracking this key. - // absl::flat_hash_map, Hash> > client_tracking_map_; absl::flat_hash_map, Hash>> + absl::flat_hash_set, Hash>> client_tracking_map_; }; diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 4e3fe4484aa5..09d36c2ce4f5 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1140,7 +1140,10 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { else (*cntx)->SendStringArr(*result); } else if (result.status() == OpStatus::KEY_NOTFOUND) { - (*cntx)->SendNull(); + if (args.size() == 1) + (*cntx)->SendNull(); + else + (*cntx)->SendEmptyArray(); } else { (*cntx)->SendError(result.status()); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index ae4aba7144cb..4538865b80f3 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1043,7 +1043,7 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, uint32_t tid, vector& keys, CmdArgList args) { auto& db_slice = op_args.shard->db_slice(); - db_slice.TrackKeys(cntx, tid, keys); + db_slice.TrackKeys(cntx->conn()->Borrow(), tid, keys); return true; } From a17eed1e945426a4a21751b75dd6ca5b16b47fbf Mon Sep 17 00:00:00 2001 From: Yue Li Date: Sun, 3 Dec 2023 04:16:05 -0800 Subject: [PATCH 21/24] allow client tracking only when resp3 is used. --- src/facade/reply_builder.cc | 4 ++++ src/facade/reply_builder.h | 2 ++ src/server/server_family.cc | 22 +++++++++++++--------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 0fc09a19fc46..bc438c6dde8f 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -263,6 +263,10 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len) RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) { } +bool RedisReplyBuilder::IsResp3() const { + return is_resp3_; +} + void RedisReplyBuilder::SetResp3(bool is_resp3) { is_resp3_ = is_resp3; } diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index af5d9e5020c6..72048d65c9aa 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -218,6 +218,8 @@ class RedisReplyBuilder : public SinkReplyBuilder { RedisReplyBuilder(::io::Sink* stream); + bool IsResp3() const; + void SetResp3(bool is_resp3); void SendError(std::string_view str, std::string_view type = {}) override; diff --git a/src/server/server_family.cc b/src/server/server_family.cc index aed1957bec27..6b1e1c07813b 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1264,15 +1264,19 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { } if (sub_cmd == "TRACKING" && args.size() == 2) { - ToUpper(&args[1]); - string_view switch_state = ArgS(args, 1); - if (switch_state == "ON") { - cntx->conn()->EnableTracking(); - return (*cntx)->SendOk(); - } else if (switch_state == "OFF") { - cntx->conn()->DisableTracking(); - // todo: the client id in tracking table will be garbage collected. - return (*cntx)->SendOk(); + if ((*cntx)->IsResp3()) { + ToUpper(&args[1]); + string_view switch_state = ArgS(args, 1); + if (switch_state == "ON") { + cntx->conn()->EnableTracking(); + return (*cntx)->SendOk(); + } else if (switch_state == "OFF") { + cntx->conn()->DisableTracking(); + return (*cntx)->SendOk(); + } + } else { + LOG_FIRST_N(ERROR, 10) + << "Client tracking is currently not supported in RESP2, please use RESP3."; } } From bfdd39990f7e347ebbed5d882c8436b97136d362 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Mon, 4 Dec 2023 17:56:31 -0800 Subject: [PATCH 22/24] fix QUIT's property, makes WeakRef remember client id. --- src/facade/dragonfly_connection.cc | 6 +++++- src/facade/dragonfly_connection.h | 3 +++ src/server/db_slice.h | 2 +- src/server/main_service.cc | 15 ++++++++++++--- src/server/server_family.cc | 2 +- 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 9123740ca594..0a86fb0b8006 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -1384,7 +1384,7 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { Connection::WeakRef::WeakRef(std::shared_ptr ptr, QueueBackpressure* backpressure, unsigned thread) - : ptr_{ptr}, backpressure_{backpressure}, thread_{thread} { + : ptr_{ptr}, backpressure_{backpressure}, thread_{thread}, client_id_{ptr->GetClientId()} { } unsigned Connection::WeakRef::Thread() const { @@ -1396,6 +1396,10 @@ Connection* Connection::WeakRef::Get() const { return ptr_.lock().get(); } +uint32_t Connection::WeakRef::GetClientId() const { + return client_id_; +} + bool Connection::WeakRef::EnsureMemoryBudget() const { // Simple optimization: If a connection was closed, don't check memory budget. if (!ptr_.expired()) { diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 80ccb90af94c..fe25854e94fa 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -161,6 +161,8 @@ class Connection : public util::Connection { // Can only be called from connection's thread. Connection* Get() const; + uint32_t GetClientId() const; + // Ensure owner thread's memory budget. If expired, skips and returns false. Thread-safe. bool EnsureMemoryBudget() const; @@ -178,6 +180,7 @@ class Connection : public util::Connection { std::weak_ptr ptr_; QueueBackpressure* backpressure_; unsigned thread_; + uint32_t client_id_; }; public: diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 652b44d8b6ff..b8630394dddb 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -394,7 +394,7 @@ class DbSlice { struct Hash { size_t operator()(const std::pair& p) const { - return std::hash()(p.first.Get()->GetClientId()); + return std::hash()(p.first.GetClientId()); } }; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 4538865b80f3..ea8145629fee 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1159,8 +1159,14 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) vector keys_to_track{keys.begin(), keys.end()}; return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); }; - dfly_cntx->transaction->Refurbish(); - dfly_cntx->transaction->ScheduleSingleHopT(cb); + + if (dfly_cntx->transaction == nullptr) { + DVLOG(2) << "transaction is a nullptr"; + } else { + DVLOG(2) << "transaction is not a nullptr"; + dfly_cntx->transaction->Refurbish(); + dfly_cntx->transaction->ScheduleSingleHopT(cb); + } } if (!dispatching_in_multi) { @@ -1478,6 +1484,9 @@ void Service::Quit(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendOk(); using facade::SinkReplyBuilder; + if (cntx->conn()->IsTrackingOn()) + cntx->conn()->DisableTracking(); + SinkReplyBuilder* builder = cntx->reply_builder(); builder->CloseConnection(); @@ -2432,7 +2441,7 @@ void Service::Register(CommandRegistry* registry) { using CI = CommandId; registry->StartFamily(); *registry - << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, acl::kQuit}.HFUNC(Quit) + << CI{"QUIT", CO::FAST, 1, 0, 0, acl::kQuit}.HFUNC(Quit) << CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, acl::kMulti}.HFUNC(Multi) << CI{"WATCH", CO::LOADING, -2, 1, -1, acl::kWatch}.HFUNC(Watch) << CI{"UNWATCH", CO::LOADING, 1, 0, 0, acl::kUnwatch}.HFUNC(Unwatch) diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 6b1e1c07813b..d99ed1568366 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1328,7 +1328,7 @@ void ServerFamily::ClientList(CmdArgList args, ConnectionContext* cntx) { string result = absl::StrJoin(client_info, "\n"); result.append("\n"); - return (*cntx)->SendBulkString(result); + return (*cntx)->SendVerbatimString(result); } void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) { From 3991c92b371598b72bb6d360bc82196eb5619e39 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Mon, 4 Dec 2023 19:51:25 -0800 Subject: [PATCH 23/24] fix reply format of hrandfield when count is negative --- src/server/hset_family.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 09d36c2ce4f5..229a6d1e4625 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1135,7 +1135,7 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); if (result) { - if (result->size() == 1) + if ((result->size() == 1) && (args.size() == 1)) (*cntx)->SendBulkString(result->front()); else (*cntx)->SendStringArr(*result); From 508bf2181c502e373870996298d35ef4458e40a6 Mon Sep 17 00:00:00 2001 From: Yue Li Date: Tue, 5 Dec 2023 16:13:23 -0800 Subject: [PATCH 24/24] fix reply format of hrandfield when count and withvalues are used. --- src/server/hset_family.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 229a6d1e4625..f1b3952d75bc 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1134,11 +1134,21 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result) { if ((result->size() == 1) && (args.size() == 1)) (*cntx)->SendBulkString(result->front()); - else - (*cntx)->SendStringArr(*result); + else { + if (((*cntx)->IsResp3()) && (args.size() == 3)) { // has withvalues + (*cntx)->StartArray(result->size() / 2); + for (unsigned int i = 0; i < result->size() / 2; ++i) { + StringVec sv{(*result)[i * 2], (*result)[i * 2 + 1]}; + (*cntx)->SendStringArr(sv); + } + } else { + (*cntx)->SendStringArr(*result); + } + } } else if (result.status() == OpStatus::KEY_NOTFOUND) { if (args.size() == 1) (*cntx)->SendNull();