Skip to content

Commit

Permalink
sql: avoid direct pointer comparisons for *types.T
Browse files Browse the repository at this point in the history
This commit replaces various `*types.T` pointer comparisons with type
family comparisons or calls to the `type.Identical()` method. This is
necessary because deserialization from disk may result in identical
types other than the global singletons (e.g. `types.Any`).

There is no release note, since these bugs have not resulted in any
known issues.

Informs #114846

Release note: None
  • Loading branch information
DrewKimball committed Feb 23, 2024
1 parent 6d65201 commit 3b455a7
Show file tree
Hide file tree
Showing 38 changed files with 75 additions and 67 deletions.
6 changes: 3 additions & 3 deletions pkg/ccl/changefeedccl/cdceval/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func cdcTimestampBuiltin(
},
Info: doc + " as HLC timestamp",
Volatility: v,
PreferredOverload: preferredOverloadReturnType == types.Decimal,
PreferredOverload: preferredOverloadReturnType.Identical(types.Decimal),
},
{
Types: tree.ParamTypes{},
Expand All @@ -179,7 +179,7 @@ func cdcTimestampBuiltin(
},
Info: doc + " as TIMESTAMPTZ",
Volatility: v,
PreferredOverload: preferredOverloadReturnType == types.TimestampTZ,
PreferredOverload: preferredOverloadReturnType.Identical(types.TimestampTZ),
},
{
Types: tree.ParamTypes{},
Expand All @@ -190,7 +190,7 @@ func cdcTimestampBuiltin(
},
Info: doc + " as TIMESTAMP",
Volatility: v,
PreferredOverload: preferredOverloadReturnType == types.Timestamp,
PreferredOverload: preferredOverloadReturnType.Identical(types.Timestamp),
},
},
)
Expand Down
2 changes: 1 addition & 1 deletion pkg/ccl/changefeedccl/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ func TestJsonRountrip(t *testing.T) {
rng, _ := randutil.NewTestRand()

isFloatOrDecimal := func(typ *types.T) bool {
return typ == types.Float4 || typ == types.Float || typ == types.Decimal
return typ.Identical(types.Float4) || typ.Identical(types.Float) || typ.Identical(types.Decimal)
}

type test struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/col/coldata/vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func BenchmarkAppend(b *testing.B) {
for _, nullProbability := range []float64{0, 0.2} {
for _, bc := range benchCases {
// Only test the AppendNoInline case for bytes.
if typ != types.Bytes && bc.bytesLen == longBytesLen {
if !typ.Identical(types.Bytes) && bc.bytesLen == longBytesLen {
continue
}
src := coldata.NewMemColumn(typ, coldata.BatchSize(), coldata.StandardColumnFactory)
Expand Down
4 changes: 2 additions & 2 deletions pkg/col/colserde/arrowbatchconverter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ func runConversionBenchmarks(
b.Fatalf("unexpected batch width: %d", batch.Width())
}
var typNameSuffix string
if typ == types.Bytes {
if typ.Identical(types.Bytes) {
tc.numBytes = int64(tc.bytesFixedLength * coldata.BatchSize())
if tc.bytesFixedLength == bytesInlinedLen {
typNameSuffix = "_inlined"
}
} else if typ == types.Decimal {
} else if typ.Identical(types.Decimal) {
// Decimal is variable length type, so we want to calculate precisely the
// total size of all decimals in the vector.
decimals := batch.ColVec(0).Decimal()
Expand Down
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ func (s *Smither) makeCreateFunc() (cf *tree.CreateRoutine, ok bool) {
stmt = expr.(*tree.StatementSource).Statement
// If the rtype isn't a RECORD, change it to rrefs or RECORD depending
// how many columns there are to avoid return type mismatch errors.
if rtyp != types.AnyTuple {
if !rtyp.Identical(types.AnyTuple) {
if len(rrefs) == 1 && s.coin() && rrefs[0].typ.Family() != types.CollatedStringFamily {
rtyp = rrefs[0].typ
} else {
Expand Down
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ func (v *replaceDatumPlaceholderVisitor) VisitPre(
) (recurse bool, newExpr tree.Expr) {
switch t := expr.(type) {
case tree.Datum:
if t.ResolvedType().IsNumeric() || t.ResolvedType() == types.Bool {
if t.ResolvedType().IsNumeric() || t.ResolvedType().Identical(types.Bool) {
v.Args = append(v.Args, expr)
placeholder, _ := tree.NewPlaceholder(strconv.Itoa(len(v.Args)))
return false, placeholder
Expand Down
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (s *Smither) canRecurseScalar(isPredicate bool, typ *types.T) bool {
// Smither option is `true`, the desired expression type is boolean, and the
// expression is being generated for use in a query predicate.
func (s *Smither) avoidConstantBooleanExpressions(isPredicate bool, typ *types.T) bool {
if isPredicate && s.unlikelyConstantPredicate && typ == types.Bool {
if isPredicate && s.unlikelyConstantPredicate && typ.Identical(types.Bool) {
return true
}
return false
Expand Down
2 changes: 1 addition & 1 deletion pkg/internal/sqlsmith/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (s *Smither) randType() *types.T {
(s.disableOIDs && typ.Family() == types.OidFamily) {
continue
}
if s.postgres && typ == types.Name {
if s.postgres && typ.Identical(types.Name) {
// Name type in CRDB doesn't match Postgres behavior. Exclude for tests
// which compare CRDB behavior to Postgres.
continue
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/colexec/colbuilder/execplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -2293,7 +2293,7 @@ func planProjectionOperators(
// anything extra for them, but we do need to handle the string
// case.
leftType, rightType := leftExpr.ResolvedType(), rightExpr.ResolvedType()
if t.Op.ReturnType == types.String && leftType.Family() != rightType.Family() {
if t.Op.ReturnType.Identical(types.String) && leftType.Family() != rightType.Family() {
// This is a special case of the STRING concatenation - we have
// to plan a cast of the non-string type to a STRING.
if leftType.Family() == types.StringFamily {
Expand Down
12 changes: 6 additions & 6 deletions pkg/sql/colexec/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,9 @@ func runDistinctBenchmarks(
}
bytesValueScratch := make([]byte, bytesValueLength)
setFirstValue := func(vec coldata.Vec) {
if typ := vec.Type(); typ == types.Int {
if typ := vec.Type(); typ.Identical(types.Int) {
vec.Int64()[0] = 0
} else if typ == types.Bytes {
} else if typ.Identical(types.Bytes) {
vec.Bytes().Set(0, bytesValueScratch)
} else {
colexecerror.InternalError(errors.AssertionFailedf("unsupported type %s", typ))
Expand All @@ -503,13 +503,13 @@ func runDistinctBenchmarks(
if i == 0 {
colexecerror.InternalError(errors.New("setIthValue called with i == 0"))
}
if typ := vec.Type(); typ == types.Int {
if typ := vec.Type(); typ.Identical(types.Int) {
col := vec.Int64()
col[i] = col[i-1]
if rng.Float64() < newValueProbability {
col[i]++
}
} else if typ == types.Bytes {
} else if typ.Identical(types.Bytes) {
if rng.Float64() < newValueProbability {
copy(bytesValueScratch, vec.Bytes().Get(i-1))
for pos := 0; pos < bytesValueLength; pos++ {
Expand Down Expand Up @@ -569,13 +569,13 @@ func runDistinctBenchmarks(
})
for colIdx, oldCol := range cols {
cols[colIdx] = testAllocator.NewMemColumn(typs[colIdx], nRows)
if typs[colIdx] == types.Int {
if typs[colIdx].Identical(types.Int) {
oldInt64s := oldCol.Int64()
newInt64s := cols[colIdx].Int64()
for i := 0; i < nRows; i++ {
newInt64s[i] = oldInt64s[order[i]]
}
} else if typs[colIdx] == types.Bytes {
} else if typs[colIdx].Identical(types.Bytes) {
oldBytes := oldCol.Bytes()
newBytes := cols[colIdx].Bytes()
for i := 0; i < nRows; i++ {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/colexec/external_hash_joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func BenchmarkExternalHashJoiner(b *testing.B) {
continue
}
var cols []coldata.Vec
if typ == types.Int {
if typ.Identical(types.Int) {
cols = newIntColumns(nCols, nRows, 1 /* dupCount */)
} else {
cols = newBytesColumns(nCols, nRows, 1 /* dupCount */)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/colexec/hashjoiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ func BenchmarkHashJoiner(b *testing.B) {
// will have 15 duplicates.
dupCount = 16
}
if typ == types.Int {
if typ.Identical(types.Int) {
cols = newIntColumns(nCols, length, dupCount)
} else {
cols = newBytesColumns(nCols, length, dupCount)
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/create_as_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ func TestCreateAsVTable(t *testing.T) {
}
// Filter out vector columns to prevent error in CTAS:
// "VECTOR column types are unsupported".
if colDef.Type == types.Int2Vector || colDef.Type == types.OidVector {
if colDef.Type.(*types.T).Identical(types.Int2Vector) ||
colDef.Type.(*types.T).Identical(types.OidVector) {
continue
}
ctasColumns = append(ctasColumns, colDef.Name.String())
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/distsql_physical_planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,12 @@ func (v *distSQLExprCheckVisitor) VisitPre(expr tree.Expr) (recurse bool, newExp
// We need to check for arrays of untyped tuples here since constant-folding
// on builtin functions sometimes produces this. DecodeUntaggedDatum
// requires that all the types of the tuple contents are known.
if t.ResolvedType().ArrayContents() == types.AnyTuple {
if t.ResolvedType().ArrayContents().Identical(types.AnyTuple) {
v.err = newQueryNotSupportedErrorf("array %s cannot be executed with distsql", t)
return false, expr
}
case *tree.DTuple:
if t.ResolvedType() == types.AnyTuple {
if t.ResolvedType().Identical(types.AnyTuple) {
v.err = newQueryNotSupportedErrorf("tuple %s cannot be executed with distsql", t)
return false, expr
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/evalcatalog/encode_table_index_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (ec *Builtins) EncodeTableIndexKey(
if err != nil {
return nil, err
}
if d.ResolvedType() == types.Unknown {
if d.ResolvedType().Family() == types.UnknownFamily {
if !col.IsNullable() {
return nil, pgerror.Newf(pgcode.NotNullViolation, "NULL provided as a value for a nonnullable column")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/importer/read_import_avro_logical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type avroLogicalInfo struct {
}

func logicalEncoder(datum tree.Datum, avroType string) (ans interface{}, err error) {
if datum.ResolvedType() == types.Unknown {
if datum.ResolvedType().Family() == types.UnknownFamily {
return nil, nil
}
switch datum.ResolvedType().Family() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/invertedidx/tsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (t *tsqueryFilterPlanner) extractInvertedFilterConditionFromLeaf(
return inverted.NonInvertedColExpression{}, expr, nil
}
d := memo.ExtractConstDatum(constantVal)
if d.ResolvedType() != types.TSQuery {
if !d.ResolvedType().Identical(types.TSQuery) {
panic(errors.AssertionFailedf(
"trying to apply tsvector inverted index to unsupported type %s", d.ResolvedType().SQLStringForError(),
))
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func (b *Builder) trackReferencedColumnForViews(col *scopeColumn) {

func (b *Builder) maybeTrackRegclassDependenciesForViews(texpr tree.TypedExpr) {
if b.trackSchemaDeps {
if texpr.ResolvedType() == types.RegClass {
if texpr.ResolvedType().Identical(types.RegClass) {
// We do not add a dependency if the RegClass Expr contains variables,
// we cannot resolve the variables in this context. This matches Postgres
// behavior.
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/opt/optbuilder/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, re
}
case *ast.Return:
desired := types.Any
if r.typ != types.Unknown {
if r.typ.Family() != types.UnknownFamily {
desired = r.typ
}
expr, _ := tree.WalkExpr(r.s, t.Expr)
Expand All @@ -1658,13 +1658,13 @@ func (r *recordTypeVisitor) Visit(stmt ast.Statement) (newStmt ast.Statement, re
panic(err)
}
typ := typedExpr.ResolvedType()
if typ == types.Unknown {
if typ.Family() == types.UnknownFamily {
return stmt, false
}
if typ.Family() != types.TupleFamily {
panic(nonCompositeErr)
}
if r.typ == types.Unknown {
if r.typ.Family() == types.UnknownFamily {
r.typ = typ
return stmt, false
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ func (b *Builder) buildScalar(
// arguments with a CastExpr that preserves the static type.

left := t.TypedLeft()
if left.ResolvedType() == types.Unknown {
if left.ResolvedType().Family() == types.UnknownFamily {
left = reType(left, t.ResolvedBinOp().LeftType)
}
right := t.TypedRight()
if right.ResolvedType() == types.Unknown {
if right.ResolvedType().Family() == types.UnknownFamily {
right = reType(right, t.ResolvedBinOp().RightType)
}
out = b.constructBinary(
Expand Down
11 changes: 5 additions & 6 deletions pkg/sql/opt/optbuilder/union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ func TestUnionType(t *testing.T) {
expected: types.Decimal,
},
{
// Error.
left: types.Float,
right: types.String,
expected: nil,
left: types.MakeArray(types.MakeTuple([]*types.T{types.Any})),
right: types.MakeArray(types.MakeTuple([]*types.T{types.Bool})),
expected: types.MakeArray(types.MakeTuple([]*types.T{types.Bool})),
},
{
// Error.
left: types.MakeArray(types.MakeTuple([]*types.T{types.Any})),
right: types.MakeArray(types.MakeTuple([]*types.T{types.Bool})),
left: types.Float,
right: types.String,
expected: nil,
},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (b *Builder) projectColumn(dst *scopeColumn, src *scopeColumn) {
// default label for a function expression. Returns true if the function's
// return type is not an empty tuple and doesn't declare any tuple labels.
func (b *Builder) shouldCreateDefaultColumn(texpr tree.TypedExpr) bool {
if texpr.ResolvedType() == types.EmptyTuple {
if texpr.ResolvedType().Identical(types.EmptyTuple) {
// This is only to support crdb_internal.unary_table().
return false
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/optbuilder/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ func rightHasMoreSpecificTuple(left, right *types.T) (isMoreSpecific bool, isEqu
return rightHasMoreSpecificTuple(left.ArrayContents(), right.ArrayContents())
}
if left.Family() == types.TupleFamily && right.Family() == types.TupleFamily {
if right == types.AnyTuple {
if right.Identical(types.AnyTuple) {
return false, true
} else if left == types.AnyTuple {
} else if left.Identical(types.AnyTuple) {
return true, true
} else if len(left.TupleContents()) != len(right.TupleContents()) {
return false, false
Expand Down
5 changes: 3 additions & 2 deletions pkg/sql/randgen/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ func typeToStringCastHasIncorrectVolatility(t *types.T) bool {
types.IntervalFamily, types.TupleFamily:
return true
case types.OidFamily:
return t == types.RegClass || t == types.RegNamespace || t == types.RegProc ||
t == types.RegProcedure || t == types.RegRole || t == types.RegType
return t.Identical(types.RegClass) || t.Identical(types.RegNamespace) ||
t.Identical(types.RegProc) || t.Identical(types.RegProcedure) ||
t.Identical(types.RegRole) || t.Identical(types.RegType)
default:
return false
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/randgen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func RandColumnTypes(rng *rand.Rand, numCols int) []*types.T {
// RandSortingType returns a column type which can be key-encoded.
func RandSortingType(rng *rand.Rand) *types.T {
typ := RandType(rng)
for colinfo.MustBeValueEncoded(typ) || typ == types.Void {
for colinfo.MustBeValueEncoded(typ) || typ.Family() == types.VoidFamily {
typ = RandType(rng)
}
return typ
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/rowenc/roundtrip_format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestRandParseDatumStringAs(t *testing.T) {
},
types.Scalar...)
for _, ty := range types.Scalar {
if ty != types.Jsonb {
if !ty.Identical(types.Jsonb) {
tests = append(tests, types.MakeArray(ty))
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -10786,7 +10786,7 @@ func makeEnumTypeFunc(impl func(t *types.T) (tree.Datum, error)) tree.FnWithExpr
ctx context.Context, evalCtx *eval.Context, args tree.Exprs,
) (tree.Datum, error) {
enumType := args[0].(tree.TypedExpr).ResolvedType()
if enumType == types.Unknown || enumType == types.AnyEnum {
if enumType.Family() == types.UnknownFamily || enumType.Identical(types.AnyEnum) {
return nil, errors.WithHint(pgerror.New(pgcode.InvalidParameterValue, "input expression must always resolve to the same enum type"),
"Try NULL::yourenumtype")
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/sql/sem/cast/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ func ValidCast(src, tgt *types.T, ctx Context) bool {
// Casts from a tuple type to AnyTuple are a no-op so they are always valid.
// If tgt is AnyTuple, we continue to LookupCast below which contains a
// special case for these casts.
if srcFamily == types.TupleFamily && tgtFamily == types.TupleFamily && tgt != types.AnyTuple {
if srcFamily == types.TupleFamily && tgtFamily == types.TupleFamily &&
!tgt.Identical(types.AnyTuple) {
srcTypes := src.TupleContents()
tgtTypes := tgt.TupleContents()
// The tuple types must have the same number of elements.
Expand Down Expand Up @@ -265,7 +266,7 @@ func LookupCast(src, tgt *types.T) (Cast, bool) {

// Casts from any tuple type to AnyTuple are no-ops, so they are implicit
// and immutable.
if srcFamily == types.TupleFamily && tgt == types.AnyTuple {
if srcFamily == types.TupleFamily && tgt.Identical(types.AnyTuple) {
return Cast{
MaxContext: ContextImplicit,
Volatility: volatility.Immutable,
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/sem/eval/binary_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (e *evaluator) EvalConcatArraysOp(
func (e *evaluator) EvalConcatOp(
ctx context.Context, op *tree.ConcatOp, left, right tree.Datum,
) (tree.Datum, error) {
if op.Left == types.String {
if op.Left.Identical(types.String) {
casted, err := PerformCast(ctx, e.ctx(), right, types.String)
if err != nil {
return nil, err
Expand All @@ -238,7 +238,7 @@ func (e *evaluator) EvalConcatOp(
string(tree.MustBeDString(left)) + string(tree.MustBeDString(casted)),
), nil
}
if op.Right == types.String {
if op.Right.Identical(types.String) {
casted, err := PerformCast(ctx, e.ctx(), left, types.String)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/sem/eval/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ func performCastWithoutPrecisionTruncation(
case types.TupleFamily:
switch v := d.(type) {
case *tree.DTuple:
if t == types.AnyTuple {
if t.Identical(types.AnyTuple) {
// If AnyTuple is the target type, we can just use the input tuple.
return v, nil
}
Expand Down

0 comments on commit 3b455a7

Please sign in to comment.