Skip to content

Commit

Permalink
Fix casts from json to custom scalars (#5624)
Browse files Browse the repository at this point in the history
We've had a codepath to fix this bug for enum already, but now I extended it to anything that:

 * is a scalar,
 * it not the top-most concrete base of itself,

Closes #5616
  • Loading branch information
aljazerzen committed Jun 9, 2023
1 parent 02d2119 commit 7972b70
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 26 deletions.
58 changes: 32 additions & 26 deletions edb/edgeql/compiler/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ def compile_cast(
):
return _find_object_by_id(ir_expr, new_stype, ctx=ctx)

json_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'json'))

if isinstance(ir_set.expr, irast.Array):
return _cast_array_literal(
ir_set, orig_stype, new_stype, srcctx=srcctx, ctx=ctx)
Expand Down Expand Up @@ -145,38 +143,33 @@ def compile_cast(
ir_set, orig_stype, new_stype,
cardinality_mod=cardinality_mod, ctx=ctx)

json_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'json'))
if (
new_stype.issubclass(ctx.env.schema, json_t)
and ir_set.path_id.is_objtype_path()
):
# JSON casts of objects are special: we want the full shape
# and not just an identity.
viewgen.late_compile_view_shapes(ir_set, ctx=ctx)
else:
if orig_stype.issubclass(ctx.env.schema, json_t) and new_stype.is_enum(
ctx.env.schema
):
# Casts from json to enums need some special handling
# here, where we have access to the enum type. Just turn
# it into json->str and str->enum.
str_typ = ctx.env.get_schema_type_and_track(
sn.QualName('std', 'str')
)
str_ir = compile_cast(ir_expr, str_typ, srcctx=srcctx, ctx=ctx)
elif orig_stype.issubclass(ctx.env.schema, json_t):

if base_stype := _get_concrete_scalar_base(new_stype, ctx):
# Casts from json to custom scalars may have special handling.
# So we turn the type cast json->x into json->base and base->x.
base_ir = compile_cast(ir_expr, base_stype, srcctx=srcctx, ctx=ctx)

return compile_cast(
str_ir,
base_ir,
new_stype,
cardinality_mod=cardinality_mod,
srcctx=srcctx,
ctx=ctx,
)

if (
orig_stype.issubclass(ctx.env.schema, json_t)
and isinstance(new_stype, s_types.Array)
and not new_stype.get_subtypes(ctx.env.schema)[0].issubclass(
ctx.env.schema, json_t
)
elif isinstance(
new_stype, s_types.Array
) and not new_stype.get_subtypes(ctx.env.schema)[0].issubclass(
ctx.env.schema, json_t
):
# Turn casts from json->array<T> into json->array<json>
# and array<json>->array<T>.
Expand All @@ -194,9 +187,7 @@ def compile_cast(
json_array_ir, new_stype, srcctx=srcctx, ctx=ctx
)

if orig_stype.issubclass(ctx.env.schema, json_t) and isinstance(
new_stype, s_types.Tuple
):
elif isinstance(new_stype, s_types.Tuple):
return _cast_json_to_tuple(
ir_set,
orig_stype,
Expand All @@ -206,9 +197,7 @@ def compile_cast(
ctx=ctx,
)

if orig_stype.issubclass(ctx.env.schema, json_t) and isinstance(
new_stype, s_types.Range
):
elif isinstance(new_stype, s_types.Range):
return _cast_json_to_range(
ir_set,
orig_stype,
Expand Down Expand Up @@ -260,6 +249,23 @@ def _has_common_concrete_scalar(
)


def _get_concrete_scalar_base(
stype: s_types.Type,
ctx: context.ContextLevel
) -> Optional[s_types.Type]:
"""Returns None if stype is not scalar or if it is already topmost"""

if stype.is_enum(ctx.env.schema):
return ctx.env.get_schema_type_and_track(sn.QualName('std', 'str'))

if not isinstance(stype, s_scalars.ScalarType):
return None
if topmost := stype.maybe_get_topmost_concrete_base(ctx.env.schema):
if topmost != stype:
return topmost
return None


def _compile_cast(
ir_expr: Union[irast.Set, irast.Expr],
orig_stype: s_types.Type,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_edgeql_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,6 +2624,25 @@ async def test_edgeql_casts_custom_scalar_05(self):
[[42]],
)

async def test_edgeql_casts_custom_scalar_06(self):
await self.con.execute(
'''
create scalar type x extending str {
create constraint expression on (false)
};
'''
)

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, 'invalid x'
):
await self.con.query("""SELECT <x>42""")

async with self.assertRaisesRegexTx(
edgedb.ConstraintViolationError, 'invalid x'
):
await self.con.query("""SELECT <x>to_json('"a"')""")

async def test_edgeql_casts_tuple_params_01(self):
# insert tuples into a nested array
def nest(data):
Expand Down

0 comments on commit 7972b70

Please sign in to comment.