Skip to content

Commit

Permalink
opt/memo: extend OutputCols with VirtualCols in statistics builder
Browse files Browse the repository at this point in the history
Throughout statistics builder we use OutputCols to determine which
columns come from the input to an expression. We then typically call
colStatXXX with those columns as part of statistics calculation.

In order to use statistics on virtual computed columns, we need to call
colStatXXX on any virtual columns that could come from our input, even
if they are not passed upward through OutputCols. To do this we extend
OutputCols with the VirtualCols set we built in a previous commit. This
commit replaces almost all usages of OutputCols in statistics builder
with a call to helper function colStatCols, which returns a union of
OutputCols and VirtualCols.

This is enough to get the optimizer to use statistics on virtual
computed columns in some simple plans. More complex plans will require
matching the virtual column scalar expressions, which will be in the
next PR. I've left some TODOs marking spots where this next PR will
touch.

Informs: #68254

Epic: CRDB-8949

Release note: None
  • Loading branch information
michae2 committed Mar 19, 2024
1 parent 62de0bf commit 05dabed
Show file tree
Hide file tree
Showing 2 changed files with 449 additions and 29 deletions.
74 changes: 45 additions & 29 deletions pkg/sql/opt/memo/statistics_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ func (sb *statisticsBuilder) clear() {
sb.md = nil
}

// colStatCols returns the set of columns which may be looked up in
// props.Statistics().ColStats, which includes props.OutputCols and any virtual
// computed columns we have statistics on that are in scope.
func (sb *statisticsBuilder) colStatCols(props *props.Relational) opt.ColSet {
if sb.evalCtx.SessionData().OptimizerUseVirtualComputedColumnStats {
return props.OutputCols.Union(props.Statistics().VirtualCols)
}
return props.OutputCols
}

