Skip to content

Commit

Permalink
Always correctly handle variadic arguments when producing AST from mi…
Browse files Browse the repository at this point in the history
…grations (#3855)

In our schema, we store variadic types with an array argument, and we need
be consistent about stripping it away when producing ASTs.

Fixes #3807.
  • Loading branch information
msullivan committed May 14, 2022
1 parent 5248eb9 commit 3fb2efa
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
46 changes: 32 additions & 14 deletions edb/schema/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from edb.edgeql import compiler as qlcompiler
from edb.edgeql import qltypes as ft
from edb.edgeql import parser as qlparser
from edb.edgeql import qltypes
from edb.common import uuidgen

from . import abc as s_abc
Expand Down Expand Up @@ -323,6 +324,33 @@ def _params_are_all_required_singletons(
)


def make_func_param(
*,
name: str,
type: qlast.TypeExpr,
typemod: qltypes.TypeModifier = qltypes.TypeModifier.SingletonType,
kind: qltypes.ParameterKind,
default: Optional[qlast.Expr] = None,
) -> qlast.FuncParam:
# If the param is variadic, strip the array from the type in the schema
if kind is ft.ParameterKind.VariadicParam:
assert (
isinstance(type, qlast.TypeName)
and isinstance(type.maintype, qlast.ObjectRef)
and type.maintype.name == 'array'
and type.subtypes
)
type = type.subtypes[0]

return qlast.FuncParam(
name=name,
type=type,
typemod=typemod,
kind=kind,
default=default,
)


class Parameter(
so.ObjectFragment,
ParameterLike,
Expand Down Expand Up @@ -439,21 +467,11 @@ def compare_field_value(

def get_ast(self, schema: s_schema.Schema) -> qlast.FuncParam:
default = self.get_default(schema)
type = utils.typeref_to_ast(schema, self.get_type(schema))
kind = self.get_kind(schema)
# If the param is variadic, strip the array from the type in the schema
if kind is ft.ParameterKind.VariadicParam:
assert (
isinstance(type, qlast.TypeName)
and isinstance(type.maintype, qlast.ObjectRef)
and type.maintype.name == 'array'
and type.subtypes
)
type = type.subtypes[0]

return qlast.FuncParam(
return make_func_param(
name=self.get_parameter_name(schema),
type=type,
type=utils.typeref_to_ast(schema, self.get_type(schema)),
typemod=self.get_typemod(schema),
kind=kind,
default=default.qlast if default else None,
Expand Down Expand Up @@ -1107,7 +1125,7 @@ def _get_params_ast(

num: int = props['num']
default = props.get('default')
param = qlast.FuncParam(
param = make_func_param(
name=Parameter.paramname_from_fullname(props['name']),
type=utils.typeref_to_ast(schema, props['type']),
typemod=props['typemod'],
Expand Down Expand Up @@ -2153,7 +2171,7 @@ def _apply_fields_ast(
for op in self.get_subcommands(type=ParameterCommand):
props = op.get_orig_attributes(schema, context)
num: int = props['num']
param = qlast.FuncParam(
param = make_func_param(
name=Parameter.paramname_from_fullname(props['name']),
type=utils.typeref_to_ast(schema, props['type']),
typemod=props['typemod'],
Expand Down
7 changes: 7 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4766,6 +4766,13 @@ def test_schema_migrations_equivalence_function_20(self):
}
"""])

def test_schema_migrations_equivalence_function_21(self):
self._assert_migration_equivalence([r"""
function foo(variadic s: str) -> str using ("!");
""", r"""
function foo(variadic s: str) -> str using ("?");
"""])

def test_schema_migrations_equivalence_recursive_01(self):
with self.assertRaisesRegex(
errors.InvalidDefinitionError,
Expand Down

0 comments on commit 3fb2efa

Please sign in to comment.