Skip to content

Commit

Permalink
sql: fix nested expression type-checking
Browse files Browse the repository at this point in the history
PR cockroachdb#108387 introduced new logic to type-checking that allows nested
expressions of a single expression to have different types in some cases
when the types can be implicitly casted to a common type. For example,
the expression `ARRAY[1::INT, 1::FLOAT8]` would have failed
type-checking but is not successfully typed as `[]FLOAT8`.

However, cockroachdb#108387 does not add an implicit cast to ensure that each
nested expression actually has that type during execution, which causes
internal errors in some cases. This commit add the necessary implicit
casts.

Fixes cockroachdb#109629

There is no release note because the bug is not present in any releases.

Release note: None
  • Loading branch information
mgartner committed Aug 31, 2023
1 parent 8b7fb0c commit 1758e91
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 16 deletions.
14 changes: 14 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/aggregate
Expand Up @@ -3968,3 +3968,17 @@ FROM t;

statement ok
RESET null_ordered_last

# Regression test for #109629. Implicit casts should be added during
# type-checking when necessary.
query T
SELECT array_cat_agg(ARRAY[(1::INT,), (1::FLOAT8,)]);
----
{(1),(1)}

query T
SELECT array_cat_agg(
ARRAY[(416644234484367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,)]
)
----
{(4.166442344843677e+17),(),(-0.12116245180368423)}
53 changes: 37 additions & 16 deletions pkg/sql/sem/tree/type_check.go
Expand Up @@ -463,7 +463,7 @@ func (expr *CaseExpr) TypeCheck(
tmpExprs = tmpExprs[:0]
// As described in the Postgres docs, CASE treats its ELSE clause (if any) as
// the "first" input.
// See https://www.postgresql.org/docs/current/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1.
// See https://www.postgresql.org/docs/15/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1.
if expr.Else != nil {
tmpExprs = append(tmpExprs, expr.Else)
}
Expand Down Expand Up @@ -2619,21 +2619,21 @@ func typeCheckSameTypedExprs(
return typeCheckConstsAndPlaceholdersWithDesired(s, desired)
default:
firstValidIdx := -1
firstValidType := types.Unknown
candidateType := types.Unknown
for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, desired)
if err != nil {
return nil, nil, err
}
typedExprs[i] = typedExpr
if returnType := typedExpr.ResolvedType(); returnType.Family() != types.UnknownFamily {
firstValidType = returnType
candidateType = returnType
firstValidIdx = i
break
}
}

if firstValidType.Family() == types.UnknownFamily {
if candidateType.Family() == types.UnknownFamily {
// We got to the end without finding a non-null expression.
switch {
case !constIdxs.Empty():
Expand All @@ -2653,35 +2653,56 @@ func typeCheckSameTypedExprs(
}

for i, ok := s.resolvableIdxs.Next(firstValidIdx + 1); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, firstValidType)
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, candidateType)
if err != nil {
return nil, nil, err
}
// From the Postgres docs
// https://www.postgresql.org/docs/current/typeconv-union-case.html:
// If the candidate type can be implicitly converted to the other type,
// but not vice-versa, select the other type as the new candidate type.
if typ := typedExpr.ResolvedType(); cast.ValidCast(firstValidType, typ, cast.ContextImplicit) {
if !cast.ValidCast(typ, firstValidType, cast.ContextImplicit) {
firstValidType = typ
// https://www.postgresql.org/docs/15/typeconv-union-case.html:
// If the candidate type can be implicitly converted to the other
// type, but not vice-versa, select the other type as the new
// candidate type.
if typ := typedExpr.ResolvedType(); cast.ValidCast(candidateType, typ, cast.ContextImplicit) {
if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) {
candidateType = typ
}
}
if typ := typedExpr.ResolvedType(); !(typ.Equivalent(firstValidType) || typ.Family() == types.UnknownFamily) {
return nil, nil, unexpectedTypeError(exprs[i], firstValidType, typ)
// TODO(mgartner): Remove this check now that we check the types
// below.
if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
typedExprs[i] = typedExpr
}
if !constIdxs.Empty() {
if _, err := typeCheckSameTypedConsts(s, firstValidType, true); err != nil {
if _, err := typeCheckSameTypedConsts(s, candidateType, true); err != nil {
return nil, nil, err
}
}
if !placeholderIdxs.Empty() {
if _, err := typeCheckSameTypedPlaceholders(s, firstValidType); err != nil {
if _, err := typeCheckSameTypedPlaceholders(s, candidateType); err != nil {
return nil, nil, err
}
}
return typedExprs, firstValidType, nil
// Now we check that each expression can be implicit cast to the
// candidate type, and add the cast if necessary. If any expressions
// cannot be cast, type-checking fails. This is described in Step 6 of
// Postgres's "UNION, CASE, and Related Constructs" type conversion
// documentation:
// https://www.postgresql.org/docs/15/typeconv-union-case.html
for i, e := range typedExprs {
typ := e.ResolvedType()
// TODO(mgartner): There should probably be a cast if the types are
// not identical, not just if the types are not equivalent.
if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily {
continue
}
if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
typedExprs[i] = NewTypedCastExpr(e, candidateType)
}
return typedExprs, candidateType, nil
}
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/sem/tree/type_check_test.go
Expand Up @@ -103,6 +103,8 @@ func TestTypeCheck(t *testing.T) {
{`ARRAY[NULL, NULL]:::int[]`, `ARRAY[NULL, NULL]:::INT8[]`},
{`ARRAY[]::INT8[]`, `ARRAY[]:::INT8[]`},
{`ARRAY[]:::INT8[]`, `ARRAY[]:::INT8[]`},
{`ARRAY[1::INT, 1::FLOAT8]`, `ARRAY[1:::INT8::FLOAT8, 1.0:::FLOAT8]:::FLOAT8[]`},
{`ARRAY[(1::INT,), (1::FLOAT8,)]`, `ARRAY[(1:::INT8::FLOAT8,), (1.0:::FLOAT8,)]:::RECORD[]`},
{`1 = ANY ARRAY[1.5, 2.5, 3.5]`, `1:::DECIMAL = ANY ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]:::DECIMAL[]`},
{`true = SOME (ARRAY[true, false])`, `true = SOME ARRAY[true, false]:::BOOL[]`},
{`1.3 = ALL ARRAY[1, 2, 3]`, `1.3:::DECIMAL = ALL ARRAY[1:::DECIMAL, 2:::DECIMAL, 3:::DECIMAL]:::DECIMAL[]`},
Expand Down

0 comments on commit 1758e91

Please sign in to comment.