From 7972b70ac55f8c85ea775aa08a6ac9d291f9eb78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Fri, 9 Jun 2023 05:51:07 +0200 Subject: [PATCH] Fix casts from json to custom scalars (#5624) We've had a codepath to fix this bug for enum already, but now I extended it to anything that: * is a scalar, * it not the top-most concrete base of itself, Closes #5616 --- edb/edgeql/compiler/casts.py | 58 ++++++++++++++++++++---------------- tests/test_edgeql_casts.py | 19 ++++++++++++ 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/edb/edgeql/compiler/casts.py b/edb/edgeql/compiler/casts.py index 4c16b362308..ba01478535a 100644 --- a/edb/edgeql/compiler/casts.py +++ b/edb/edgeql/compiler/casts.py @@ -104,8 +104,6 @@ def compile_cast( ): return _find_object_by_id(ir_expr, new_stype, ctx=ctx) - json_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'json')) - if isinstance(ir_set.expr, irast.Array): return _cast_array_literal( ir_set, orig_stype, new_stype, srcctx=srcctx, ctx=ctx) @@ -145,6 +143,7 @@ def compile_cast( ir_set, orig_stype, new_stype, cardinality_mod=cardinality_mod, ctx=ctx) + json_t = ctx.env.get_schema_type_and_track(sn.QualName('std', 'json')) if ( new_stype.issubclass(ctx.env.schema, json_t) and ir_set.path_id.is_objtype_path() @@ -152,31 +151,25 @@ def compile_cast( # JSON casts of objects are special: we want the full shape # and not just an identity. viewgen.late_compile_view_shapes(ir_set, ctx=ctx) - else: - if orig_stype.issubclass(ctx.env.schema, json_t) and new_stype.is_enum( - ctx.env.schema - ): - # Casts from json to enums need some special handling - # here, where we have access to the enum type. Just turn - # it into json->str and str->enum. - str_typ = ctx.env.get_schema_type_and_track( - sn.QualName('std', 'str') - ) - str_ir = compile_cast(ir_expr, str_typ, srcctx=srcctx, ctx=ctx) + elif orig_stype.issubclass(ctx.env.schema, json_t): + + if base_stype := _get_concrete_scalar_base(new_stype, ctx): + # Casts from json to custom scalars may have special handling. + # So we turn the type cast json->x into json->base and base->x. + base_ir = compile_cast(ir_expr, base_stype, srcctx=srcctx, ctx=ctx) + return compile_cast( - str_ir, + base_ir, new_stype, cardinality_mod=cardinality_mod, srcctx=srcctx, ctx=ctx, ) - if ( - orig_stype.issubclass(ctx.env.schema, json_t) - and isinstance(new_stype, s_types.Array) - and not new_stype.get_subtypes(ctx.env.schema)[0].issubclass( - ctx.env.schema, json_t - ) + elif isinstance( + new_stype, s_types.Array + ) and not new_stype.get_subtypes(ctx.env.schema)[0].issubclass( + ctx.env.schema, json_t ): # Turn casts from json->array into json->array # and array->array. @@ -194,9 +187,7 @@ def compile_cast( json_array_ir, new_stype, srcctx=srcctx, ctx=ctx ) - if orig_stype.issubclass(ctx.env.schema, json_t) and isinstance( - new_stype, s_types.Tuple - ): + elif isinstance(new_stype, s_types.Tuple): return _cast_json_to_tuple( ir_set, orig_stype, @@ -206,9 +197,7 @@ def compile_cast( ctx=ctx, ) - if orig_stype.issubclass(ctx.env.schema, json_t) and isinstance( - new_stype, s_types.Range - ): + elif isinstance(new_stype, s_types.Range): return _cast_json_to_range( ir_set, orig_stype, @@ -260,6 +249,23 @@ def _has_common_concrete_scalar( ) +def _get_concrete_scalar_base( + stype: s_types.Type, + ctx: context.ContextLevel +) -> Optional[s_types.Type]: + """Returns None if stype is not scalar or if it is already topmost""" + + if stype.is_enum(ctx.env.schema): + return ctx.env.get_schema_type_and_track(sn.QualName('std', 'str')) + + if not isinstance(stype, s_scalars.ScalarType): + return None + if topmost := stype.maybe_get_topmost_concrete_base(ctx.env.schema): + if topmost != stype: + return topmost + return None + + def _compile_cast( ir_expr: Union[irast.Set, irast.Expr], orig_stype: s_types.Type, diff --git a/tests/test_edgeql_casts.py b/tests/test_edgeql_casts.py index c1b2e3b416c..889de7485f2 100644 --- a/tests/test_edgeql_casts.py +++ b/tests/test_edgeql_casts.py @@ -2624,6 +2624,25 @@ async def test_edgeql_casts_custom_scalar_05(self): [[42]], ) + async def test_edgeql_casts_custom_scalar_06(self): + await self.con.execute( + ''' + create scalar type x extending str { + create constraint expression on (false) + }; + ''' + ) + + async with self.assertRaisesRegexTx( + edgedb.ConstraintViolationError, 'invalid x' + ): + await self.con.query("""SELECT 42""") + + async with self.assertRaisesRegexTx( + edgedb.ConstraintViolationError, 'invalid x' + ): + await self.con.query("""SELECT to_json('"a"')""") + async def test_edgeql_casts_tuple_params_01(self): # insert tuples into a nested array def nest(data):