Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch aggregate calls in constraints and indexes in ql compiler #7343

Merged
merged 5 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions edb/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from edb.common import ordered

from edb.edgeql import qltypes as ft
from edb.schema import name as sn

from . import ast as irast
from . import typeutils
Expand Down Expand Up @@ -489,13 +490,47 @@ def find_potentially_visible(
return visible_paths


def contains_set_of_op(ir: irast.Base) -> bool:
def is_singleton_set_of_call(
call: irast.Call
) -> bool:
# Some set functions and operators are allowed in singleton mode
# as long as their inputs are singletons

return call.func_shortname in {
sn.QualName('std', 'IN'),
sn.QualName('std', 'NOT IN'),
sn.QualName('std', 'EXISTS'),
sn.QualName('std', '??'),
sn.QualName('std', 'IF'),
}
Comment on lines +493 to +505
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this a new flag on operators that we configure in the standard library creation code. No sure if it is worth it, though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially was going to do that, but it seemed like a lot of machinery. I think it's probably the "correct" thing to do in the long run.



def has_set_of_param(
call: irast.Call,
) -> bool:
return any(
arg.param_typemod == ft.TypeModifier.SetOfType
for arg in call.args.values()
)


def returns_set_of(
call: irast.Call,
) -> bool:
return call.typemod == ft.TypeModifier.SetOfType


def find_set_of_op(
ir: irast.Base,
has_multi_param: bool,
) -> Optional[irast.Call]:
def flt(n: irast.Call) -> bool:
return any(
arg.param_typemod == ft.TypeModifier.SetOfType
for arg in n.args.values()
return (
(has_multi_param or not is_singleton_set_of_call(n))
and (has_set_of_param(n) or returns_set_of(n))
)
return bool(ast.find_children(ir, irast.Call, flt, terminate_early=True))
calls = ast.find_children(ir, irast.Call, flt, terminate_early=True)
return next(iter(calls or []), None)
Comment on lines +532 to +533
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the bool thing not actually work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're trying to get the actual value, OK



T = TypeVar('T')
Expand Down
20 changes: 16 additions & 4 deletions edb/pgsql/compiler/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,16 @@ def compile_OperatorCall(
dispatch.compile(r_expr, ctx=ctx),
],
)
elif expr.typemod is ql_ft.TypeModifier.SetOfType:
elif irutils.is_singleton_set_of_call(expr):
pass
elif irutils.returns_set_of(expr):
raise errors.UnsupportedFeatureError(
f'set returning operator {expr.func_shortname} is not supported '
f'in singleton expressions')
f"set returning operator '{expr.func_shortname}' is not supported "
f"in singleton expressions")
elif irutils.has_set_of_param(expr):
raise errors.UnsupportedFeatureError(
f"aggregate operator '{expr.func_shortname}' is not supported "
f"in singleton expressions")

