diff --git a/src/memgraph.cpp b/src/memgraph.cpp index 457aa83646..86354c1f6a 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -751,6 +751,24 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler { } } + memgraph::auth::User *GetUser(const std::string &username) override { + if (!std::regex_match(username, name_regex_)) { + throw memgraph::query::QueryRuntimeException("Invalid user name."); + } + try { + auto locked_auth = auth_->Lock(); + auto user = locked_auth->GetUser(username); + if (!user) { + throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); + } + + return new memgraph::auth::User(*user); + + } catch (const memgraph::auth::AuthException &e) { + throw memgraph::query::QueryRuntimeException(e.what()); + } + } + void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges, const std::vector &labels) override { diff --git a/src/query/context.hpp b/src/query/context.hpp index 12b1f0cacb..0be220dfbb 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -14,6 +14,7 @@ #include #include "query/common.hpp" +#include "query/fine_grained_access_checker.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/metadata.hpp" #include "query/parameters.hpp" @@ -72,6 +73,7 @@ struct ExecutionContext { ExecutionStats execution_stats; TriggerContextCollector *trigger_context_collector{nullptr}; utils::AsyncTimer timer; + FineGrainedAccessChecker *fine_grained_access_checker{nullptr}; }; static_assert(std::is_move_assignable_v, "ExecutionContext must be move assignable!"); diff --git a/src/query/fine_grained_access_checker.hpp b/src/query/fine_grained_access_checker.hpp new file mode 100644 index 0000000000..cd508734ad --- /dev/null +++ b/src/query/fine_grained_access_checker.hpp @@ -0,0 +1,24 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#pragma once + +#include "auth/models.hpp" +#include "query/frontend/ast/ast.hpp" +#include "storage/v2/id_types.hpp" + +namespace memgraph::query { +class FineGrainedAccessChecker { + public: + virtual bool IsUserAuthorizedLabels(const std::vector &label, + memgraph::query::DbAccessor *dba) const = 0; +}; +} // namespace memgraph::query diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 8f05bc2077..f45e17794b 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -29,6 +29,7 @@ #include "query/db_accessor.hpp" #include "query/dump.hpp" #include "query/exceptions.hpp" +#include "query/fine_grained_access_checker.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp" @@ -259,6 +260,24 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { private: storage::Storage *db_; }; + +class FineGrainedAccessChecker final : public memgraph::query::FineGrainedAccessChecker { + public: + explicit FineGrainedAccessChecker(memgraph::auth::User *user) : user_{user} {} + + bool IsUserAuthorizedLabels(const std::vector &labels, + memgraph::query::DbAccessor *dba) const final { + auto labelPermissions = user_->GetFineGrainedAccessPermissions(); + + return std::any_of(labels.begin(), labels.end(), [&labelPermissions, &dba](const auto label) { + return labelPermissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT; + }); + } + + private: + memgraph::auth::User *user_; +}; + /// returns false if the replication role can't be set /// @throw QueryRuntimeException if an error ocurred. @@ -898,7 +917,7 @@ struct PullPlanVector { struct PullPlan { explicit PullPlan(std::shared_ptr plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - TriggerContextCollector *trigger_context_collector = nullptr, + std::optional username, TriggerContextCollector *trigger_context_collector = nullptr, std::optional memory_limit = {}); std::optional Pull(AnyStream *stream, std::optional n, const std::vector &output_symbols, @@ -927,7 +946,8 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - TriggerContextCollector *trigger_context_collector, const std::optional memory_limit) + std::optional username, TriggerContextCollector *trigger_context_collector, + const std::optional memory_limit) : plan_(plan), cursor_(plan->plan().MakeCursor(execution_memory)), frame_(plan->symbol_table().max_position(), execution_memory), @@ -938,6 +958,12 @@ PullPlan::PullPlan(const std::shared_ptr plan, const Parameters &par ctx_.evaluation_context.parameters = parameters; ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba); ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba); +#ifdef MG_ENTERPRISE + if (username.has_value()) { + memgraph::auth::User *user = interpreter_context->auth->GetUser(*username); + ctx_.fine_grained_access_checker = new FineGrainedAccessChecker{user}; + } +#endif if (interpreter_context->config.execution_timeout_sec > 0) { ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec}; } @@ -1111,6 +1137,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, DbAccessor *dba, utils::MemoryResource *execution_memory, std::vector *notifications, + const std::string *username, TriggerContextCollector *trigger_context_collector = nullptr) { auto *cypher_query = utils::Downcast(parsed_query.query); @@ -1154,8 +1181,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, trigger_context_collector, memory_limit); + auto pull_plan = + std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, + StringPointerToOptional(username), trigger_context_collector, memory_limit); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( AnyStream *stream, std::optional n) -> std::optional { @@ -1215,7 +1243,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory) { + DbAccessor *dba, utils::MemoryResource *execution_memory, + const std::string *username) { const std::string kProfileQueryStart = "profile "; MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart), @@ -1265,12 +1294,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query, parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba); auto rw_type_checker = plan::ReadWriteTypeChecker(); + auto optional_username = StringPointerToOptional(username); + rw_type_checker.InferRWType(const_cast(cypher_query_plan->plan())); return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"}, std::move(parsed_query.required_privileges), [plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters), - summary, dba, interpreter_context, execution_memory, memory_limit, + summary, dba, interpreter_context, execution_memory, memory_limit, optional_username, // We want to execute the query we are profiling lazily, so we delay // the construction of the corresponding context. stats_and_total_time = std::optional{}, @@ -1279,7 +1310,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra // No output symbols are given so that nothing is streamed. if (!stats_and_total_time) { stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context, - execution_memory, nullptr, memory_limit) + execution_memory, optional_username, nullptr, memory_limit) .Pull(stream, {}, {}, summary); pull_plan = std::make_shared(ProfilingStatsToTable(*stats_and_total_time)); } @@ -1414,7 +1445,7 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::map *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory) { + DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } @@ -1434,8 +1465,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), 0.0, AstStorage{}, symbol_table)); - auto pull_plan = - std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory); + auto pull_plan = std::make_shared(plan, parsed_query.parameters, false, dba, interpreter_context, + execution_memory, StringPointerToOptional(username)); return PreparedQuery{ callback.header, std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), @@ -2148,7 +2179,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory, - &query_execution->notifications, + &query_execution->notifications, username, trigger_context_collector_ ? &*trigger_context_collector_ : nullptr); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, @@ -2156,7 +2187,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception); + &query_execution->execution_memory_with_exception, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_, &query_execution->execution_memory); @@ -2166,7 +2197,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception); + &query_execution->execution_memory_with_exception, username); } else if (utils::Downcast(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, interpreter_context_->db, diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 6c949830fe..155f5688af 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -99,6 +99,9 @@ class AuthQueryHandler { virtual std::vector> GetPrivileges(const std::string &user_or_role) = 0; + /// @throw QueryRuntimeException if an error ocurred. + virtual memgraph::auth::User *GetUser(const std::string &username) = 0; + /// @throw QueryRuntimeException if an error ocurred. virtual void GrantPrivilege(const std::string &user_or_role, const std::vector &privileges, const std::vector &labels) = 0;