From 8c726023bb2b222ee7ab6dd7451aa4f0615874bb Mon Sep 17 00:00:00 2001 From: Marcus Gartner Date: Mon, 28 Aug 2023 17:41:34 -0400 Subject: [PATCH] sql: fix nested expression type-checking PR #108387 introduced new logic to type-checking that allows nested expressions of a single expression to have different types in some cases when the types can be implicitly casted to a common type. For example, the expression `ARRAY[1::INT, 1::FLOAT8]` would have failed type-checking but is not successfully typed as `[]FLOAT8`. However, #108387 does not add an implicit cast to ensure that each nested expression actually has that type during execution, which causes internal errors in some cases. This commit add the necessary implicit casts. Fixes #109629 There is no release note because the bug is not present in any releases. Release note: None --- .../logictest/testdata/logic_test/aggregate | 14 ++++ pkg/sql/sem/tree/type_check.go | 72 ++++++++++++++----- pkg/sql/sem/tree/type_check_test.go | 2 + 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/aggregate b/pkg/sql/logictest/testdata/logic_test/aggregate index 8d34e68cf9b7..71f7b055138b 100644 --- a/pkg/sql/logictest/testdata/logic_test/aggregate +++ b/pkg/sql/logictest/testdata/logic_test/aggregate @@ -3968,3 +3968,17 @@ FROM t; statement ok RESET null_ordered_last + +# Regression test for #109629. Implicit casts should be added during +# type-checking when necessary. +query T +SELECT array_cat_agg(ARRAY[(1::INT,), (1::FLOAT8,)]); +---- +{(1),(1)} + +query T +SELECT array_cat_agg( + ARRAY[(416644234484367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,)] +) +---- +{(4.166442344843677e+17),(),(-0.12116245180368423)} diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index d05e44edb9b6..21192c706928 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -463,7 +463,7 @@ func (expr *CaseExpr) TypeCheck( tmpExprs = tmpExprs[:0] // As described in the Postgres docs, CASE treats its ELSE clause (if any) as // the "first" input. - // See https://www.postgresql.org/docs/current/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1. + // See https://www.postgresql.org/docs/15/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1. if expr.Else != nil { tmpExprs = append(tmpExprs, expr.Else) } @@ -2619,7 +2619,7 @@ func typeCheckSameTypedExprs( return typeCheckConstsAndPlaceholdersWithDesired(s, desired) default: firstValidIdx := -1 - firstValidType := types.Unknown + candidateType := types.Unknown for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) { typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, desired) if err != nil { @@ -2627,13 +2627,13 @@ func typeCheckSameTypedExprs( } typedExprs[i] = typedExpr if returnType := typedExpr.ResolvedType(); returnType.Family() != types.UnknownFamily { - firstValidType = returnType + candidateType = returnType firstValidIdx = i break } } - if firstValidType.Family() == types.UnknownFamily { + if candidateType.Family() == types.UnknownFamily { // We got to the end without finding a non-null expression. switch { case !constIdxs.Empty(): @@ -2653,35 +2653,75 @@ func typeCheckSameTypedExprs( } for i, ok := s.resolvableIdxs.Next(firstValidIdx + 1); ok; i, ok = s.resolvableIdxs.Next(i + 1) { - typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, firstValidType) + typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, candidateType) if err != nil { return nil, nil, err } // From the Postgres docs - // https://www.postgresql.org/docs/current/typeconv-union-case.html: - // If the candidate type can be implicitly converted to the other type, - // but not vice-versa, select the other type as the new candidate type. - if typ := typedExpr.ResolvedType(); cast.ValidCast(firstValidType, typ, cast.ContextImplicit) { - if !cast.ValidCast(typ, firstValidType, cast.ContextImplicit) { - firstValidType = typ + // https://www.postgresql.org/docs/15/typeconv-union-case.html: + // If the candidate type can be implicitly converted to the other + // type, but not vice-versa, select the other type as the new + // candidate type. + if typ := typedExpr.ResolvedType(); cast.ValidCast(candidateType, typ, cast.ContextImplicit) { + if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) { + candidateType = typ } } - if typ := typedExpr.ResolvedType(); !(typ.Equivalent(firstValidType) || typ.Family() == types.UnknownFamily) { - return nil, nil, unexpectedTypeError(exprs[i], firstValidType, typ) + // TODO(mgartner): Remove this check now that we check the types + // below. + if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) { + return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) } typedExprs[i] = typedExpr } if !constIdxs.Empty() { - if _, err := typeCheckSameTypedConsts(s, firstValidType, true); err != nil { + if _, err := typeCheckSameTypedConsts(s, candidateType, true); err != nil { return nil, nil, err } } if !placeholderIdxs.Empty() { - if _, err := typeCheckSameTypedPlaceholders(s, firstValidType); err != nil { + if _, err := typeCheckSameTypedPlaceholders(s, candidateType); err != nil { return nil, nil, err } } - return typedExprs, firstValidType, nil + // Now we check that each expression can be implicit cast to the + // candidate type, and add the cast if necessary. If any expressions + // cannot be cast, type-checking fails. This is described in Step 6 of + // Postgres's "UNION, CASE, and Related Constructs" type conversion + // documentation: + // https://www.postgresql.org/docs/15/typeconv-union-case.html + // + // NOTE: We fail if any part of the candidate type is a tuple type, + // because we don't have a way to serialize this cast over DistSQL. + var hasNestedTupleType func(t *types.T) bool + hasNestedTupleType = func(t *types.T) bool { + switch t.Family() { + case types.TupleFamily: + return true + case types.ArrayFamily: + if hasNestedTupleType(t.ArrayContents()) { + return true + } + } + return false + } + nestedTupleType := hasNestedTupleType(candidateType) + for i, e := range typedExprs { + typ := e.ResolvedType() + // TODO(mgartner): There should probably be a cast if the types are + // not identical, not just if the types are not equivalent. + if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily { + continue + } + if nestedTupleType { + return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) + } + if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) { + return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) + } + typedExprs[i] = NewTypedCastExpr(e, candidateType) + } + return typedExprs, candidateType, nil } } diff --git a/pkg/sql/sem/tree/type_check_test.go b/pkg/sql/sem/tree/type_check_test.go index 80586be8f429..cff12080de07 100644 --- a/pkg/sql/sem/tree/type_check_test.go +++ b/pkg/sql/sem/tree/type_check_test.go @@ -103,6 +103,8 @@ func TestTypeCheck(t *testing.T) { {`ARRAY[NULL, NULL]:::int[]`, `ARRAY[NULL, NULL]:::INT8[]`}, {`ARRAY[]::INT8[]`, `ARRAY[]:::INT8[]`}, {`ARRAY[]:::INT8[]`, `ARRAY[]:::INT8[]`}, + {`ARRAY[1::INT, 1::FLOAT8]`, `ARRAY[1:::INT8::FLOAT8, 1.0:::FLOAT8]:::FLOAT8[]`}, + {`ARRAY[(1::INT,), (1::FLOAT8,)]`, `ARRAY[(1:::INT8::FLOAT8,), (1.0:::FLOAT8,)]:::RECORD[]`}, {`1 = ANY ARRAY[1.5, 2.5, 3.5]`, `1:::DECIMAL = ANY ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]:::DECIMAL[]`}, {`true = SOME (ARRAY[true, false])`, `true = SOME ARRAY[true, false]:::BOOL[]`}, {`1.3 = ALL ARRAY[1, 2, 3]`, `1.3:::DECIMAL = ALL ARRAY[1:::DECIMAL, 2:::DECIMAL, 3:::DECIMAL]:::DECIMAL[]`},