Skip to content

Commit

Permalink
sql: fix nested expression type-checking
Browse files Browse the repository at this point in the history
PR #108387 introduced new logic to type-checking that allows nested
expressions of a single expression to have different types in some cases
when the types can be implicitly casted to a common type. For example,
the expression `ARRAY[1::INT, 1::FLOAT8]` would have failed
type-checking but is not successfully typed as `[]FLOAT8`.

However, #108387 does not add an implicit cast to ensure that each
nested expression actually has that type during execution, which causes
internal errors in some cases. This commit add the necessary implicit
casts.

Fixes #109629

There is no release note because the bug is not present in any releases.

Release note: None
  • Loading branch information
mgartner committed Aug 30, 2023
1 parent 8b7fb0c commit 8c72602
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
14 changes: 14 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/aggregate
Original file line number Diff line number Diff line change
Expand Up @@ -3968,3 +3968,17 @@ FROM t;

statement ok
RESET null_ordered_last

# Regression test for #109629. Implicit casts should be added during
# type-checking when necessary.
query T
SELECT array_cat_agg(ARRAY[(1::INT,), (1::FLOAT8,)]);
----
{(1),(1)}

query T
SELECT array_cat_agg(
ARRAY[(416644234484367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,)]
)
----
{(4.166442344843677e+17),(),(-0.12116245180368423)}
72 changes: 56 additions & 16 deletions pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ func (expr *CaseExpr) TypeCheck(
tmpExprs = tmpExprs[:0]
// As described in the Postgres docs, CASE treats its ELSE clause (if any) as
// the "first" input.
// See https://www.postgresql.org/docs/current/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1.
// See https://www.postgresql.org/docs/15/typeconv-union-case.html#ftn.id-1.5.9.10.9.6.1.1.
if expr.Else != nil {
tmpExprs = append(tmpExprs, expr.Else)
}
Expand Down Expand Up @@ -2619,21 +2619,21 @@ func typeCheckSameTypedExprs(
return typeCheckConstsAndPlaceholdersWithDesired(s, desired)
default:
firstValidIdx := -1
firstValidType := types.Unknown
candidateType := types.Unknown
for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, desired)
if err != nil {
return nil, nil, err
}
typedExprs[i] = typedExpr
if returnType := typedExpr.ResolvedType(); returnType.Family() != types.UnknownFamily {
firstValidType = returnType
candidateType = returnType
firstValidIdx = i
break
}
}

if firstValidType.Family() == types.UnknownFamily {
if candidateType.Family() == types.UnknownFamily {
// We got to the end without finding a non-null expression.
switch {
case !constIdxs.Empty():
Expand All @@ -2653,35 +2653,75 @@ func typeCheckSameTypedExprs(
}

for i, ok := s.resolvableIdxs.Next(firstValidIdx + 1); ok; i, ok = s.resolvableIdxs.Next(i + 1) {
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, firstValidType)
typedExpr, err := exprs[i].TypeCheck(ctx, semaCtx, candidateType)
if err != nil {
return nil, nil, err
}
// From the Postgres docs
// https://www.postgresql.org/docs/current/typeconv-union-case.html:
// If the candidate type can be implicitly converted to the other type,
// but not vice-versa, select the other type as the new candidate type.
if typ := typedExpr.ResolvedType(); cast.ValidCast(firstValidType, typ, cast.ContextImplicit) {
if !cast.ValidCast(typ, firstValidType, cast.ContextImplicit) {
firstValidType = typ
// https://www.postgresql.org/docs/15/typeconv-union-case.html:
// If the candidate type can be implicitly converted to the other
// type, but not vice-versa, select the other type as the new
// candidate type.
if typ := typedExpr.ResolvedType(); cast.ValidCast(candidateType, typ, cast.ContextImplicit) {
if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) {
candidateType = typ
}
}
if typ := typedExpr.ResolvedType(); !(typ.Equivalent(firstValidType) || typ.Family() == types.UnknownFamily) {
return nil, nil, unexpectedTypeError(exprs[i], firstValidType, typ)
// TODO(mgartner): Remove this check now that we check the types
// below.
if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
typedExprs[i] = typedExpr
}
if !constIdxs.Empty() {
if _, err := typeCheckSameTypedConsts(s, firstValidType, true); err != nil {
if _, err := typeCheckSameTypedConsts(s, candidateType, true); err != nil {
return nil, nil, err
}
}
if !placeholderIdxs.Empty() {
if _, err := typeCheckSameTypedPlaceholders(s, firstValidType); err != nil {
if _, err := typeCheckSameTypedPlaceholders(s, candidateType); err != nil {
return nil, nil, err
}
}
return typedExprs, firstValidType, nil
// Now we check that each expression can be implicit cast to the
// candidate type, and add the cast if necessary. If any expressions
// cannot be cast, type-checking fails. This is described in Step 6 of
// Postgres's "UNION, CASE, and Related Constructs" type conversion
// documentation:
// https://www.postgresql.org/docs/15/typeconv-union-case.html
//
// NOTE: We fail if any part of the candidate type is a tuple type,
// because we don't have a way to serialize this cast over DistSQL.
var hasNestedTupleType func(t *types.T) bool
hasNestedTupleType = func(t *types.T) bool {
switch t.Family() {
case types.TupleFamily:
return true
case types.ArrayFamily:
if hasNestedTupleType(t.ArrayContents()) {
return true
}
}
return false
}
nestedTupleType := hasNestedTupleType(candidateType)
for i, e := range typedExprs {
typ := e.ResolvedType()
// TODO(mgartner): There should probably be a cast if the types are
// not identical, not just if the types are not equivalent.
if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily {
continue
}
if nestedTupleType {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
typedExprs[i] = NewTypedCastExpr(e, candidateType)
}
return typedExprs, candidateType, nil
}
}

Expand Down
2 changes: 2 additions & 0 deletions pkg/sql/sem/tree/type_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ func TestTypeCheck(t *testing.T) {
{`ARRAY[NULL, NULL]:::int[]`, `ARRAY[NULL, NULL]:::INT8[]`},
{`ARRAY[]::INT8[]`, `ARRAY[]:::INT8[]`},
{`ARRAY[]:::INT8[]`, `ARRAY[]:::INT8[]`},
{`ARRAY[1::INT, 1::FLOAT8]`, `ARRAY[1:::INT8::FLOAT8, 1.0:::FLOAT8]:::FLOAT8[]`},
{`ARRAY[(1::INT,), (1::FLOAT8,)]`, `ARRAY[(1:::INT8::FLOAT8,), (1.0:::FLOAT8,)]:::RECORD[]`},
{`1 = ANY ARRAY[1.5, 2.5, 3.5]`, `1:::DECIMAL = ANY ARRAY[1.5:::DECIMAL, 2.5:::DECIMAL, 3.5:::DECIMAL]:::DECIMAL[]`},
{`true = SOME (ARRAY[true, false])`, `true = SOME ARRAY[true, false]:::BOOL[]`},
{`1.3 = ALL ARRAY[1, 2, 3]`, `1.3:::DECIMAL = ALL ARRAY[1:::DECIMAL, 2:::DECIMAL, 3:::DECIMAL]:::DECIMAL[]`},
Expand Down

0 comments on commit 8c72602

Please sign in to comment.