Skip to content

Commit

Permalink
[E129<-T0955-MG] Expand ExecutionContext with label related informati…
Browse files Browse the repository at this point in the history
…on (#467)

* added

* Added FineGrainedAccessChecker to Context

* fixed
  • Loading branch information
niko4299 committed Jul 21, 2022
1 parent f85ee31 commit a2643cc
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 13 deletions.
18 changes: 18 additions & 0 deletions src/memgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,24 @@ class AuthQueryHandler final : public memgraph::query::AuthQueryHandler {
}
}

memgraph::auth::User *GetUser(const std::string &username) override {
if (!std::regex_match(username, name_regex_)) {
throw memgraph::query::QueryRuntimeException("Invalid user name.");
}
try {
auto locked_auth = auth_->Lock();
auto user = locked_auth->GetUser(username);
if (!user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}

return new memgraph::auth::User(*user);

} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}

void GrantPrivilege(const std::string &user_or_role,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) override {
Expand Down
2 changes: 2 additions & 0 deletions src/query/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <type_traits>

#include "query/common.hpp"
#include "query/fine_grained_access_checker.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/metadata.hpp"
#include "query/parameters.hpp"
Expand Down Expand Up @@ -72,6 +73,7 @@ struct ExecutionContext {
ExecutionStats execution_stats;
TriggerContextCollector *trigger_context_collector{nullptr};
utils::AsyncTimer timer;
FineGrainedAccessChecker *fine_grained_access_checker{nullptr};
};

static_assert(std::is_move_assignable_v<ExecutionContext>, "ExecutionContext must be move assignable!");
Expand Down
24 changes: 24 additions & 0 deletions src/query/fine_grained_access_checker.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

#pragma once

#include "auth/models.hpp"
#include "query/frontend/ast/ast.hpp"
#include "storage/v2/id_types.hpp"

namespace memgraph::query {
class FineGrainedAccessChecker {
public:
virtual bool IsUserAuthorizedLabels(const std::vector<memgraph::storage::LabelId> &label,
memgraph::query::DbAccessor *dba) const = 0;
};
} // namespace memgraph::query
57 changes: 44 additions & 13 deletions src/query/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "query/db_accessor.hpp"
#include "query/dump.hpp"
#include "query/exceptions.hpp"
#include "query/fine_grained_access_checker.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/ast_visitor.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
Expand Down Expand Up @@ -259,6 +260,24 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
private:
storage::Storage *db_;
};

class FineGrainedAccessChecker final : public memgraph::query::FineGrainedAccessChecker {
public:
explicit FineGrainedAccessChecker(memgraph::auth::User *user) : user_{user} {}

bool IsUserAuthorizedLabels(const std::vector<memgraph::storage::LabelId> &labels,
memgraph::query::DbAccessor *dba) const final {
auto labelPermissions = user_->GetFineGrainedAccessPermissions();

return std::any_of(labels.begin(), labels.end(), [&labelPermissions, &dba](const auto label) {
return labelPermissions.Has(dba->LabelToName(label)) == memgraph::auth::PermissionLevel::GRANT;
});
}

private:
memgraph::auth::User *user_;
};

/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.

Expand Down Expand Up @@ -898,7 +917,7 @@ struct PullPlanVector {
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector = nullptr,
std::optional<size_t> memory_limit = {});
std::optional<plan::ProfilingStatsWithTotalTime> Pull(AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
Expand Down Expand Up @@ -927,7 +946,8 @@ struct PullPlan {

PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory,
TriggerContextCollector *trigger_context_collector, const std::optional<size_t> memory_limit)
std::optional<std::string> username, TriggerContextCollector *trigger_context_collector,
const std::optional<size_t> memory_limit)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory),
Expand All @@ -938,6 +958,12 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par
ctx_.evaluation_context.parameters = parameters;
ctx_.evaluation_context.properties = NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels = NamesToLabels(plan->ast_storage().labels_, dba);
#ifdef MG_ENTERPRISE
if (username.has_value()) {
memgraph::auth::User *user = interpreter_context->auth->GetUser(*username);
ctx_.fine_grained_access_checker = new FineGrainedAccessChecker{user};
}
#endif
if (interpreter_context->config.execution_timeout_sec > 0) {
ctx_.timer = utils::AsyncTimer{interpreter_context->config.execution_timeout_sec};
}
Expand Down Expand Up @@ -1111,6 +1137,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper)
PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MemoryResource *execution_memory, std::vector<Notification> *notifications,
const std::string *username,
TriggerContextCollector *trigger_context_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);

Expand Down Expand Up @@ -1154,8 +1181,9 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
header.push_back(
utils::FindOr(parsed_query.stripped_query.named_expressions(), symbol.token_position(), symbol.name()).first);
}
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, trigger_context_collector, memory_limit);
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), trigger_context_collector, memory_limit);
return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary](
AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
Expand Down Expand Up @@ -1215,7 +1243,8 @@ PreparedQuery PrepareExplainQuery(ParsedQuery parsed_query, std::map<std::string

PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::string *username) {
const std::string kProfileQueryStart = "profile ";

MG_ASSERT(utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()), kProfileQueryStart),
Expand Down Expand Up @@ -1265,12 +1294,14 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
parsed_inner_query.parameters, parsed_inner_query.is_cacheable ? &interpreter_context->plan_cache : nullptr, dba);
auto rw_type_checker = plan::ReadWriteTypeChecker();
auto optional_username = StringPointerToOptional(username);

rw_type_checker.InferRWType(const_cast<plan::LogicalOperator &>(cypher_query_plan->plan()));

return PreparedQuery{{"OPERATOR", "ACTUAL HITS", "RELATIVE TIME", "ABSOLUTE TIME"},
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan), parameters = std::move(parsed_inner_query.parameters),
summary, dba, interpreter_context, execution_memory, memory_limit,
summary, dba, interpreter_context, execution_memory, memory_limit, optional_username,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
stats_and_total_time = std::optional<plan::ProfilingStatsWithTotalTime>{},
Expand All @@ -1279,7 +1310,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
// No output symbols are given so that nothing is streamed.
if (!stats_and_total_time) {
stats_and_total_time = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory, nullptr, memory_limit)
execution_memory, optional_username, nullptr, memory_limit)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(*stats_and_total_time));
}
Expand Down Expand Up @@ -1414,7 +1445,7 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans

PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory) {
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
Expand All @@ -1434,8 +1465,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));

auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory);
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, StringPointerToOptional(username));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
Expand Down Expand Up @@ -2148,15 +2179,15 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory,
&query_execution->notifications,
&query_execution->notifications, username,
trigger_context_collector_ ? &*trigger_context_collector_ : nullptr);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory_with_exception);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query = PrepareDumpQuery(std::move(parsed_query), &query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory);
Expand All @@ -2166,7 +2197,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception);
&query_execution->execution_memory_with_exception, username);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db,
Expand Down
3 changes: 3 additions & 0 deletions src/query/interpreter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class AuthQueryHandler {

virtual std::vector<std::vector<TypedValue>> GetPrivileges(const std::string &user_or_role) = 0;

/// @throw QueryRuntimeException if an error ocurred.
virtual memgraph::auth::User *GetUser(const std::string &username) = 0;

/// @throw QueryRuntimeException if an error ocurred.
virtual void GrantPrivilege(const std::string &user_or_role, const std::vector<AuthQuery::Privilege> &privileges,
const std::vector<std::string> &labels) = 0;
Expand Down

0 comments on commit a2643cc

Please sign in to comment.