Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support casting between scalars with a common concrete base #5108

Merged
merged 3 commits into from Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 24 additions & 4 deletions edb/edgeql/compiler/casts.py
Expand Up @@ -36,6 +36,7 @@
from edb.schema import functions as s_func
from edb.schema import indexes as s_indexes
from edb.schema import name as sn
from edb.schema import scalars as s_scalars
from edb.schema import types as s_types
from edb.schema import utils as s_utils
from edb.schema import name as s_name
Expand Down Expand Up @@ -134,9 +135,13 @@ def compile_cast(
ir_set, orig_stype, new_stype,
cardinality_mod=cardinality_mod, ctx=ctx)

if new_stype.issubclass(ctx.env.schema, orig_stype):
# The new type is a subtype, so may potentially have
# a more restrictive domain, generate a cast call.
if (
new_stype.issubclass(ctx.env.schema, orig_stype)
or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)
):
# The new type is a subtype or a sibling type of a shared
# ancestor, so may potentially have a more restrictive domain,
# generate a cast call.
return _inheritance_cast_to_ir(
ir_set, orig_stype, new_stype,
cardinality_mod=cardinality_mod, ctx=ctx)
Expand Down Expand Up @@ -236,6 +241,20 @@ def compile_cast(
)


def _has_common_concrete_scalar(
orig_stype: s_types.Type,
new_stype: s_types.Type, *,
ctx: context.ContextLevel) -> bool:
schema = ctx.env.schema
return bool(
isinstance(orig_stype, s_scalars.ScalarType)
and isinstance(new_stype, s_scalars.ScalarType)
and (orig_base := orig_stype.maybe_get_topmost_concrete_base(schema))
and (new_base := new_stype.maybe_get_topmost_concrete_base(schema))
and orig_base == new_base
)


def _compile_cast(
ir_expr: Union[irast.Set, irast.Expr],
orig_stype: s_types.Type,
Expand Down Expand Up @@ -424,7 +443,8 @@ def _find_cast(
# Don't try to pick up casts when there is a direct subtyping
# relationship.
if (orig_stype.issubclass(ctx.env.schema, new_stype)
or new_stype.issubclass(ctx.env.schema, orig_stype)):
or new_stype.issubclass(ctx.env.schema, orig_stype)
or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)):
return None

casts = ctx.env.schema.get_casts_to_type(new_stype)
Expand Down
94 changes: 80 additions & 14 deletions tests/test_edgeql_casts.py
Expand Up @@ -2544,6 +2544,86 @@ async def test_edgeql_casts_custom_scalar_01(self):
await self.con.query(
"SELECT <custom_str_t>'123'")

async def test_edgeql_casts_custom_scalar_02(self):
await self.assert_query_result(
"""
SELECT <foo><bar>'test'
""",
['test'],
)

await self.assert_query_result(
"""
SELECT <array<foo>><array<bar>>['test']
""",
[['test']],
)

async def test_edgeql_casts_custom_scalar_03(self):
await self.assert_query_result(
"""
SELECT <array<custom_str_t>><array<bar>>['TEST']
""",
[['TEST']],
)

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, r'invalid'
):
await self.con.query("""
SELECT <custom_str_t><bar>'test'
""")

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, r'invalid'
):
await self.con.query("""
SELECT <array<custom_str_t>><array<bar>>['test']
""")

async def test_edgeql_casts_custom_scalar_04(self):
await self.con.execute('''
create abstract scalar type abs extending int64;
create scalar type foo2 extending abs;
create scalar type bar2 extending abs;
''')

await self.assert_query_result(
"""
SELECT <foo2><bar2>42
""",
[42],
)

await self.assert_query_result(
"""
SELECT <array<foo2>><array<bar2>>[42]
""",
[[42]],
)

async def test_edgeql_casts_custom_scalar_05(self):
await self.con.execute('''
create abstract scalar type xfoo extending int64;
create abstract scalar type xbar extending int64;
create scalar type bar1 extending xfoo, xbar;
create scalar type bar2 extending xfoo, xbar;
''')

await self.assert_query_result(
"""
SELECT <bar1><bar2>42
""",
[42],
)

await self.assert_query_result(
"""
SELECT <array<bar1>><array<bar2>>[42]
""",
[[42]],
)

async def test_edgeql_casts_tuple_params_01(self):
# insert tuples into a nested array
def nest(data):
Expand Down Expand Up @@ -2792,20 +2872,6 @@ async def test_edgeql_cast_empty_set_to_array_01(self):
[],
)

async def test_edgeql_casts_custom_scalars_01(self):
async with self.assertRaisesRegexTx(
edgedb.QueryError, r'cannot cast'):
await self.con.execute("""
SELECT <foo><bar>'test'
""")

async def test_edgeql_casts_custom_scalars_02(self):
async with self.assertRaisesRegexTx(
edgedb.QueryError, r'cannot cast'):
await self.con.execute("""
SELECT <array<foo>><array<bar>>['test']
""")

async def test_edgeql_casts_std_enum_01(self):
await self.assert_query_result(
'''
Expand Down