Skip to content

Commit

Permalink
parsing error messages, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Oct 20, 2022
1 parent 1a119a1 commit e374fa5
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 60 deletions.
3 changes: 3 additions & 0 deletions edb/edgeql/parser/grammar/ddl.py
Expand Up @@ -1775,6 +1775,7 @@ def reduce_DropLink(self, *kids):
commands_block(
'CreateAccessPolicy',
CreateAnnotationValueStmt,
SetFieldStmt,
)


Expand Down Expand Up @@ -1846,6 +1847,8 @@ def reduce_RESET_WHEN(self, *kids):
AccessPermStmt,
AccessUsingStmt,
AccessWhenStmt,
SetFieldStmt,
ResetFieldStmt,
opt=False
)

Expand Down
4 changes: 1 addition & 3 deletions edb/edgeql/parser/grammar/sdl.py
Expand Up @@ -937,9 +937,7 @@ def reduce_CreateQualifiedComputableLink(self, *kids):
#
# Access Policies
#
sdl_commands_block(
'CreateAccessPolicy',
SetAnnotation)
sdl_commands_block('CreateAccessPolicy', SetField, SetAnnotation)


class AccessPolicyDeclarationBlock(Nonterm):
Expand Down
40 changes: 0 additions & 40 deletions edb/pgsql/compiler/astutils.py
Expand Up @@ -233,46 +233,6 @@ def extend_select_op(
return result


def conditional_string_agg(
pairs: Sequence[Tuple[pgast.BaseExpr, pgast.BaseExpr]],
) -> Optional[pgast.BaseExpr]:

selects = [
pgast.SelectStmt(
target_list=[pgast.ResTarget(val=value)],
where_clause=cond,
)
for value, cond in pairs
]
union = extend_select_op(None, *selects)

if not union:
return None

return pgast.SubLink(
type=pgast.SubLinkType.EXPR,
expr=pgast.SelectStmt(
target_list=[
pgast.ResTarget(
val=pgast.FuncCall(
name=('string_agg',),
args=[
pgast.ColumnRef(name=('error_msg',)),
pgast.StringConstant(val=', '),
],
)
)
],
from_clause=[
pgast.RangeSubselect(
subquery=union,
alias=pgast.Alias(aliasname='t', colnames=['error_msg']),
)
],
),
)


def new_unop(op: str, expr: pgast.BaseExpr) -> pgast.Expr:
return pgast.Expr(
kind=pgast.ExprKind.OP,
Expand Down
58 changes: 48 additions & 10 deletions edb/pgsql/compiler/dml.py
Expand Up @@ -891,7 +891,7 @@ def raise_if(
hint += ', '.join(allow_msgs)
allow_msg = f'{msg} ({hint})'
else:
allow_msg = f'{msg} (no allow policies)'
allow_msg = msg
allow_msg_expr = pgast.StringConstant(val=allow_msg)

ictx.rel.target_list.append(
Expand All @@ -907,16 +907,9 @@ def raise_if(
deny_expr = astutils.extend_binop(None, *deny_conds, op='OR')

deny_messages = [
(
pgast.StringConstant(
val=f'denied by policy {pol.error_msg}'
),
cond,
)
for pol, cond in deny
if pol.error_msg
(pol.error_msg, cond) for pol, cond in deny if pol.error_msg
]
deny_message = astutils.conditional_string_agg(deny_messages)
deny_message = _conditional_string_agg(deny_messages)
if deny_message:
deny_message = astutils.extend_concat(
msg + ' (', deny_message, ')'
Expand All @@ -940,6 +933,51 @@ def raise_if(
return policy_cte


def _conditional_string_agg(
pairs: Sequence[Tuple[str, pgast.BaseExpr]],
) -> Optional[pgast.BaseExpr]:

selects = [
pgast.SelectStmt(
target_list=[pgast.ResTarget(val=pgast.StringConstant(val=str))],
where_clause=cond,
)
for str, cond in pairs
]
union = astutils.extend_select_op(None, *selects)

if not union:
return None

return (
pgast.SelectStmt(
target_list=[
pgast.ResTarget(
val=pgast.FuncCall(
name=('coalesce',),
args=[
pgast.FuncCall(
name=('string_agg',),
args=[
pgast.ColumnRef(name=('error_msg',)),
pgast.StringConstant(val=', '),
],
),
pgast.StringConstant(val=''),
],
)
)
],
from_clause=[
pgast.RangeSubselect(
subquery=union,
alias=pgast.Alias(aliasname='t', colnames=['error_msg']),
)
],
),
)


def force_policy_checks(
policy_cte: pgast.CommonTableExpr,
queries: Sequence[pgast.Query],
Expand Down
45 changes: 38 additions & 7 deletions tests/test_edgeql_policies.py
Expand Up @@ -876,22 +876,53 @@ async def test_edgeql_policies_messages(self):
create required property val -> str;
create access policy allow_insert_of_a
allow insert using (.val = 'a')
{ errmessage := 'you can insert a' };
{ set errmessage := 'you can insert a' };
create access policy allow_insert_of_b
allow insert using (.val = 'b')
{ errmessage := 'you can insert b' };
};
{ set errmessage := 'you can insert b' };
};
create type ThreeDenies {
create required property val -> str;
create access policy allow_insert
allow insert;
create access policy deny_starting_with_f
deny insert using (.val[0] = 'f')
{ set errmessage := 'val cannot start with f' };
create access policy deny_foo
deny insert using (.val = 'foo')
{ set errmessage := 'val cannot be foo' };
create access policy deny_bar
deny insert using (.val = 'bar');
};
'''
)

await self.con.execute("insert TwoAllows { val := 'a' };")

async with self.assertRaisesRegexTx(
edgedb.InvalidValueError,
r"no allow policies"):
edgedb.InvalidValueError, r"access policy violation"
):
await self.con.query('insert NoAllows')

async with self.assertRaisesRegexTx(
edgedb.InvalidValueError,
r"none of these allow policies match: "):
edgedb.InvalidValueError,
"none of these allow policies match: "
"you can insert a, "
"you can insert b",
):
await self.con.query("insert TwoAllows { val := 'c' }")

async with self.assertRaisesRegexTx(
edgedb.InvalidValueError,
"access policy violation.*\(" ".*val cannot be foo.*\)",
):
await self.con.query("insert ThreeDenies { val := 'foo' }")

async with self.assertRaisesRegexTx(
edgedb.InvalidValueError, "access policy violation.*\(\)"
):
await self.con.query("insert ThreeDenies { val := 'bar' }")

0 comments on commit e374fa5

Please sign in to comment.