// colStatFromChild retrieves a column statistic from a specific child of the
// given expression.
func (sb *statisticsBuilder) colStatFromChild(
Expand All @@ -259,7 +269,7 @@ func (sb *statisticsBuilder) colStatFromChild(
child := e.Child(childIdx).(RelExpr)
childProps := child.Relational()
if !colSet.SubsetOf(childProps.OutputCols) {
colSet = colSet.Intersection(childProps.OutputCols)
colSet = colSet.Intersection(sb.colStatCols(childProps))
if colSet.Empty() {
// All the columns in colSet are outer columns; therefore, we can treat
// them as a constant.
Expand Down Expand Up @@ -338,24 +348,26 @@ func (sb *statisticsBuilder) colStatFromInput(

if lookupJoin != nil || invertedJoin != nil || zigzagJoin != nil ||
opt.IsJoinOp(e) || e.Op() == opt.MergeJoinOp {

var leftProps *props.Relational
if zigzagJoin != nil {
leftProps = &zigzagJoin.leftProps
} else {
leftProps = e.Child(0).(RelExpr).Relational()
}
intersectsLeft := sb.colStatCols(leftProps).Intersects(colSet)

intersectsLeft := leftProps.OutputCols.Intersects(colSet)
var intersectsRight bool
var rightProps *props.Relational
if lookupJoin != nil {
intersectsRight = lookupJoin.lookupProps.OutputCols.Intersects(colSet)
rightProps = &lookupJoin.lookupProps
} else if invertedJoin != nil {
intersectsRight = invertedJoin.lookupProps.OutputCols.Intersects(colSet)
rightProps = &invertedJoin.lookupProps
} else if zigzagJoin != nil {
intersectsRight = zigzagJoin.rightProps.OutputCols.Intersects(colSet)
rightProps = &zigzagJoin.rightProps
} else {
intersectsRight = e.Child(1).(RelExpr).Relational().OutputCols.Intersects(colSet)
rightProps = e.Child(1).(RelExpr).Relational()
}
intersectsRight := sb.colStatCols(rightProps).Intersects(colSet)

// It's possible that colSet intersects both left and right if we have a
// lookup join that was converted from an index join, so check the left
Expand Down Expand Up @@ -1095,7 +1107,7 @@ func (sb *statisticsBuilder) colStatProject(
// Columns may be passed through from the input, or they may reference a
// higher scope (in the case of a correlated subquery), or they
// may be synthesized by the projection operation.
inputCols := prj.Input.Relational().OutputCols
inputCols := sb.colStatCols(prj.Input.Relational())
reqInputCols := colSet.Intersection(inputCols)
nonNullFound := false
reqSynthCols := colSet.Difference(inputCols)
Expand All @@ -1110,6 +1122,7 @@ func (sb *statisticsBuilder) colStatProject(
for i := range prj.Projections {
item := &prj.Projections[i]
if reqSynthCols.Contains(item.Col) {
// TODO(michae2): Check for matching virtual column expressions here.
reqInputCols.UnionWith(item.scalar.OuterCols)

// If the element is not a null constant, account for that
Expand Down Expand Up @@ -1231,8 +1244,8 @@ func (sb *statisticsBuilder) buildJoin(

leftStats := h.leftProps.Statistics()
rightStats := h.rightProps.Statistics()
leftCols := h.leftProps.OutputCols.Copy()
rightCols := h.rightProps.OutputCols.Copy()
leftCols := sb.colStatCols(h.leftProps)
rightCols := sb.colStatCols(h.rightProps)
equivReps := h.filtersFD.EquivReps()

switch h.joinType {
Expand Down Expand Up @@ -1316,8 +1329,8 @@ func (sb *statisticsBuilder) buildJoin(
constrainedCols = sb.tryReduceJoinCols(
constrainedCols,
s,
h.leftProps.OutputCols,
h.rightProps.OutputCols,
leftCols,
rightCols,
&h.leftProps.FuncDeps,
&h.rightProps.FuncDeps,
)
Expand All @@ -1334,7 +1347,7 @@ func (sb *statisticsBuilder) buildJoin(
// calculations. It will be fixed below.
s.RowCount = leftStats.RowCount
s.ApplySelectivity(sb.selectivityFromEquivalenciesSemiJoin(
equivReps, h.leftProps.OutputCols, h.rightProps.OutputCols, &h.filtersFD, join, s,
equivReps, leftCols, rightCols, &h.filtersFD, join, s,
))
var oredTermSelectivity props.Selectivity
oredTermSelectivity, numUnappliedConjuncts =
Expand Down Expand Up @@ -1394,7 +1407,7 @@ func (sb *statisticsBuilder) buildJoin(
switch h.joinType {
case opt.SemiJoinOp, opt.SemiJoinApplyOp, opt.AntiJoinOp, opt.AntiJoinApplyOp:
// Keep only column stats from the left side.
s.ColStats.RemoveIntersecting(h.rightProps.OutputCols)
s.ColStats.RemoveIntersecting(rightCols)
}

// The above calculation is for inner joins. Other joins need to remove stats
Expand All @@ -1403,12 +1416,12 @@ func (sb *statisticsBuilder) buildJoin(
case opt.LeftJoinOp, opt.LeftJoinApplyOp:
// Keep only column stats from the right side. The stats from the left side
// are not valid.
s.ColStats.RemoveIntersecting(h.leftProps.OutputCols)
s.ColStats.RemoveIntersecting(leftCols)

case opt.RightJoinOp:
// Keep only column stats from the left side. The stats from the right side
// are not valid.
s.ColStats.RemoveIntersecting(h.rightProps.OutputCols)
s.ColStats.RemoveIntersecting(rightCols)

case opt.FullJoinOp:
// Do not keep any column stats.
Expand Down Expand Up @@ -1534,6 +1547,9 @@ func (sb *statisticsBuilder) colStatJoin(colSet opt.ColSet, join RelExpr) *props
rightProps = join.Child(1).(RelExpr).Relational()
}

leftCols := sb.colStatCols(leftProps)
rightCols := sb.colStatCols(rightProps)

switch joinType {
case opt.SemiJoinOp, opt.SemiJoinApplyOp, opt.AntiJoinOp, opt.AntiJoinApplyOp:
// Column stats come from left side of join.
Expand All @@ -1560,8 +1576,8 @@ func (sb *statisticsBuilder) colStatJoin(colSet opt.ColSet, join RelExpr) *props
inputRowCount := leftProps.Statistics().RowCount * rightProps.Statistics().RowCount
leftNullCount := leftProps.Statistics().RowCount
rightNullCount := rightProps.Statistics().RowCount
leftColsAreEmpty := !leftProps.OutputCols.Intersects(colSet)
rightColsAreEmpty := !rightProps.OutputCols.Intersects(colSet)
leftColsAreEmpty := !leftCols.Intersects(colSet)
rightColsAreEmpty := !rightCols.Intersects(colSet)
if rightColsAreEmpty {
colStat = sb.copyColStat(colSet, s, sb.colStatFromJoinLeft(colSet, join))
leftNullCount = colStat.NullCount
Expand All @@ -1577,12 +1593,9 @@ func (sb *statisticsBuilder) colStatJoin(colSet opt.ColSet, join RelExpr) *props
colStat.ApplySelectivity(s.Selectivity, inputRowCount)
}
} else {
// Column stats come from both sides of join.
leftCols := leftProps.OutputCols.Intersection(colSet)
rightCols := rightProps.OutputCols.Intersection(colSet)
// Make a copy of the input column stats so we don't modify the originals.
leftColStat := *sb.colStatFromJoinLeft(leftCols, join)
rightColStat := *sb.colStatFromJoinRight(rightCols, join)
leftColStat := *sb.colStatFromJoinLeft(leftCols.Intersection(colSet), join)
rightColStat := *sb.colStatFromJoinRight(rightCols.Intersection(colSet), join)

leftNullCount = leftColStat.NullCount
rightNullCount = rightColStat.NullCount
Expand Down Expand Up @@ -1775,7 +1788,7 @@ func (sb *statisticsBuilder) colStatIndexJoin(
s := relProps.Statistics()

inputProps := join.Input.Relational()
inputCols := inputProps.OutputCols
inputCols := sb.colStatCols(inputProps)

colStat, _ := s.ColStats.Add(colSet)
colStat.DistinctCount = 1
Expand Down Expand Up @@ -2581,7 +2594,7 @@ func (sb *statisticsBuilder) colStatProjectSet(

inputProps := projectSet.Input.Relational()
inputStats := inputProps.Statistics()
inputCols := inputProps.OutputCols
inputCols := sb.colStatCols(inputProps)

colStat, _ := s.ColStats.Add(colSet)
colStat.DistinctCount = 1
Expand Down Expand Up @@ -2636,7 +2649,7 @@ func (sb *statisticsBuilder) colStatProjectSet(
zipColsNullCount *= (s.RowCount - inputStats.RowCount) / s.RowCount
}

if item.ScalarProps().OuterCols.Intersects(inputProps.OutputCols) {
if item.ScalarProps().OuterCols.Intersects(inputCols) {
// The column(s) are correlated with the input, so they may have a
// distinct value for each distinct row of the input.
zipColsDistinctCount *= inputStats.RowCount * UnknownDistinctCountRatio
Expand Down Expand Up @@ -3077,8 +3090,8 @@ func (sb *statisticsBuilder) rowsProcessed(e RelExpr) float64 {
panic(errors.AssertionFailedf("rowsProcessed not supported for operator type %v", redact.Safe(e.Op())))
}

leftCols := e.Child(0).(RelExpr).Relational().OutputCols
rightCols := e.Child(1).(RelExpr).Relational().OutputCols
leftCols := sb.colStatCols(e.Child(0).(RelExpr).Relational())
rightCols := sb.colStatCols(e.Child(1).(RelExpr).Relational())
filters := e.Child(2).(*FiltersExpr)

// Remove ON conditions that are not equality conditions,
Expand Down Expand Up @@ -3277,6 +3290,7 @@ func (sb *statisticsBuilder) applyFilters(
func (sb *statisticsBuilder) applyFiltersItem(
filter *FiltersItem, e RelExpr, relProps *props.Relational,
) (numUnappliedConjuncts float64, constrainedCols, histCols opt.ColSet) {
// TODO(michae2): Check for matching virtual column expressions here.
if isEqualityWithTwoVars(filter.Condition) {
// Equalities are handled by applyEquivalencies.
return 0, opt.ColSet{}, opt.ColSet{}
Expand Down Expand Up @@ -4393,6 +4407,7 @@ func (sb *statisticsBuilder) selectivityFromOredEquivalencies(
// is able to build column equivalencies.
switch disjuncts[i].(type) {
case *EqExpr, *AndExpr:
// TODO(michae2): Check for matching virtual column expressions here.
if andFilters, ok = addEqExprConjuncts(disjuncts[i], andFilters, e.Memo()); !ok {
numUnappliedConjuncts++
continue
Expand Down Expand Up @@ -4423,7 +4438,8 @@ func (sb *statisticsBuilder) selectivityFromOredEquivalencies(
equivReps := FD.EquivReps()
if semiJoin {
singleSelectivity = sb.selectivityFromEquivalenciesSemiJoin(
equivReps, h.leftProps.OutputCols, h.rightProps.OutputCols, FD, e, s,
equivReps, sb.colStatCols(h.leftProps), sb.colStatCols(h.rightProps),
FD, e, s,
)
} else {
singleSelectivity = sb.selectivityFromEquivalencies(equivReps, FD, e, s)
Expand Down

0 comments on commit 05dabed

Please sign in to comment.