Skip to content

Commit

Permalink
Merge branch 'main' into acl_part_7_add_acl_deluser
Browse files Browse the repository at this point in the history
  • Loading branch information
kostasrim committed Sep 1, 2023
2 parents 12ff002 + 9ca7dba commit 2ed4699
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 17 deletions.
27 changes: 24 additions & 3 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "facade/memcache_parser.h"
#include "facade/redis_parser.h"
#include "facade/service_interface.h"
#include "server/conn_context.h"
#include "util/fibers/proactor_base.h"

#ifdef DFLY_USE_SSL
Expand Down Expand Up @@ -130,6 +131,7 @@ struct Connection::DispatchOperations {
void operator()(const PubMessage& msg);
void operator()(Connection::PipelineMessage& msg);
void operator()(const MonitorMessage& msg);
void operator()(const AclUpdateMessage& msg);

template <typename T, typename D> void operator()(unique_ptr<T, D>& ptr) {
operator()(*ptr.get());
Expand Down Expand Up @@ -191,7 +193,11 @@ size_t Connection::MessageHandle::UsedMemory() const {
arg.storage.capacity();
};
auto monitor_size = [](const MonitorMessage& arg) -> size_t { return arg.capacity(); };
return sizeof(MessageHandle) + visit(Overloaded{pub_size, msg_size, monitor_size}, this->handle);
auto acl_update_size = [](const AclUpdateMessage& msg) -> size_t {
return sizeof(AclUpdateMessage);
};
return sizeof(MessageHandle) +
visit(Overloaded{pub_size, msg_size, monitor_size, acl_update_size}, this->handle);
}

bool Connection::MessageHandle::IsPipelineMsg() const {
Expand All @@ -203,6 +209,13 @@ void Connection::DispatchOperations::operator()(const MonitorMessage& msg) {
rbuilder->SendSimpleString(msg);
}

void Connection::DispatchOperations::operator()(const AclUpdateMessage& msg) {
auto* ctx = static_cast<dfly::ConnectionContext*>(self->cntx());
if (ctx && msg.username == ctx->authed_username) {
ctx->acl_categories = msg.categories;
}
}

void Connection::DispatchOperations::operator()(const PubMessage& pub_msg) {
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
++stats->async_writes_cnt;
Expand Down Expand Up @@ -929,6 +942,10 @@ void Connection::SendMonitorMessageAsync(string msg) {
SendAsync({MonitorMessage{move(msg)}});
}

void Connection::SendAclUpdateAsync(AclUpdateMessage msg) {
SendAsync({msg});
}

void Connection::SendAsync(MessageHandle msg) {
DCHECK(cc_);
DCHECK(owner());
Expand All @@ -945,7 +962,11 @@ void Connection::SendAsync(MessageHandle msg) {
}

dispatch_q_bytes_.fetch_add(msg.UsedMemory(), memory_order_relaxed);
dispatch_q_.push_back(move(msg));
if (std::holds_alternative<AclUpdateMessage>(msg.handle)) {
dispatch_q_.push_front(std::move(msg));
} else {
dispatch_q_.push_back(std::move(msg));
}

// Don't notify if a sync dispatch is in progress, it will wake after finishing.
// This might only happen if we started receving messages while `SUBSCRIBE`
Expand Down Expand Up @@ -975,7 +996,7 @@ std::string Connection::RemoteEndpointAddress() const {
return re.address().to_string();
}

ConnectionContext* Connection::cntx() {
facade::ConnectionContext* Connection::cntx() {
return cc_.get();
}

Expand Down
10 changes: 9 additions & 1 deletion src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class Connection : public util::Connection {

struct MonitorMessage : public std::string {};

struct AclUpdateMessage {
std::string_view username;
uint64_t categories{0};
};

struct PipelineMessage {
PipelineMessage(size_t nargs, size_t capacity) : args(nargs), storage(capacity) {
}
Expand Down Expand Up @@ -111,7 +116,7 @@ class Connection : public util::Connection {

bool IsPipelineMsg() const;

std::variant<MonitorMessage, PubMessagePtr, PipelineMessagePtr> handle;
std::variant<MonitorMessage, PubMessagePtr, PipelineMessagePtr, AclUpdateMessage> handle;
};

enum Phase { READ_SOCKET, PROCESS };
Expand All @@ -124,6 +129,9 @@ class Connection : public util::Connection {
// Add monitor message to dispatch queue.
void SendMonitorMessageAsync(std::string);

// Add acl update to dispatch queue.
void SendAclUpdateAsync(AclUpdateMessage msg);

// Must be called before Send_Async to ensure the connection dispatch queue is not overfilled.
// Blocks until free space is available.
void EnsureAsyncMemoryBudget();
Expand Down
5 changes: 1 addition & 4 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ void AclFamily::StreamUpdatesToAllProactorConnections(std::string_view user, uin
auto update_cb = [user, update_cat]([[maybe_unused]] size_t id, util::Connection* conn) {
DCHECK(conn);
auto connection = static_cast<facade::Connection*>(conn);
auto ctx = static_cast<ConnectionContext*>(connection->cntx());
if (ctx && user == ctx->authed_username) {
ctx->acl_categories = update_cat;
}
connection->SendAclUpdateAsync(facade::Connection::AclUpdateMessage{user, update_cat});
};

if (main_listener_) {
Expand Down
10 changes: 5 additions & 5 deletions src/server/acl/user_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ UserRegistry::RegistryViewWithLock UserRegistry::GetRegistryWithLock() const {
return {std::move(lock), registry_};
}

UserRegistry::UserViewWithLock::UserViewWithLock(std::shared_lock<util::SharedMutex> lk,
const User& user, bool exists)
UserRegistry::UserWithWriteLock::UserWithWriteLock(std::unique_lock<util::SharedMutex> lk,
const User& user, bool exists)
: user(user), exists(exists), registry_lk_(std::move(lk)) {
}

UserRegistry::UserViewWithLock UserRegistry::MaybeAddAndUpdateWithLock(std::string_view username,
User::UpdateRequest req) {
std::shared_lock<util::SharedMutex> lock(mu_);
UserRegistry::UserWithWriteLock UserRegistry::MaybeAddAndUpdateWithLock(std::string_view username,
User::UpdateRequest req) {
std::unique_lock<util::SharedMutex> lock(mu_);
const bool exists = registry_.contains(username);
auto& user = registry_[username];
user.Update(std::move(req));
Expand Down
8 changes: 4 additions & 4 deletions src/server/acl/user_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ class UserRegistry {
RegistryViewWithLock GetRegistryWithLock() const;

// Helper class for accessing a user with a ReadLock outside the scope of UserRegistry
class UserViewWithLock {
class UserWithWriteLock {
public:
UserViewWithLock(std::shared_lock<util::SharedMutex> lk, const User& user, bool exists);
UserWithWriteLock(std::unique_lock<util::SharedMutex> lk, const User& user, bool exists);
const User& user;
const bool exists;

private:
std::shared_lock<util::SharedMutex> registry_lk_;
std::unique_lock<util::SharedMutex> registry_lk_;
};

UserViewWithLock MaybeAddAndUpdateWithLock(std::string_view username, User::UpdateRequest req);
UserWithWriteLock MaybeAddAndUpdateWithLock(std::string_view username, User::UpdateRequest req);

private:
RegistryType registry_;
Expand Down
20 changes: 20 additions & 0 deletions tests/dragonfly/acl_family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,23 @@ async def test_acl_del_user_while_running_lua_script(df_server):
assert res == b"10000"

await admin_client.close()


@pytest.mark.asyncio
async def test_acl_with_long_running_script(df_server):
client = aioredis.Redis(port=df_server.port)
await client.execute_command("ACL SETUSER roman ON >yoman +@string +@scripting")
await client.execute_command("AUTH roman yoman")
admin_client = aioredis.Redis(port=df_server.port)

await asyncio.gather(
client.eval(script, 4, "key", "key1", "key2", "key3"),
admin_client.execute_command("ACL SETUSER -@string -@scripting"),
)

for i in range(1, 4):
res = await admin_client.get(f"key{i}")
assert res == b"10000"

await client.close()
await admin_client.close()

0 comments on commit 2ed4699

Please sign in to comment.