Skip to content

Commit

Permalink
Report failing access policy (#4529)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Oct 21, 2022
1 parent 748d873 commit 8aa7b9f
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 91 deletions.
9 changes: 4 additions & 5 deletions edb/edgeql/compiler/clauses.py
Expand Up @@ -40,12 +40,11 @@


def compile_where_clause(
ir_stmt: irast.FilteredStmt,
where: Optional[qlast.Base], *,
ctx: context.ContextLevel) -> None:
where: Optional[qlast.Base], *, ctx: context.ContextLevel
) -> Optional[irast.Set]:

if where is None:
return
return None

if ctx.partial_path_prefix:
pathctx.register_set_in_scope(ctx.partial_path_prefix, ctx=ctx)
Expand All @@ -57,7 +56,7 @@ def compile_where_clause(
bool_t = ctx.env.get_track_schema_type(sn.QualName('std', 'bool'))
ir_set = setgen.scoped_set(ir_expr, typehint=bool_t, ctx=subctx)

ir_stmt.where = ir_set
return ir_set


def compile_orderby_clause(
Expand Down
21 changes: 14 additions & 7 deletions edb/edgeql/compiler/inference/cardinality.py
Expand Up @@ -1098,11 +1098,16 @@ def _infer_dml_check_cardinality(
ctx: inference_context.InfCtx,
) -> None:
pctx = ctx._replace(singletons=ctx.singletons | {ir.result.path_id})
for pol in [
*ir.read_policy_exprs.values(), *ir.write_policy_exprs.values()
]:
pol.cardinality = infer_cardinality(
pol.expr, scope_tree=scope_tree, ctx=pctx)
for read_pol in ir.read_policies.values():
read_pol.cardinality = infer_cardinality(
read_pol.expr, scope_tree=scope_tree, ctx=pctx
)

for write_pol in ir.write_policies.values():
for p in write_pol.policies:
p.cardinality = infer_cardinality(
p.expr, scope_tree=scope_tree, ctx=pctx
)

if ir.conflict_checks:
for on_conflict in ir.conflict_checks:
Expand Down Expand Up @@ -1258,8 +1263,10 @@ def __infer_insert_stmt(
# ... except if UNLESS CONFLICT is used
else:
return _infer_on_conflict_cardinality(
ir.on_conflict, type_has_rewrites=bool(ir.write_policy_exprs),
scope_tree=scope_tree, ctx=ctx,
ir.on_conflict,
type_has_rewrites=bool(ir.write_policies),
scope_tree=scope_tree,
ctx=ctx,
)


Expand Down
9 changes: 5 additions & 4 deletions edb/edgeql/compiler/inference/multiplicity.py
Expand Up @@ -683,11 +683,12 @@ def _infer_mutating_stmt(
for clause in ir.conflict_checks:
_infer_on_conflict_clause(clause, scope_tree=scope_tree, ctx=ctx)

for policy in ir.write_policy_exprs.values():
infer_multiplicity(policy.expr, scope_tree=scope_tree, ctx=ctx)
for write_pol in ir.write_policies.values():
for policy in write_pol.policies:
infer_multiplicity(policy.expr, scope_tree=scope_tree, ctx=ctx)

for policy in ir.read_policy_exprs.values():
infer_multiplicity(policy.expr, scope_tree=scope_tree, ctx=ctx)
for read_pol in ir.read_policies.values():
infer_multiplicity(read_pol.expr, scope_tree=scope_tree, ctx=ctx)


def _infer_on_conflict_clause(
Expand Down
94 changes: 72 additions & 22 deletions edb/edgeql/compiler/policies.py
Expand Up @@ -79,7 +79,6 @@ def has_own_policies(


def compile_pol(
stype: s_objtypes.ObjectType,
pol: s_policies.AccessPolicy, *,
ctx: context.ContextLevel,
) -> irast.Set:
Expand All @@ -99,6 +98,7 @@ def compile_pol(
expr = expr_field.qlast
else:
expr = qlast.BooleanConstant(value='true')

if condition := pol.get_condition(schema):
expr = qlast.BinOp(op='AND', left=condition.qlast, right=expr)

Expand All @@ -125,6 +125,8 @@ def get_rewrite_filter(
) -> Optional[qlast.Expr]:
schema = ctx.env.schema
pols = get_access_policies(stype, ctx=ctx)
if not pols:
return None

ctx.anchors = ctx.anchors.copy()

Expand All @@ -133,7 +135,7 @@ def get_rewrite_filter(
if mode not in pol.get_access_kinds(schema):
continue

ir_set = compile_pol(stype, pol, ctx=ctx)
ir_set = compile_pol(pol, ctx=ctx)
expr = ctx.create_anchor(ir_set)

is_allow = pol.get_action(schema) == qltypes.AccessPolicyAction.Allow
Expand All @@ -142,9 +144,7 @@ def get_rewrite_filter(
else:
deny.append(expr)

if not pols:
filter_expr = None
elif allow:
if allow:
filter_expr = astutils.extend_binop(None, *allow, op='OR')
else:
filter_expr = qlast.BooleanConstant(value='false')
Expand All @@ -161,7 +161,7 @@ def get_rewrite_filter(
# from bogusly optimizing away the entire type CTE if it can prove
# it empty (which could then result in assert_exists on links to
# the type not always firing).
if filter_expr and mode == qltypes.AccessKind.Select:
if mode == qltypes.AccessKind.Select:
bogus_check = qlast.BinOp(
op='?=',
left=qlast.Path(partial=True, steps=[qlast.Ptr(
Expand Down Expand Up @@ -271,8 +271,7 @@ def try_type_rewrite(
subctx.partial_path_prefix = base_set
subctx.path_scope = subctx.env.path_scope.root.attach_fence()

clauses.compile_where_clause(
filtered_stmt,
filtered_stmt.where = clauses.compile_where_clause(
get_rewrite_filter(
stype, mode=qltypes.AccessKind.Select, ctx=subctx),
ctx=subctx)
Expand Down Expand Up @@ -314,36 +313,87 @@ def try_type_rewrite(
type_rewrites[rw_key] = filtered_set


def compile_dml_policy(
def compile_dml_write_policies(
stype: s_objtypes.ObjectType,
result: irast.Set,
mode: qltypes.AccessKind, *,
ctx: context.ContextLevel,
) -> Optional[irast.PolicyExpr]:
"""Compile a policy filter for a DML statement at a particular type"""
) -> Optional[irast.WritePolicies]:
"""Compile policy filters and wrap them into irast.WritePolicies"""
if not ctx.env.type_rewrites.get((stype, False)):
return None

pols = get_access_policies(stype, ctx=ctx)
if not pols:
with ctx.detached() as _, ctx.newscope(fenced=True) as subctx:
# TODO: can we make sure to always avoid generating needless
# select filters
_prepare_dml_policy_context(stype, result, ctx=subctx)

schema = subctx.env.schema
subctx.anchors = subctx.anchors.copy()

pols = get_access_policies(stype, ctx=ctx)
if not pols:
return None

policies = []
for pol in pols:
if mode not in pol.get_access_kinds(schema):
continue

ir_set = compile_pol(pol, ctx=subctx)

action = pol.get_action(schema)
name = str(pol.get_shortname(schema))

policies.append(
irast.WritePolicy(
expr=ir_set,
action=action,
name=name,
error_msg=pol.get_errmessage(schema),
)
)

return irast.WritePolicies(policies=policies)


def compile_dml_read_policies(
stype: s_objtypes.ObjectType,
result: irast.Set,
mode: qltypes.AccessKind,
*,
ctx: context.ContextLevel,
) -> Optional[irast.ReadPolicyExpr]:
"""Compile a policy filter for a DML statement at a particular type"""
if not ctx.env.type_rewrites.get((stype, False)):
return None

with ctx.detached() as _, ctx.newscope(fenced=True) as subctx:
# TODO: can we make sure to always avoid generating needless
# select filters
skip_subtypes = (stype, False) not in ctx.env.type_rewrites
result = setgen.class_set(
stype, path_id=result.path_id, skip_subtypes=skip_subtypes,
ctx=ctx)

subctx.anchors[qlast.Subject().name] = result
subctx.partial_path_prefix = result
_prepare_dml_policy_context(stype, result, ctx=subctx)

condition = get_rewrite_filter(stype, mode=mode, ctx=subctx)
assert condition
if not condition:
return None

return irast.PolicyExpr(
return irast.ReadPolicyExpr(
expr=setgen.scoped_set(
dispatch.compile(condition, ctx=subctx), ctx=subctx
),
)


def _prepare_dml_policy_context(
stype: s_objtypes.ObjectType,
result: irast.Set,
*,
ctx: context.ContextLevel,
) -> None:
skip_subtypes = (stype, False) not in ctx.env.type_rewrites
result = setgen.class_set(
stype, path_id=result.path_id, skip_subtypes=skip_subtypes, ctx=ctx
)

ctx.anchors[qlast.Subject().name] = result
ctx.partial_path_prefix = result
25 changes: 11 additions & 14 deletions edb/edgeql/compiler/stmt.py
Expand Up @@ -133,8 +133,7 @@ def compile_SelectQuery(
forward_rptr=forward_rptr,
ctx=sctx)

clauses.compile_where_clause(
stmt, expr.where, ctx=sctx)
stmt.where = clauses.compile_where_clause(expr.where, ctx=sctx)

stmt.orderby = clauses.compile_orderby_clause(
expr.orderby, ctx=sctx)
Expand Down Expand Up @@ -391,8 +390,7 @@ def compile_InternalGroupQuery(
result_alias=expr.result_alias,
ctx=bctx)

clauses.compile_where_clause(
stmt, expr.where, ctx=bctx)
stmt.where = clauses.compile_where_clause(expr.where, ctx=bctx)

stmt.orderby = clauses.compile_orderby_clause(
expr.orderby, ctx=bctx)
Expand Down Expand Up @@ -512,10 +510,10 @@ def compile_InsertQuery(
ctx=resultctx,
)

if pol_condition := policies.compile_dml_policy(
if pol_condition := policies.compile_dml_write_policies(
mat_stype, result, mode=qltypes.AccessKind.Insert, ctx=ctx
):
stmt.write_policy_exprs[mat_stype.id] = pol_condition
stmt.write_policies[mat_stype.id] = pol_condition

result = fini_stmt(stmt, expr, ctx=ictx, parent_ctx=ctx)

Expand Down Expand Up @@ -600,8 +598,7 @@ def compile_UpdateQuery(

ictx.partial_path_prefix = subject

clauses.compile_where_clause(
stmt, expr.where, ctx=ictx)
stmt.where = clauses.compile_where_clause(expr.where, ctx=ictx)

with ictx.new() as bodyctx:
bodyctx.class_view_overrides = ictx.class_view_overrides.copy()
Expand Down Expand Up @@ -634,14 +631,14 @@ def compile_UpdateQuery(
)

for dtype in schemactx.get_all_concrete(mat_stype, ctx=ctx):
if pol_cond := policies.compile_dml_policy(
if read_pol := policies.compile_dml_read_policies(
dtype, result, mode=qltypes.AccessKind.UpdateRead, ctx=ctx
):
stmt.read_policy_exprs[dtype.id] = pol_cond
if pol_cond := policies.compile_dml_policy(
stmt.read_policies[dtype.id] = read_pol
if write_pol := policies.compile_dml_write_policies(
dtype, result, mode=qltypes.AccessKind.UpdateWrite, ctx=ctx
):
stmt.write_policy_exprs[dtype.id] = pol_cond
stmt.write_policies[dtype.id] = write_pol

stmt.conflict_checks = conflicts.compile_inheritance_conflict_checks(
stmt, mat_stype, ctx=ictx)
Expand Down Expand Up @@ -754,10 +751,10 @@ def compile_DeleteQuery(
)

for dtype in schemactx.get_all_concrete(mat_stype, ctx=ctx):
if pol_cond := policies.compile_dml_policy(
if pol_cond := policies.compile_dml_read_policies(
dtype, result, mode=qltypes.AccessKind.Delete, ctx=ctx
):
stmt.read_policy_exprs[dtype.id] = pol_cond
stmt.read_policies[dtype.id] = pol_cond

result = fini_stmt(stmt, expr, ctx=ictx, parent_ctx=ctx)

Expand Down
3 changes: 3 additions & 0 deletions edb/edgeql/parser/grammar/ddl.py
Expand Up @@ -1781,6 +1781,7 @@ def reduce_DropLink(self, *kids):
commands_block(
'CreateAccessPolicy',
CreateAnnotationValueStmt,
SetFieldStmt,
)


Expand Down Expand Up @@ -1852,6 +1853,8 @@ def reduce_RESET_WHEN(self, *kids):
AccessPermStmt,
AccessUsingStmt,
AccessWhenStmt,
SetFieldStmt,
ResetFieldStmt,
opt=False
)

Expand Down
4 changes: 3 additions & 1 deletion edb/edgeql/parser/grammar/sdl.py
Expand Up @@ -939,7 +939,9 @@ def reduce_CreateQualifiedComputableLink(self, *kids):
#
sdl_commands_block(
'CreateAccessPolicy',
SetAnnotation)
SetField,
SetAnnotation
)


class AccessPolicyDeclarationBlock(Nonterm):
Expand Down
25 changes: 20 additions & 5 deletions edb/ir/ast.py
Expand Up @@ -1033,11 +1033,13 @@ class MutatingStmt(Stmt):
# for.
conflict_checks: typing.Optional[typing.List[OnConflictClause]] = None
# Access policy checks that we should raise errors on
write_policy_exprs: typing.Dict[
uuid.UUID, PolicyExpr] = ast.field(factory=dict)
write_policies: typing.Dict[uuid.UUID, WritePolicies] = ast.field(
factory=dict
)
# Access policy checks that we should filter on
read_policy_exprs: typing.Dict[
uuid.UUID, PolicyExpr] = ast.field(factory=dict)
read_policies: typing.Dict[uuid.UUID, ReadPolicyExpr] = ast.field(
factory=dict
)

@property
def material_type(self) -> TypeRef:
Expand All @@ -1048,8 +1050,21 @@ def material_type(self) -> TypeRef:
raise NotImplementedError


class PolicyExpr(Base):
class ReadPolicyExpr(Base):
expr: Set
cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN


class WritePolicies(Base):
policies: typing.List[WritePolicy]


class WritePolicy(Base):
expr: Set
action: qltypes.AccessPolicyAction
name: str
error_msg: typing.Optional[str]

cardinality: qltypes.Cardinality = qltypes.Cardinality.UNKNOWN


Expand Down
1 change: 1 addition & 0 deletions edb/lib/schema.edgeql
Expand Up @@ -411,6 +411,7 @@ ALTER TYPE schema::AccessPolicy {
CREATE PROPERTY condition -> std::str;
CREATE REQUIRED PROPERTY action -> schema::AccessPolicyAction;
CREATE PROPERTY expr -> std::str;
CREATE PROPERTY errmessage -> std::str;
};


Expand Down

0 comments on commit 8aa7b9f

Please sign in to comment.