Skip to content

Commit

Permalink
Support casting between scalars with a common concrete base (#5108)
Browse files Browse the repository at this point in the history
Seems like this was not supported because of an oversight, and then I
unthinkingly added a test that asserted the behavior after I permuted
some related code in #3662.
  • Loading branch information
msullivan committed Mar 1, 2023
1 parent 9180505 commit 833bc4a
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 18 deletions.
28 changes: 24 additions & 4 deletions edb/edgeql/compiler/casts.py
Expand Up @@ -36,6 +36,7 @@
from edb.schema import functions as s_func
from edb.schema import indexes as s_indexes
from edb.schema import name as sn
from edb.schema import scalars as s_scalars
from edb.schema import types as s_types
from edb.schema import utils as s_utils
from edb.schema import name as s_name
Expand Down Expand Up @@ -134,9 +135,13 @@ def compile_cast(
ir_set, orig_stype, new_stype,
cardinality_mod=cardinality_mod, ctx=ctx)

if new_stype.issubclass(ctx.env.schema, orig_stype):
# The new type is a subtype, so may potentially have
# a more restrictive domain, generate a cast call.
if (
new_stype.issubclass(ctx.env.schema, orig_stype)
or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)
):
# The new type is a subtype or a sibling type of a shared
# ancestor, so may potentially have a more restrictive domain,
# generate a cast call.
return _inheritance_cast_to_ir(
ir_set, orig_stype, new_stype,
cardinality_mod=cardinality_mod, ctx=ctx)
Expand Down Expand Up @@ -236,6 +241,20 @@ def compile_cast(
)


def _has_common_concrete_scalar(
orig_stype: s_types.Type,
new_stype: s_types.Type, *,
ctx: context.ContextLevel) -> bool:
schema = ctx.env.schema
return bool(
isinstance(orig_stype, s_scalars.ScalarType)
and isinstance(new_stype, s_scalars.ScalarType)
and (orig_base := orig_stype.maybe_get_topmost_concrete_base(schema))
and (new_base := new_stype.maybe_get_topmost_concrete_base(schema))
and orig_base == new_base
)


def _compile_cast(
ir_expr: Union[irast.Set, irast.Expr],
orig_stype: s_types.Type,
Expand Down Expand Up @@ -424,7 +443,8 @@ def _find_cast(
# Don't try to pick up casts when there is a direct subtyping
# relationship.
if (orig_stype.issubclass(ctx.env.schema, new_stype)
or new_stype.issubclass(ctx.env.schema, orig_stype)):
or new_stype.issubclass(ctx.env.schema, orig_stype)
or _has_common_concrete_scalar(orig_stype, new_stype, ctx=ctx)):
return None

casts = ctx.env.schema.get_casts_to_type(new_stype)
Expand Down
94 changes: 80 additions & 14 deletions tests/test_edgeql_casts.py
Expand Up @@ -2544,6 +2544,86 @@ async def test_edgeql_casts_custom_scalar_01(self):
await self.con.query(
"SELECT <custom_str_t>'123'")

async def test_edgeql_casts_custom_scalar_02(self):
await self.assert_query_result(
"""
SELECT <foo><bar>'test'
""",
['test'],
)

await self.assert_query_result(
"""
SELECT <array<foo>><array<bar>>['test']
""",
[['test']],
)

async def test_edgeql_casts_custom_scalar_03(self):
await self.assert_query_result(
"""
SELECT <array<custom_str_t>><array<bar>>['TEST']
""",
[['TEST']],
)

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, r'invalid'
):
await self.con.query("""
SELECT <custom_str_t><bar>'test'
""")

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, r'invalid'
):
await self.con.query("""
SELECT <array<custom_str_t>><array<bar>>['test']
""")

async def test_edgeql_casts_custom_scalar_04(self):
await self.con.execute('''
create abstract scalar type abs extending int64;
create scalar type foo2 extending abs;
create scalar type bar2 extending abs;
''')

await self.assert_query_result(
"""
SELECT <foo2><bar2>42
""",
[42],
)

await self.assert_query_result(
"""
SELECT <array<foo2>><array<bar2>>[42]
""",
[[42]],
)

async def test_edgeql_casts_custom_scalar_05(self):
await self.con.execute('''
create abstract scalar type xfoo extending int64;
create abstract scalar type xbar extending int64;
create scalar type bar1 extending xfoo, xbar;
create scalar type bar2 extending xfoo, xbar;
''')

await self.assert_query_result(
"""
SELECT <bar1><bar2>42
""",
[42],
)

await self.assert_query_result(
"""
SELECT <array<bar1>><array<bar2>>[42]
""",
[[42]],
)

async def test_edgeql_casts_tuple_params_01(self):
# insert tuples into a nested array
def nest(data):
Expand Down Expand Up @@ -2792,20 +2872,6 @@ async def test_edgeql_cast_empty_set_to_array_01(self):
[],
)

async def test_edgeql_casts_custom_scalars_01(self):
async with self.assertRaisesRegexTx(
edgedb.QueryError, r'cannot cast'):
await self.con.execute("""
SELECT <foo><bar>'test'
""")

async def test_edgeql_casts_custom_scalars_02(self):
async with self.assertRaisesRegexTx(
edgedb.QueryError, r'cannot cast'):
await self.con.execute("""
SELECT <array<foo>><array<bar>>['test']
""")

async def test_edgeql_casts_std_enum_01(self):
await self.assert_query_result(
'''
Expand Down

0 comments on commit 833bc4a

Please sign in to comment.