From b8057f4c3c196e5cd05402c9b8a21ed09d79d16e Mon Sep 17 00:00:00 2001 From: Tommy Reilly Date: Fri, 4 Aug 2023 16:13:24 +0000 Subject: [PATCH] sql: fix overload type checking of nested case expressions If we have a nested case expression where the inner case expression is ambiguous the AnyCollatedString type would be selected and it would leak to the execution engine causing the 'failed to parse locale ""' internal error. Instead have the overload type checker remember if it saw a AnyCollatedString type and go back and repair types if a concrete type is found in any of the other exprs. Release note (bug fix): Fix a bug with collated string type checking with nested case expressions where inner case had no explicit collated type. Epic: none Fixes: #101418 Release justification: Low risk fix for edge case SQL construct. --- pkg/sql/sem/tree/overload.go | 32 +++++++++++++++++++++++++++++ pkg/sql/sem/tree/type_check_test.go | 27 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/pkg/sql/sem/tree/overload.go b/pkg/sql/sem/tree/overload.go index 7a0be0d6ead7..d9bdd263b335 100644 --- a/pkg/sql/sem/tree/overload.go +++ b/pkg/sql/sem/tree/overload.go @@ -807,6 +807,7 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs( // Filter out overloads on resolved types. This includes resolved placeholders // and any other resolvable exprs. + ambiguousCollatedTypes := false var typeableIdxs intsets.Fast for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) { typeableIdxs.Add(i) @@ -833,6 +834,8 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs( break } } + // Don't allow ambiguous types to be desired, this prevents for instance + // AnyCollatedString from trumping the concrete collated type. if sameType != nil { paramDesired = sameType } @@ -841,6 +844,9 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs( return err } s.typedExprs[i] = typ + if typ.ResolvedType() == types.AnyCollatedString { + ambiguousCollatedTypes = true + } rt := typ.ResolvedType() s.overloadIdxs = filterParams(s.overloadIdxs, s.params, func( params TypeList, @@ -849,6 +855,32 @@ func (s *overloadTypeChecker) typeCheckOverloadedExprs( }) } + // If we typed any exprs as AnyCollatedString but have a concrete collated + // string type, redo the types using the contrete type. Note we're probably + // still lacking full compliance with PG on collation handling: + // https://www.postgresql.org/docs/current/collation.html#id-1.6.11.4.4 + if ambiguousCollatedTypes { + var concreteType *types.T + for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) { + typ := s.typedExprs[i].ResolvedType() + if typ != types.AnyCollatedString { + concreteType = typ + break + } + } + if concreteType != nil { + for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) { + if s.typedExprs[i].ResolvedType() == types.AnyCollatedString { + typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, concreteType) + if err != nil { + return err + } + s.typedExprs[i] = typ + } + } + } + } + // At this point, all remaining overload candidates accept the argument list, // so we begin checking for a single remainig candidate implementation to choose. // In case there is more than one candidate remaining, the following code uses diff --git a/pkg/sql/sem/tree/type_check_test.go b/pkg/sql/sem/tree/type_check_test.go index 58a71e41ffdb..80586be8f429 100644 --- a/pkg/sql/sem/tree/type_check_test.go +++ b/pkg/sql/sem/tree/type_check_test.go @@ -473,6 +473,33 @@ func TestTypeCheckCollatedString(t *testing.T) { require.Equal(t, rightTyp.Locale(), "en-US-u-ks-level2") } +func TestTypeCheckCollatedStringNestedCaseComparison(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + semaCtx := tree.MakeSemaContext() + + // The collated string constant must be on the LHS for this test, so that + // the type-checker chooses the collated string overload first. + for _, exprStr := range []string{ + `CASE WHEN false THEN CASE WHEN (NOT (false)) THEN NULL END ELSE ('' COLLATE "es_ES") END >= ('' COLLATE "es_ES")`, + `CASE WHEN false THEN NULL ELSE ('' COLLATE "es_ES") END >= ('' COLLATE "es_ES")`, + `CASE WHEN false THEN ('' COLLATE "es_ES") ELSE NULL END >= ('' COLLATE "es_ES")`, + `('' COLLATE "es_ES") >= CASE WHEN false THEN CASE WHEN (NOT (false)) THEN NULL END ELSE ('' COLLATE "es_ES") END`} { + expr, err := parser.ParseExpr(exprStr) + require.NoError(t, err) + typed, err := tree.TypeCheck(ctx, expr, &semaCtx, types.Any) + require.NoError(t, err) + + for _, ex := range []tree.Expr{typed.(*tree.ComparisonExpr).Left, typed.(*tree.ComparisonExpr).Right} { + typ := ex.(tree.TypedExpr).ResolvedType() + require.Equal(t, types.CollatedStringFamily, typ.Family()) + require.Equal(t, "es_ES", typ.Locale()) + } + } +} + func TestTypeCheckCaseExprWithPlaceholders(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)