diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index 030d44b7d7f..0a86fb0b800 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -150,6 +150,7 @@ struct Connection::DispatchOperations { void operator()(const AclUpdateMessage& msg); void operator()(const MigrationRequestMessage& msg); void operator()(CheckpointMessage msg); + void operator()(const InvalidationMessage& msg); template void operator()(unique_ptr& ptr) { operator()(*ptr.get()); @@ -215,6 +216,9 @@ size_t Connection::MessageHandle::UsedMemory() const { size_t operator()(const CheckpointMessage& msg) { return 0; // no access to internal type, memory usage negligible } + size_t operator()(const InvalidationMessage& msg) { + return 0; + } }; return sizeof(MessageHandle) + visit(MessageSize{}, this->handle); @@ -282,6 +286,19 @@ void Connection::DispatchOperations::operator()(CheckpointMessage msg) { msg.bc.Dec(); } +void Connection::DispatchOperations::operator()(const InvalidationMessage& msg) { + RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder; + rbuilder->SetResp3(true); + rbuilder->StartCollection(2, facade::RedisReplyBuilder::CollectionType::PUSH); + rbuilder->SendBulkString("invalidate"); + if (msg.invalidate_due_to_flush) { + rbuilder->SendNull(); + } else { + std::vector keys{msg.key}; + rbuilder->SendStringArr(keys); + } +} + Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service) : io_buf_(kMinReadSize), http_listener_(http_listener), ctx_(ctx), service_(service), name_{} { @@ -313,6 +330,10 @@ Connection::Connection(Protocol protocol, util::HttpListenerBase* http_listener, migration_enabled_ = absl::GetFlag(FLAGS_migrate_connections); + // Create shared_ptr with empty value and associate it with `this` pointer (aliasing constructor). + // We use it for reference counting and accessing `this` (without managing it). + self_ = {std::make_shared(std::monostate{}), this}; + #ifdef DFLY_USE_SSL // Increment reference counter so Listener won't free the context while we're // still using it. @@ -1168,11 +1189,19 @@ void Connection::Migrate(util::fb2::ProactorBase* dest) { // connections CHECK(!cc_->async_dispatch); CHECK_EQ(cc_->subscriptions, 0); // are bound to thread local caches + CHECK_EQ(self_.use_count(), 1u); // references cache our thread and backpressure CHECK(!dispatch_fb_.IsJoinable()); // can't move once it started listener()->Migrate(this, dest); } +Connection::WeakRef Connection::Borrow() { + DCHECK(self_); + DCHECK_GT(cc_->subscriptions, 0); + + return WeakRef(self_, queue_backpressure_, socket_->proactor()->GetPoolIndex()); +} + void Connection::ShutdownThreadLocal() { pipeline_req_pool_.clear(); } @@ -1189,6 +1218,10 @@ void Connection::SendPubMessageAsync(PubMessage msg) { SendAsync({PubMessagePtr{new (ptr) PubMessage{move(msg)}, MessageDeleter{}}}); } +void Connection::SendInvalidationMessageAsync(InvalidationMessage msg) { + SendAsync({std::move(msg)}); +} + void Connection::SendMonitorMessageAsync(string msg) { SendAsync({MonitorMessage{move(msg)}}); } @@ -1277,10 +1310,6 @@ void Connection::RecycleMessage(MessageHandle msg) { } } -void Connection::EnsureAsyncMemoryBudget() { - queue_backpressure_->EnsureBelowLimit(); -} - std::string Connection::LocalBindStr() const { if (socket_->IsUDS()) return "unix-domain-socket"; @@ -1327,6 +1356,15 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) { migration_request_ = dest; } +void Connection::EnableTracking() { + tracking_enabled_ = true; + cc_->subscriptions++; +} + +void Connection::DisableTracking() { + tracking_enabled_ = false; +} + Connection::MemoryUsage Connection::GetMemoryUsage() const { size_t mem = sizeof(*this) + dfly::HeapSize(dispatch_q_) + dfly::HeapSize(name_) + dfly::HeapSize(tmp_parse_args_) + dfly::HeapSize(tmp_cmd_vec_) + @@ -1344,6 +1382,35 @@ Connection::MemoryUsage Connection::GetMemoryUsage() const { }; } +Connection::WeakRef::WeakRef(std::shared_ptr ptr, QueueBackpressure* backpressure, + unsigned thread) + : ptr_{ptr}, backpressure_{backpressure}, thread_{thread}, client_id_{ptr->GetClientId()} { +} + +unsigned Connection::WeakRef::Thread() const { + return thread_; +} + +Connection* Connection::WeakRef::Get() const { + // DCHECK_EQ(ProactorBase::me()->GetPoolIndex(), int(thread_)); + return ptr_.lock().get(); +} + +uint32_t Connection::WeakRef::GetClientId() const { + return client_id_; +} + +bool Connection::WeakRef::EnsureMemoryBudget() const { + // Simple optimization: If a connection was closed, don't check memory budget. + if (!ptr_.expired()) { + // We don't rely on the connection ptr staying valid because we only access + // the threads backpressure + backpressure_->EnsureBelowLimit(); + return true; + } + return false; +} + void Connection::DecreaseStatsOnClose() { stats_->read_buf_capacity -= io_buf_.Capacity(); diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 93a1c2c2e7e..fe25854e94f 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -53,6 +53,8 @@ class SinkReplyBuilder; // For pipelined requests, monitor and pubsub messages it uses // a separate dispatch queue that is processed on a separate fiber. class Connection : public util::Connection { + struct QueueBackpressure; + public: Connection(Protocol protocol, util::HttpListenerBase* http_listener, SSL_CTX* ctx, ServiceInterface* service); @@ -75,7 +77,12 @@ class Connection : public util::Connection { size_t message_len); }; - // Pipeline message, accumulated command to be executed. + struct InvalidationMessage { + std::string_view key; + bool invalidate_due_to_flush = false; + }; + + // Pipeline message, accumulated command to be execute.d struct PipelineMessage { PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) { } @@ -134,7 +141,7 @@ class Connection : public util::Connection { bool IsPubMsg() const; std::variant + MigrationRequestMessage, CheckpointMessage, InvalidationMessage> handle; }; @@ -143,11 +150,47 @@ class Connection : public util::Connection { enum Phase { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES }; + // Weak reference to a connection, invalidated upon connection close. + // Used to dispatch async operations for the connection without worrying about pointer lifetime. + struct WeakRef { + public: + // Get residing thread of connection. Thread-safe. + unsigned Thread() const; + + // Get pointer to connection if still valid, nullptr if expired. + // Can only be called from connection's thread. + Connection* Get() const; + + uint32_t GetClientId() const; + + // Ensure owner thread's memory budget. If expired, skips and returns false. Thread-safe. + bool EnsureMemoryBudget() const; + + bool operator==(const WeakRef rhs) const { + auto rhs_ptr = rhs.ptr_.lock(); + auto lhs_ptr = ptr_.lock(); + return (rhs_ptr == lhs_ptr); + }; + + private: + friend class Connection; + + WeakRef(std::shared_ptr ptr, QueueBackpressure* backpressure, unsigned thread); + + std::weak_ptr ptr_; + QueueBackpressure* backpressure_; + unsigned thread_; + uint32_t client_id_; + }; + public: // Add PubMessage to dispatch queue. // Virtual because behavior is overridden in test_utils. virtual void SendPubMessageAsync(PubMessage); + // Add InvalidationMessage to dispatch queue. + void SendInvalidationMessageAsync(InvalidationMessage); + // Add monitor message to dispatch queue. void SendMonitorMessageAsync(std::string); @@ -176,6 +219,9 @@ class Connection : public util::Connection { // Migrate this connecton to a different thread. void Migrate(util::fb2::ProactorBase* dest); + // Borrow weak reference to connection. Can be called from any thread. + WeakRef Borrow(); + static void ShutdownThreadLocal(); bool IsCurrentlyDispatching() const; @@ -220,6 +266,16 @@ class Connection : public util::Connection { // Connections will migrate at most once, and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest); + // Set the flag to enable client side tracking + void EnableTracking(); + + // Set the flag to disable client side tracking + void DisableTracking(); + + bool IsTrackingOn() const { + return tracking_enabled_; + } + protected: void OnShutdown() override; void OnPreMigrateThread() override; @@ -340,6 +396,9 @@ class Connection : public util::Connection { RespVec tmp_parse_args_; CmdArgVec tmp_cmd_vec_; + // Used to keep track of borrowed references. Does not really own itself + std::shared_ptr self_; + // Pointer to corresponding queue backpressure struct. // Needed for access from different threads by EnsureAsyncMemoryBudget(). QueueBackpressure* queue_backpressure_; @@ -354,6 +413,9 @@ class Connection : public util::Connection { // Per-thread queue backpressure structs. static thread_local QueueBackpressure tl_queue_backpressure_; + + // whether client tracking is enabled + bool tracking_enabled_ = false; }; } // namespace facade diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index bb7ab82d282..bc438c6dde8 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -30,6 +30,8 @@ constexpr char kCRLF[] = "\r\n"; constexpr char kErrPref[] = "-ERR "; constexpr char kSimplePref[] = "+"; +constexpr char kRET[] = "$1\r\n\r\r\n"; + constexpr unsigned kConvFlags = DoubleToStringConverter::UNIQUE_ZERO | DoubleToStringConverter::EMIT_POSITIVE_EXPONENT_SIGN; @@ -261,6 +263,10 @@ char* RedisReplyBuilder::FormatDouble(double val, char* dest, unsigned dest_len) RedisReplyBuilder::RedisReplyBuilder(::io::Sink* sink) : SinkReplyBuilder(sink) { } +bool RedisReplyBuilder::IsResp3() const { + return is_resp3_; +} + void RedisReplyBuilder::SetResp3(bool is_resp3) { is_resp3_ = is_resp3; } @@ -291,7 +297,6 @@ void RedisReplyBuilder::SendProtocolError(std::string_view str) { void RedisReplyBuilder::SendSimpleString(std::string_view str) { iovec v[3] = {IoVec(kSimplePref), IoVec(str), IoVec(kCRLF)}; - Send(v, ABSL_ARRAYSIZE(v)); } @@ -331,6 +336,32 @@ void RedisReplyBuilder::SendBulkString(std::string_view str) { return Send(v, ABSL_ARRAYSIZE(v)); } +void RedisReplyBuilder::SendVerbatimString(std::string_view str, VerbatimFormat format) { + if (!is_resp3_) + return SendBulkString(str); + + char tmp[absl::numbers_internal::kFastToBufferSize + 3]; + tmp[0] = '='; + // + 4 because format is three byte, and need to be followed by a ":" + char* next = absl::numbers_internal::FastIntToBuffer(uint32_t(str.size() + 4), tmp + 1); + *next++ = '\r'; + *next++ = '\n'; + + std::string_view lenpref{tmp, size_t(next - tmp)}; + + std::string_view format_str; + if (format == VerbatimFormat::TXT) + format_str = "txt:"; + else if (format == VerbatimFormat::MARKDOWN) + format_str = "mkd:"; + else { + DVLOG(1) << "Unknown verbatim reply format: " << format; + return; + } + iovec v[4] = {IoVec(lenpref), IoVec(format_str), IoVec(str), IoVec(kCRLF)}; + return Send(v, ABSL_ARRAYSIZE(v)); +} + void RedisReplyBuilder::SendLong(long num) { string str = absl::StrCat(":", num, kCRLF); SendRaw(str); diff --git a/src/facade/reply_builder.h b/src/facade/reply_builder.h index 091b068d0fe..72048d65c9a 100644 --- a/src/facade/reply_builder.h +++ b/src/facade/reply_builder.h @@ -212,10 +212,14 @@ class RedisReplyBuilder : public SinkReplyBuilder { public: enum CollectionType { ARRAY, SET, MAP, PUSH }; + enum VerbatimFormat { TXT, MARKDOWN }; + using StrSpan = std::variant, absl::Span>; RedisReplyBuilder(::io::Sink* stream); + bool IsResp3() const; + void SetResp3(bool is_resp3); void SendError(std::string_view str, std::string_view type = {}) override; @@ -247,6 +251,8 @@ class RedisReplyBuilder : public SinkReplyBuilder { static char* FormatDouble(double val, char* dest, unsigned dest_len); + void SendVerbatimString(std::string_view str, VerbatimFormat format = TXT); + protected: struct WrappedStrSpan : public StrSpan { size_t Size() const; diff --git a/src/server/channel_store.cc b/src/server/channel_store.cc index 8eeae4ee133..4307d515105 100644 --- a/src/server/channel_store.cc +++ b/src/server/channel_store.cc @@ -25,18 +25,12 @@ bool Matches(string_view pattern, string_view channel) { } // namespace -ChannelStore::Subscriber::Subscriber(ConnectionContext* cntx, uint32_t tid) - : conn_cntx(cntx), borrow_token(cntx->conn_state.subscribe_info->borrow_token), thread_id(tid) { -} - -ChannelStore::Subscriber::Subscriber(uint32_t tid) - : conn_cntx(nullptr), borrow_token(0), thread_id(tid) { +bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { + return ByThreadId(lhs, rhs.Thread()); } -bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) { - if (lhs.thread_id == rhs.thread_id) - return (lhs.conn_cntx != nullptr) < (rhs.conn_cntx != nullptr); - return lhs.thread_id < rhs.thread_id; +bool ChannelStore::Subscriber::ByThreadId(const Subscriber& lhs, const unsigned thread) { + return lhs.Thread() < thread; } ChannelStore::UpdatablePointer::UpdatablePointer(const UpdatablePointer& other) { @@ -120,12 +114,8 @@ void ChannelStore::Fill(const SubscribeMap& src, const string& pattern, vectorreserve(out->size() + src.size()); for (const auto [cntx, thread_id] : src) { CHECK(cntx->conn_state.subscribe_info); - - Subscriber s(cntx, thread_id); - s.pattern = pattern; - s.borrow_token.Inc(); - - out->push_back(std::move(s)); + Subscriber sub{cntx->conn()->Borrow(), pattern}; + out->push_back(std::move(sub)); } } diff --git a/src/server/channel_store.h b/src/server/channel_store.h index a60e6c59aeb..49e650d46b4 100644 --- a/src/server/channel_store.h +++ b/src/server/channel_store.h @@ -7,6 +7,7 @@ #include +#include "facade/dragonfly_connection.h" #include "server/conn_context.h" namespace dfly { @@ -39,22 +40,15 @@ class ChannelStore { friend class ChannelStoreUpdater; public: - struct Subscriber { - Subscriber(ConnectionContext* cntx, uint32_t tid); - Subscriber(uint32_t tid); - - Subscriber(Subscriber&&) noexcept = default; - Subscriber& operator=(Subscriber&&) noexcept = default; - - Subscriber(const Subscriber&) = delete; - void operator=(const Subscriber&) = delete; + struct Subscriber : public facade::Connection::WeakRef { + Subscriber(WeakRef ref, const std::string& pattern) + : facade::Connection::WeakRef(std::move(ref)), pattern(pattern) { + } // Sort by thread-id. Subscriber without owner comes first. static bool ByThread(const Subscriber& lhs, const Subscriber& rhs); + static bool ByThreadId(const Subscriber& lhs, const unsigned thread); - ConnectionContext* conn_cntx; - BlockingCounter borrow_token; // to keep connection alive - uint32_t thread_id; std::string pattern; // non-empty if registered via psubscribe }; diff --git a/src/server/conn_context.h b/src/server/conn_context.h index e05f9970c0b..a27d842e166 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -117,8 +117,6 @@ struct ConnectionState { // TODO: to provide unique_strings across service. This will allow us to use string_view here. absl::flat_hash_set channels; absl::flat_hash_set patterns; - - BlockingCounter borrow_token{0}; }; struct ReplicationInfo { @@ -208,6 +206,7 @@ class ConnectionContext : public facade::ConnectionContext { subscriptions++; // required to support the monitoring monitor = enable; } + void SendSubscriptionChangedResponse(std::string_view action, std::optional topic, unsigned count); }; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index c93641dd2c3..cc777701f15 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -15,6 +15,7 @@ extern "C" { #include "server/journal/journal.h" #include "server/server_state.h" #include "server/tiered_storage.h" +#include "util/fibers/proactor_base.h" ABSL_FLAG(bool, enable_heartbeat_eviction, true, "Enable eviction during heartbeat when memory is under pressure."); @@ -219,9 +220,13 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotspotBuckets& eb, PrimeT return 0; } + std::string tmp; + std::string_view key = last_slot_it->first.GetSlice(&tmp); + DbTable* table = db_slice_->GetDBTable(cntx_.db_index); PerformDeletion(last_slot_it, db_slice_->shard_owner(), table); ++evicted_; + db_slice_->SendInvalidationTrackingMessage(key); } me->ShiftRight(bucket_it); @@ -320,7 +325,7 @@ void DbSlice::Reserve(DbIndex db_ind, size_t key_size) { db->prime.Reserve(key_size); } -auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) const +auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) -> OpResult { auto it = FindExt(cntx, key).first; @@ -334,7 +339,7 @@ auto DbSlice::Find(const Context& cntx, string_view key, unsigned req_obj_type) return it; } -pair DbSlice::FindExt(const Context& cntx, string_view key) const { +pair DbSlice::FindExt(const Context& cntx, string_view key) { pair res; if (!IsDbValid(cntx.db_index)) @@ -547,15 +552,15 @@ bool DbSlice::Del(DbIndex db_ind, PrimeIterator it) { auto& db = db_arr_[db_ind]; auto obj_type = it->second.ObjType(); + string tmp; + string_view key = it->first.GetSlice(&tmp); if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) { - string tmp; - string_view key = it->first.GetSlice(&tmp); DbContext cntx{db_ind, GetCurrentTimeMs()}; doc_del_cb_(key, cntx, it->second); } PerformDeletion(it, shard_owner(), db.get()); - + SendInvalidationTrackingMessage(key); return true; } @@ -571,6 +576,7 @@ void DbSlice::FlushSlotsFb(const SlotSet& slot_ids) { SlotId sid = ClusterConfig::KeySlot(key); if (slot_ids.contains(sid) && it.GetVersion() < next_version) { PerformDeletion(it, shard_owner(), db_arr_[0].get()); + SendInvalidationTrackingMessage(key); } return true; }; @@ -616,7 +622,10 @@ void DbSlice::FlushDb(DbIndex db_ind) { if (db_ptr->stats.tiered_entries > 0) { for (auto it = db_ptr->prime.begin(); it != db_ptr->prime.end(); ++it) { if (it->second.IsExternal()) { + std::string tmp; + std::string_view key = it->first.GetSlice(&tmp); PerformDeletion(it, shard_owner(), db_ptr.get()); + SendInvalidationTrackingMessage(key); } } } @@ -652,6 +661,9 @@ void DbSlice::FlushDb(DbIndex db_ind) { for (auto& db : all_dbs) db.reset(); mi_heap_collect(ServerState::tlocal()->data_heap(), true); + + // clear all the tracking table. + client_tracking_map_.clear(); } void DbSlice::AddExpire(DbIndex db_ind, PrimeIterator main_it, uint64_t at) { @@ -973,10 +985,11 @@ void DbSlice::PostUpdate(DbIndex db_ind, PrimeIterator it, std::string_view key, if (ClusterConfig::IsEnabled()) { db.slots_stats[ClusterConfig::KeySlot(key)].total_writes += 1; } + + SendInvalidationTrackingMessage(key); } -pair DbSlice::ExpireIfNeeded(const Context& cntx, - PrimeIterator it) const { +pair DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) { DCHECK(it->second.HasExpire()); auto& db = db_arr_[cntx.db_index]; @@ -992,24 +1005,23 @@ pair DbSlice::ExpireIfNeeded(const Context& cntx, return make_pair(it, expire_it); string tmp_key_buf; - string_view tmp_key; + string_view tmp_key = it->first.GetSlice(&tmp_key_buf); // Replicate expiry if (auto journal = owner_->journal(); journal) { - tmp_key = it->first.GetSlice(&tmp_key_buf); RecordExpiry(cntx.db_index, tmp_key); } auto obj_type = it->second.ObjType(); if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) { - if (tmp_key.empty()) - tmp_key = it->first.GetSlice(&tmp_key_buf); doc_del_cb_(tmp_key, cntx, it->second); } PerformDeletion(it, expire_it, shard_owner(), db.get()); ++events_.expired_keys; + SendInvalidationTrackingMessage(tmp_key); + return make_pair(PrimeIterator{}, ExpireIterator{}); } @@ -1169,6 +1181,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes PerformDeletion(evict_it, shard_owner(), db_table.get()); ++evicted; + SendInvalidationTrackingMessage(key); + used_memory_after = owner_->UsedMemory(); // returns when whichever condition is met first if ((evicted == max_eviction_per_hb) || @@ -1231,6 +1245,7 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t return current < used_memory_start ? used_memory_start - current : 0; }; + std::string tmp; for (unsigned i = 0; !evict_succeeded && i < kNumStashBuckets; ++i) { unsigned stash_bid = i + PrimeTable::Segment_t::kNumBuckets; const auto& bucket = segment->GetBucket(stash_bid); @@ -1246,8 +1261,13 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t if (evict_it == it || evict_it->first.IsSticky()) continue; + string_view key = evict_it->first.GetSlice(&tmp); + PerformDeletion(evict_it, shard_owner(), table); ++evicted; + + SendInvalidationTrackingMessage(key); + if (freed_memory_fun() > memory_to_free) { evict_succeeded = true; break; @@ -1272,9 +1292,12 @@ size_t DbSlice::EvictObjects(size_t memory_to_free, PrimeIterator it, DbTable* t if (evict_it == it || evict_it->first.IsSticky()) continue; + string_view key = evict_it->first.GetSlice(&tmp); PerformDeletion(evict_it, shard_owner(), table); ++evicted; + SendInvalidationTrackingMessage(key); + if (freed_memory_fun() > memory_to_free) { evict_succeeded = true; break; @@ -1338,4 +1361,60 @@ void DbSlice::ResetUpdateEvents() { events_.update = 0; } +// todo: perhaps we need to limit the total number of entry in the tracking +// table so that we have a way to limit the amount of memory used for +// client tracking +void DbSlice::TrackKeys(facade::Connection::WeakRef conn, int32_t tid, + const std::vector& keys) { + DVLOG(2) << "Start tracking the following keys for thread ID: " << tid + << ", client ID: " << conn.Get()->GetClientId(); + if (conn.Get() == nullptr) + return; + + for (auto key : keys) { + // std::string_view k = key; + std::string k{key.begin(), key.end()}; + DVLOG(2) << " " << k; + std::pair p{conn, tid}; + if (client_tracking_map_.find(k) == client_tracking_map_.end()) { + absl::flat_hash_set, Hash> tracker_set{p}; + client_tracking_map_.insert({k, tracker_set}); + } else { + client_tracking_map_[k].insert(p); + } + } +} + +void DbSlice::SendInvalidationTrackingMessage(std::string_view key) { + std::string k{key.begin(), key.end()}; + if (client_tracking_map_.find(k) != client_tracking_map_.end()) { + // notify all the clients. + auto& client_set = client_tracking_map_[k]; + DVLOG(2) << "Garbage collect clients that are no longer tracking... "; + auto is_closed_or_not_tracking = [](std::pair p) { + return ((p.first.Get() == nullptr) || (!p.first.Get()->IsTrackingOn())); + }; + absl::erase_if(client_set, is_closed_or_not_tracking); + DVLOG(2) << "Number of clients left: " << client_set.size(); + + if (!client_set.empty()) { + auto cb = [key, client_set](unsigned idx, util::ProactorBase*) { + for (auto it = client_set.begin(); it != client_set.end(); ++it) { + if ((unsigned int)it->second != idx) + continue; + + facade::Connection* conn = it->first.Get(); + if (conn == nullptr) + continue; + conn->SendInvalidationMessageAsync({key}); + return; + } + }; + shard_set->pool()->DispatchBrief(std::move(cb)); + } + // remove this key from the tracking table as the key no longer exists + client_tracking_map_.erase(k); + } +} + } // namespace dfly diff --git a/src/server/db_slice.h b/src/server/db_slice.h index 1c4ae56008f..b8630394ddd 100644 --- a/src/server/db_slice.h +++ b/src/server/db_slice.h @@ -4,6 +4,7 @@ #pragma once +#include "facade/dragonfly_connection.h" #include "facade/op_status.h" #include "server/common.h" #include "server/conn_context.h" @@ -147,11 +148,10 @@ class DbSlice { return ExpirePeriod{time_ms - expire_base_[0]}; } - OpResult Find(const Context& cntx, std::string_view key, - unsigned req_obj_type) const; + OpResult Find(const Context& cntx, std::string_view key, unsigned req_obj_type); // Returns (value, expire) dict entries if key exists, null if it does not exist or has expired. - std::pair FindExt(const Context& cntx, std::string_view key) const; + std::pair FindExt(const Context& cntx, std::string_view key); // Returns (iterator, args-index) if found, KEY_NOTFOUND otherwise. // If multiple keys are found, returns the first index in the ArgSlice. @@ -269,8 +269,7 @@ class DbSlice { // Check whether 'it' has not expired. Returns it if it's still valid. Otherwise, erases it // from both tables and return PrimeIterator{}. - std::pair ExpireIfNeeded(const Context& cntx, - PrimeIterator it) const; + std::pair ExpireIfNeeded(const Context& cntx, PrimeIterator it); // Iterate over all expire table entries and delete expired. void ExpireAllIfNeeded(); @@ -334,6 +333,14 @@ class DbSlice { expire_allowed_ = is_allowed; } + // Start tracking keys for the client with client_id + // void TrackKeys(ConnectionContext*, int32_t, const std::vector&); + void TrackKeys(facade::Connection::WeakRef, int32_t, const std::vector&); + + // Send invalidatoin message when a key being tracked is updated/deleted. + // A connection that has been closed will be garbage collected along the way. + void SendInvalidationTrackingMessage(std::string_view key); + private: // Releases a single key. `key` must have been normalized by GetLockKey(). void ReleaseNormalized(IntentLock::Mode m, DbIndex db_index, std::string_view key, @@ -384,6 +391,16 @@ class DbSlice { // Registered by shard indices on when first document index is created. DocDeletionCallback doc_del_cb_; + + struct Hash { + size_t operator()(const std::pair& p) const { + return std::hash()(p.first.GetClientId()); + } + }; + + absl::flat_hash_map, Hash>> + client_tracking_map_; }; } // namespace dfly diff --git a/src/server/hset_family.cc b/src/server/hset_family.cc index 851dce09ae2..f1b3952d75b 100644 --- a/src/server/hset_family.cc +++ b/src/server/hset_family.cc @@ -1134,10 +1134,26 @@ void HSetFamily::HRandField(CmdArgList args, ConnectionContext* cntx) { }; OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result) { - (*cntx)->SendStringArr(*result); + if ((result->size() == 1) && (args.size() == 1)) + (*cntx)->SendBulkString(result->front()); + else { + if (((*cntx)->IsResp3()) && (args.size() == 3)) { // has withvalues + (*cntx)->StartArray(result->size() / 2); + for (unsigned int i = 0; i < result->size() / 2; ++i) { + StringVec sv{(*result)[i * 2], (*result)[i * 2 + 1]}; + (*cntx)->SendStringArr(sv); + } + } else { + (*cntx)->SendStringArr(*result); + } + } } else if (result.status() == OpStatus::KEY_NOTFOUND) { - (*cntx)->SendNull(); + if (args.size() == 1) + (*cntx)->SendNull(); + else + (*cntx)->SendEmptyArray(); } else { (*cntx)->SendError(result.status()); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 4dcdb787b2d..ea8145629fe 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -1040,6 +1040,13 @@ std::optional Service::VerifyCommandState(const CommandId* cid, CmdA return VerifyConnectionAclStatus(cid, &dfly_cntx, "has no ACL permissions"); } +OpResult OpTrackKeys(const OpArgs& op_args, ConnectionContext* cntx, uint32_t tid, + vector& keys, CmdArgList args) { + auto& db_slice = op_args.shard->db_slice(); + db_slice.TrackKeys(cntx->conn()->Borrow(), tid, keys); + return true; +} + void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) { CHECK(!args.empty()); DCHECK_NE(0u, shard_set->size()) << "Init was not called"; @@ -1140,6 +1147,28 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) dfly_cntx->reply_builder()->CloseConnection(); } + // if this is a read command, and client tracking has enabled, + // start tracking updates to the keys in this read command + // notify the client when there is update, see PostUpdate() in db_slice.cc + if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn()) { + // let's pass thread id and connection to db_slice for tracking + int32_t tid = ProactorBase::me()->GetPoolIndex(); + // uint32_t client_id = dfly_cntx->conn()->GetClientId(); + auto cb = [&](Transaction* t, EngineShard* shard) { + auto keys = t->GetShardArgs(shard->shard_id()); + vector keys_to_track{keys.begin(), keys.end()}; + return OpTrackKeys(t->GetOpArgs(shard), dfly_cntx, tid, keys_to_track, args); + }; + + if (dfly_cntx->transaction == nullptr) { + DVLOG(2) << "transaction is a nullptr"; + } else { + DVLOG(2) << "transaction is not a nullptr"; + dfly_cntx->transaction->Refurbish(); + dfly_cntx->transaction->ScheduleSingleHopT(cb); + } + } + if (!dispatching_in_multi) { dfly_cntx->transaction = nullptr; } @@ -1455,6 +1484,9 @@ void Service::Quit(CmdArgList args, ConnectionContext* cntx) { (*cntx)->SendOk(); using facade::SinkReplyBuilder; + if (cntx->conn()->IsTrackingOn()) + cntx->conn()->DisableTracking(); + SinkReplyBuilder* builder = cntx->reply_builder(); builder->CloseConnection(); @@ -2081,12 +2113,12 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { // thus not adding any overhead to backpressure checks. optional last_thread; for (auto& sub : subscribers) { - DCHECK_LE(last_thread.value_or(0), sub.thread_id); - if (last_thread && *last_thread == sub.thread_id) // skip same thread + DCHECK_LE(last_thread.value_or(0), sub.Thread()); + if (last_thread && *last_thread == sub.Thread()) // skip same thread continue; - sub.conn_cntx->conn()->EnsureAsyncMemoryBudget(); - last_thread = sub.thread_id; + if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped + last_thread = sub.Thread(); } auto subscribers_ptr = make_shared(std::move(subscribers)); @@ -2096,14 +2128,13 @@ void Service::Publish(CmdArgList args, ConnectionContext* cntx) { auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) { auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx, - ChannelStore::Subscriber::ByThread); - - while (it != subscribers_ptr->end() && it->thread_id == idx) { - facade::Connection* conn = it->conn_cntx->conn(); - DCHECK(conn); - conn->SendPubMessageAsync( - {std::move(it->pattern), std::move(buf), channel.size(), msg.size()}); - it->borrow_token.Dec(); + ChannelStore::Subscriber::ByThreadId); + + while (it != subscribers_ptr->end() && it->Thread() == idx) { + if (auto* ptr = it->Get(); ptr) { + ptr->SendPubMessageAsync( + {std::move(it->pattern), std::move(buf), channel.size(), msg.size()}); + } it++; } }; @@ -2333,20 +2364,12 @@ void Service::OnClose(facade::ConnectionContext* cntx) { if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB if (!conn_state.subscribe_info->channels.empty()) { - auto token = conn_state.subscribe_info->borrow_token; server_cntx->UnsubscribeAll(false); - - // Check that all borrowers finished processing. - // token is increased in channel_slice (the publisher side). - token.Wait(); } if (conn_state.subscribe_info) { DCHECK(!conn_state.subscribe_info->patterns.empty()); - - auto token = conn_state.subscribe_info->borrow_token; server_cntx->PUnsubscribeAll(false); - token.Wait(); // Same as above } DCHECK(!conn_state.subscribe_info); @@ -2357,6 +2380,10 @@ void Service::OnClose(facade::ConnectionContext* cntx) { DeactivateMonitoring(server_cntx); server_family_.OnClose(server_cntx); + + // disable client tracking. + if (server_cntx->conn()->IsTrackingOn()) + server_cntx->conn()->DisableTracking(); } string Service::GetContextInfo(facade::ConnectionContext* cntx) { @@ -2414,7 +2441,7 @@ void Service::Register(CommandRegistry* registry) { using CI = CommandId; registry->StartFamily(); *registry - << CI{"QUIT", CO::READONLY | CO::FAST, 1, 0, 0, acl::kQuit}.HFUNC(Quit) + << CI{"QUIT", CO::FAST, 1, 0, 0, acl::kQuit}.HFUNC(Quit) << CI{"MULTI", CO::NOSCRIPT | CO::FAST | CO::LOADING, 1, 0, 0, acl::kMulti}.HFUNC(Multi) << CI{"WATCH", CO::LOADING, -2, 1, -1, acl::kWatch}.HFUNC(Watch) << CI{"UNWATCH", CO::LOADING, 1, 0, 0, acl::kUnwatch}.HFUNC(Unwatch) diff --git a/src/server/server_family.cc b/src/server/server_family.cc index e327b13241f..d99ed156836 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -1164,10 +1164,29 @@ string GetPassword() { return ""; } +void ServerFamily::SendInvalidationMessages() const { + // send invalidation message (caused by flushdb) to all the clients which + // turned on client tracking + auto cb = [](unsigned thread_index, util::Connection* conn) { + facade::ConnectionContext* fc = static_cast(conn)->cntx(); + if (fc) { + ConnectionContext* cntx = static_cast(fc); + if (cntx->conn()->IsTrackingOn()) { + facade::Connection::InvalidationMessage x; + x.invalidate_due_to_flush = true; + cntx->conn()->SendInvalidationMessageAsync(x); + } + } + }; + for (auto* listener : listeners_) { + listener->TraverseConnections(cb); + } +} + void ServerFamily::FlushDb(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); Drakarys(cntx->transaction, cntx->transaction->GetDbIndex()); - + SendInvalidationMessages(); cntx->reply_builder()->SendOk(); } @@ -1179,6 +1198,7 @@ void ServerFamily::FlushAll(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); Drakarys(cntx->transaction, DbSlice::kDbAll); + SendInvalidationMessages(); (*cntx)->SendOk(); } @@ -1243,6 +1263,23 @@ void ServerFamily::Client(CmdArgList args, ConnectionContext* cntx) { return (*cntx)->SendOk(); } + if (sub_cmd == "TRACKING" && args.size() == 2) { + if ((*cntx)->IsResp3()) { + ToUpper(&args[1]); + string_view switch_state = ArgS(args, 1); + if (switch_state == "ON") { + cntx->conn()->EnableTracking(); + return (*cntx)->SendOk(); + } else if (switch_state == "OFF") { + cntx->conn()->DisableTracking(); + return (*cntx)->SendOk(); + } + } else { + LOG_FIRST_N(ERROR, 10) + << "Client tracking is currently not supported in RESP2, please use RESP3."; + } + } + LOG_FIRST_N(ERROR, 10) << "Subcommand " << sub_cmd << " not supported"; return (*cntx)->SendError(UnknownSubCmd(sub_cmd, "CLIENT"), kSyntaxErrType); } @@ -1291,7 +1328,7 @@ void ServerFamily::ClientList(CmdArgList args, ConnectionContext* cntx) { string result = absl::StrJoin(client_info, "\n"); result.append("\n"); - return (*cntx)->SendBulkString(result); + return (*cntx)->SendVerbatimString(result); } void ServerFamily::ClientPause(CmdArgList args, ConnectionContext* cntx) { @@ -1844,7 +1881,7 @@ void ServerFamily::Info(CmdArgList args, ConnectionContext* cntx) { append("cluster_enabled", ClusterConfig::IsEnabledOrEmulated()); } - (*cntx)->SendBulkString(info); + (*cntx)->SendVerbatimString(info); } void ServerFamily::Hello(CmdArgList args, ConnectionContext* cntx) { diff --git a/src/server/server_family.h b/src/server/server_family.h index 08de412f623..38948328616 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -256,6 +256,8 @@ class ServerFamily { void SnapshotScheduling(); + void SendInvalidationMessages() const; + Fiber snapshot_schedule_fb_; Future load_result_; diff --git a/src/server/server_family_test.cc b/src/server/server_family_test.cc index e0a12c26956..185ca9a821f 100644 --- a/src/server/server_family_test.cc +++ b/src/server/server_family_test.cc @@ -200,4 +200,11 @@ TEST_F(ServerFamilyTest, ClientPause) { EXPECT_GT((absl::Now() - start), absl::Milliseconds(50)); } +TEST_F(ServerFamilyTest, ClientTracking) { + // client tracking only works for RESP3 + auto resp = Run({"hello", "3"}); + resp = Run({"client", "tracking", "on"}); + EXPECT_THAT(resp.GetString(), "OK"); +} + } // namespace dfly diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index e9fed427289..aa1cd4f4cb8 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -66,6 +66,7 @@ static vector SplitLines(const std::string& src) { TestConnection::TestConnection(Protocol protocol, io::StringSink* sink) : facade::Connection(protocol, nullptr, nullptr, nullptr), sink_(sink) { cc_.reset(new dfly::ConnectionContext(sink_, this)); + SetSocket(ProactorBase::me()->CreateSocket()); } void TestConnection::SendPubMessageAsync(PubMessage pmsg) { diff --git a/src/server/transaction.cc b/src/server/transaction.cc index ebf29b03409..3e50c167029 100644 --- a/src/server/transaction.cc +++ b/src/server/transaction.cc @@ -906,6 +906,12 @@ void Transaction::Conclude() { Execute(std::move(cb), true); } +void Transaction::Refurbish() { + txid_ = 0; + coordinator_state_ = 0; + cb_ptr_ = nullptr; +} + void Transaction::EnableShard(ShardId sid) { unique_shard_cnt_ = 1; unique_shard_id_ = sid; diff --git a/src/server/transaction.h b/src/server/transaction.h index 6d76e24cf4b..64d175e2880 100644 --- a/src/server/transaction.h +++ b/src/server/transaction.h @@ -323,6 +323,8 @@ class Transaction { // Utility to run a single hop on a no-key command static void RunOnceAsCommand(const CommandId* cid, RunnableType cb); + void Refurbish(); + private: // Holds number of locks for each IntentLock::Mode: shared and exlusive. struct LockCnt {