args, maybe_null = _compile_call_args(expr, ctx=ctx)
return _wrap_call(
Expand Down Expand Up @@ -683,9 +689,15 @@ def compile_FunctionCall(
f'unimplemented function for singleton mode: {fname}'
)

if expr.typemod is ql_ft.TypeModifier.SetOfType:
if irutils.is_singleton_set_of_call(expr):
pass
elif irutils.returns_set_of(expr):
raise errors.UnsupportedFeatureError(
'set returning functions are not supported in simple expressions')
elif irutils.has_set_of_param(expr):
raise errors.UnsupportedFeatureError(
f"aggregate function '{expr.func_shortname}' is not supported "
f"in singleton expressions")

args, maybe_null = _compile_call_args(expr, ctx=ctx)

Expand Down
93 changes: 45 additions & 48 deletions edb/schema/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def _populate_concrete_constraint_attrs(
self,
schema: s_schema.Schema,
context: sd.CommandContext,
subject_obj: Optional[so.Object],
subject_obj: so.Object,
*,
name: sn.QualName,
subjectexpr: Optional[s_expr.Expression] = None,
Expand Down Expand Up @@ -761,17 +761,27 @@ def _populate_concrete_constraint_attrs(
if not constr_base.is_non_concrete(schema):
return

orig_subjectexpr = subjectexpr
orig_subject = subject_obj
attrs = dict(kwargs)
inherited = dict()

base_subjectexpr = constr_base.get_field_value(schema, 'subjectexpr')
if subjectexpr is None:
attrs['subjectexpr'] = subjectexpr
inherited['subjectexpr'] = subjectexpr_inherited

subjectexpr = base_subjectexpr
elif (base_subjectexpr is not None
else:
if (base_subjectexpr is not None
and subjectexpr.text != base_subjectexpr.text):
raise errors.InvalidConstraintDefinitionError(
f'subjectexpr is already defined for {name}',
span=sourcectx,
)
raise errors.InvalidConstraintDefinitionError(
f'subjectexpr is already defined for {name}',
span=sourcectx,
)

base_subjectexpr = constr_base.get_subjectexpr(schema)
if base_subjectexpr is not None:
attrs['subjectexpr'] = base_subjectexpr
inherited['subjectexpr'] = True

if (isinstance(subject_obj, s_scalars.ScalarType)
and constr_base.get_is_aggregate(schema)):
Expand All @@ -790,12 +800,6 @@ def _populate_concrete_constraint_attrs(
span=sourcectx,
)

if subjectexpr is not None:
subject_ql = subjectexpr.parse()
subject = subject_ql
else:
subject = subject_obj

expr: s_expr.Expression = constr_base.get_field_value(schema, 'expr')
if not expr:
raise errors.InvalidConstraintDefinitionError(
Expand All @@ -805,26 +809,16 @@ def _populate_concrete_constraint_attrs(
# the AST below.
expr_ql = qlparser.parse_query(expr.text)

if not args:
args = constr_base.get_field_value(schema, 'args')

attrs = dict(kwargs)
inherited = dict()
if orig_subjectexpr is not None:
attrs['subjectexpr'] = orig_subjectexpr
inherited['subjectexpr'] = subjectexpr_inherited
else:
base_subjectexpr = constr_base.get_subjectexpr(schema)
if base_subjectexpr is not None:
attrs['subjectexpr'] = base_subjectexpr
inherited['subjectexpr'] = True

if subject is not orig_subject:
if subjectexpr is not None:
# subject has been redefined
assert isinstance(subject, qlast.Base)
subject_ql = subjectexpr.parse()

assert isinstance(subject_ql, qlast.Base)
qlutils.inline_anchors(
expr_ql, anchors={'__subject__': subject})
subject = orig_subject
expr_ql, anchors={'__subject__': subject_ql})

if not args:
args = constr_base.get_field_value(schema, 'args')

if args:
args_ql: List[qlast.Base] = [
Expand All @@ -840,17 +834,13 @@ def _populate_concrete_constraint_attrs(

attrs['args'] = args

if subject_obj:
assert isinstance(subject_obj, (s_types.Type, s_pointers.Pointer))
singletons = frozenset({subject_obj})
else:
singletons = frozenset()
assert isinstance(subject_obj, (s_types.Type, s_pointers.Pointer))
singletons = frozenset({subject_obj})

assert subject is not None
final_expr = s_expr.Expression.from_ast(expr_ql, schema, {}).compiled(
schema=schema,
options=qlcompiler.CompilerOptions(
anchors={'__subject__': subject},
anchors={'__subject__': subject_obj},
path_prefix_anchor='__subject__',
singletons=singletons,
apply_query_rewrites=False,
Expand All @@ -874,7 +864,7 @@ def _populate_concrete_constraint_attrs(

if subjectexpr is not None:
options = qlcompiler.CompilerOptions(
anchors={'__subject__': subject},
anchors={'__subject__': subject_obj},
path_prefix_anchor='__subject__',
singletons=singletons,
apply_query_rewrites=False,
Expand All @@ -896,7 +886,6 @@ def _populate_concrete_constraint_attrs(

has_any_multi = has_non_subject_multi = False
for ref in refs:
assert subject_obj
while isinstance(ref.expr, irast.Pointer):
rptr = ref.expr

Expand All @@ -920,7 +909,7 @@ def _populate_concrete_constraint_attrs(
irast.TupleIndirectionPointerRef)
and rptr.ptrref.source_ptr is None
and isinstance(rptr.source.expr, irast.Pointer)):
if isinstance(subject, s_links.Link):
if isinstance(subject_obj, s_links.Link):
raise errors.InvalidConstraintDefinitionError(
"link constraints may not access "
"the link target",
Expand All @@ -942,12 +931,20 @@ def _populate_concrete_constraint_attrs(
span=sourcectx
)

if has_any_multi and ir_utils.contains_set_of_op(
final_subjectexpr.irast):
raise errors.InvalidConstraintDefinitionError(
"cannot use aggregate functions or operators "
"in a non-aggregating constraint",
span=sourcectx
if set_of_op := ir_utils.find_set_of_op(
final_subjectexpr.irast,
has_any_multi,
):
label = (
'function'
if isinstance(set_of_op, irast.FunctionCall) else
'operator'
)
op_name = str(set_of_op.func_shortname)
raise errors.UnsupportedFeatureError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should say SET OF instead of aggregate.

Also, I know I wrote the original error message, but I have no idea what "non-aggregating constraint" means, so we should stop saying it. :P

f"cannot use SET OF {label} '{op_name}' "
f"in a constraint",
span=set_of_op.span
)

if (
Expand Down
18 changes: 14 additions & 4 deletions edb/schema/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def compile_expr_field(
value: s_expr.Expression,
track_schema_ref_exprs: bool=False,
) -> s_expr.CompiledExpression:
from edb.ir import ast as irast
from edb.ir import utils as irutils

if field.name in {'expr', 'except_expr'}:
Expand Down Expand Up @@ -850,11 +851,20 @@ def compile_expr_field(
has_multi = True
break

if has_multi and irutils.contains_set_of_op(expr.irast):
if set_of_op := irutils.find_set_of_op(
expr.irast,
has_multi,
):
label = (
'function'
if isinstance(set_of_op, irast.FunctionCall) else
'operator'
)
op_name = str(set_of_op.func_shortname)
raise errors.SchemaDefinitionError(
"cannot use aggregate functions or operators "
"in an index expression",
span=self.span,
f"cannot use SET OF {label} '{op_name}' "
f"in an index expression",
span=set_of_op.span
)

# compile the expression to sql to preempt errors downstream
Expand Down
Loading