diff --git a/pkg/sql/logictest/testdata/logic_test/aggregate b/pkg/sql/logictest/testdata/logic_test/aggregate index 8d34e68cf9b7..71f7b055138b 100644 --- a/pkg/sql/logictest/testdata/logic_test/aggregate +++ b/pkg/sql/logictest/testdata/logic_test/aggregate @@ -3968,3 +3968,17 @@ FROM t; statement ok RESET null_ordered_last + +# Regression test for #109629. Implicit casts should be added during +# type-checking when necessary. +query T +SELECT array_cat_agg(ARRAY[(1::INT,), (1::FLOAT8,)]); +---- +{(1),(1)} + +query T +SELECT array_cat_agg( + ARRAY[(416644234484367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,)] +) +---- +{(4.166442344843677e+17),(),(-0.12116245180368423)} diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index d05e44edb9b6..7c1498382762 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -463,7 +463,7 @@ func (expr *CaseExpr) TypeCheck( tmpExprs = tmpExprs[:0] // As described in the Postgres docs, CASE treats its ELSE clause (if any) as // the "first" input. - // See https://www.postgresql.org/docs/current/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1. + // See https://www.postgresql.org/docs/15/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1. if expr.Else != nil { tmpExprs = append(tmpExprs, expr.Else) } @@ -2619,7 +2619,7 @@ func typeCheckSameTypedExprs( return typeCheckConstsAndPlaceholdersWithDesired(s, desired) default: firstValidIdx := -1 - firstValidType := types.Unknown + candidateType := types.Unknown for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) { typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, desired) if err != nil { @@ -2627,13 +2627,13 @@ func typeCheckSameTypedExprs( } typedExprs[i] = typedExpr if returnType := typedExpr.ResolvedType(); returnType.Family() != types.UnknownFamily { - firstValidType = returnType + candidateType = returnType firstValidIdx = i break } } - if firstValidType.Family() == types.UnknownFamily { + if candidateType.Family() == types.UnknownFamily { // We got to the end without finding a non-null expression. switch { case !constIdxs.Empty(): @@ -2653,35 +2653,56 @@ func typeCheckSameTypedExprs( } for i, ok := s.resolvableIdxs.Next(firstValidIdx + 1); ok; i, ok = s.resolvableIdxs.Next(i + 1) { - typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, firstValidType) + typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, candidateType) if err != nil { return nil, nil, err } // From the Postgres docs - // https://www.postgresql.org/docs/current/typeconv-union-case.html: - // If the candidate type can be implicitly converted to the other type, - // but not vice-versa, select the other type as the new candidate type. - if typ := typedExpr.ResolvedType(); cast.ValidCast(firstValidType, typ, cast.ContextImplicit) { - if !cast.ValidCast(typ, firstValidType, cast.ContextImplicit) { - firstValidType = typ + // https://www.postgresql.org/docs/15/typeconv-union-case.html: + // If the candidate type can be implicitly converted to the other + // type, but not vice-versa, select the other type as the new + // candidate type. + if typ := typedExpr.ResolvedType(); cast.ValidCast(candidateType, typ, cast.ContextImplicit) { + if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) { + candidateType = typ } } - if typ := typedExpr.ResolvedType(); !(typ.Equivalent(firstValidType) || typ.Family() == types.UnknownFamily) { - return nil, nil, unexpectedTypeError(exprs[i], firstValidType, typ) + // TODO(mgartner): Remove this check now that we check the types + // below. + if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) { + return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) } typedExprs[i] = typedExpr } if !constIdxs.Empty() { - if _, err := typeCheckSameTypedConsts(s, firstValidType, true); err != nil { + if _, err := typeCheckSameTypedConsts(s, candidateType, true); err != nil { return nil, nil, err } } if !placeholderIdxs.Empty() { - if _, err := typeCheckSameTypedPlaceholders(s, firstValidType); err != nil { + if _, err := typeCheckSameTypedPlaceholders(s, candidateType); err != nil { return nil, nil, err } } - return typedExprs, firstValidType, nil + // Now we check that each expression can be implicit cast to the + // candidate type, and add the cast if necessary. If any expressions + // cannot be cast, type-checking fails. This is described in Step 6 of + // Postgres's "UNION, CASE, and Related Constructs" type conversion + // documentation: + // https://www.postgresql.org/docs/15/typeconv-union-case.html + for i, e := range typedExprs { + typ := e.ResolvedType() + // TODO(mgartner): There should probably be a cast if the types are + // not identical, not just if the types are not equivalent. + if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily { + continue + } + if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) { + return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) + } + typedExprs[i] = NewTypedCastExpr(e, candidateType) + } + return typedExprs, candidateType, nil } } diff --git a/pkg/sql/sem/tree/type_check_test.go b/pkg/sql/sem/tree/type_check_test.go index 80586be8f429..cff12080de07 100644 --- a/pkg/sql/sem/tree/type_check_test.go +++ b/pkg/sql/sem/tree/type_check_test.go @@ -103,6 +103,8 @@ func TestTypeCheck(t *testing.T) { {`ARRAY[NULL, NULL]:::int[]`, `ARRAY[NULL, NULL]:::INT8[]`}, {`ARRAY[]::INT8[]`, `ARRAY[]:::INT8[]`}, {`ARRAY[]:::INT8[]`, `ARRAY[]:::INT8[]`}, + {`ARRAY[1::INT, 1::FLOAT8]`, `ARRAY[1:::INT8::FLOAT8, 1.0:::FLOAT8]:::FLOAT8[]`}, + {`ARRAY[(1::INT,), (1::FLOAT8,)]`, `ARRAY[(1:::INT8::FLOAT8,), (1.0:::FLOAT8,)]:::RECORD[]`}, {`1 = ANY ARRAY[1.5, 2.5, 3.5]`, `1:::DECIMAL = ANY ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]:::DECIMAL[]`}, {`true = SOME (ARRAY[true, false])`, `true = SOME ARRAY[true, false]:::BOOL[]`}, {`1.3 = ALL ARRAY[1, 2, 3]`, `1.3:::DECIMAL = ALL ARRAY[1:::DECIMAL, 2:::DECIMAL, 3:::DECIMAL]:::DECIMAL[]`},