Skip to content

Commit

Permalink
Merge pull request #303 from dolthub/aaron/analyzer-join-search-goes-…
Browse files Browse the repository at this point in the history
…faster

sql/analyzer: Make join_search faster
  • Loading branch information
reltuk authored Feb 17, 2021
2 parents 5724607 + a53df31 commit bb0f2ed
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 300 deletions.
341 changes: 90 additions & 251 deletions sql/analyzer/join_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package analyzer

import (
"fmt"
"math"
"strings"

Expand Down Expand Up @@ -85,20 +84,17 @@ func buildJoinTree(
joinConds []*joinCond,
) *joinSearchNode {

rootNodes := searchJoins(nil, &joinSearchParams{
tables: tableOrder,
joinConds: joinConds,
})

for _, tree := range rootNodes {
// The search function here can return valid sub trees that don't have all the tables in the full join, so we need
// to check them for validity as an entire tree
if isValidJoinTree(tree) {
return tree
var found *joinSearchNode
visitJoinSearchNodes(tableOrder, func(n *joinSearchNode) bool {
assignConditions(n, joinConds)
if n.joinCond != nil {
found = n
return false
}
}
return true
})

return nil
return found
}

// Estimates the cost of the table ordering given. Lower numbers are better. Bails out and returns cost so far if cost
Expand Down Expand Up @@ -188,50 +184,97 @@ func permutations(a []int) (res [][]int) {
return res
}

// joinSearchParams is a simple struct to track available tables and join conditions during a join search
type joinSearchParams struct {
tables []string
usedTableIndexes []int
joinConds []*joinCond
usedJoinCondsIndexes []int
}

func (js *joinSearchParams) copy() *joinSearchParams {
usedTableIndexesCopy := make([]int, len(js.usedTableIndexes))
copy(usedTableIndexesCopy, js.usedTableIndexes)
usedJoinCondIndexesCopy := make([]int, len(js.usedJoinCondsIndexes))
copy(usedJoinCondIndexesCopy, js.usedJoinCondsIndexes)
return &joinSearchParams{
tables: js.tables,
usedTableIndexes: usedTableIndexesCopy,
joinConds: js.joinConds,
usedJoinCondsIndexes: usedJoinCondIndexesCopy,
// visitJoinSearchNodes visits every possible joinSearchNode where the
// in-order leaves are given by |tables|. If the callback returns
// |false|, visits stop.
func visitJoinSearchNodes(tables []string, cb func(n *joinSearchNode) bool) {
if len(tables) == 0 {
return
}
if len(tables) == 1 {
cb(&joinSearchNode{table: tables[0]})
return
}
var stop bool
for i := 1; i < len(tables) && !stop; i++ {
visitJoinSearchNodes(tables[:i], func(l *joinSearchNode) bool {
visitJoinSearchNodes(tables[i:len(tables)], func(r *joinSearchNode) bool {
if !cb(&joinSearchNode{left: l, right: r}) {
stop = true
return false
}
return true
})
return !stop
})
}
}

func (js *joinSearchParams) tableIndexUsed(i int) bool {
return indexOfInt(i, js.usedTableIndexes) >= 0
}

func (js *joinSearchParams) joinCondIndexUsed(i int) bool {
return indexOfInt(i, js.usedJoinCondsIndexes) >= 0
// assignConditions attempts to assign the conditions in |conditions|
// to the search tree in |root|, such that every condition is on an
// internal node, and all of the trees referenced in the condition
// appear in tables which are in the subtree of its internal node. If
// it finds an assignment, leaves it in the |joinSearchNode.joinCond|
// fields of the provided tree. Otherwise there is no valid assignment
// and leaves the provided tree unmodified.
func assignConditions(root *joinSearchNode, conditions []*joinCond) {
// A recursive helper which is going to assign conditions to
// subtrees, remove the assigned conditions from |conditions|
// and make a callback to |cb| for each such assignment that
// is found.
var helper func(n *joinSearchNode, cb func() bool) bool
helper = func(n *joinSearchNode, cb func() bool) bool {
if n.isLeaf() {
return cb()
}
// for each assignment of conditions to the left tree
return helper(n.left, func() bool {
// for each assignment of conditions to the right tree
return helper(n.right, func() bool {
tables := n.tableOrder()
// look at every remaining condition
for i := range conditions {
cond := conditions[i]
joinCondTables := findTables(cond.cond)
// if the condition only references tables in our subtree
if containsAll(joinCondTables, tables) {
n.joinCond = cond
conditions = append(conditions[:i], conditions[i+1:]...)
// continue the search with this assignment tried
if !cb() {
conditions = append(conditions, nil)
copy(conditions[i+1:], conditions[i:])
conditions[i] = n.joinCond
return false
}
conditions = append(conditions, nil)
copy(conditions[i+1:], conditions[i:])
conditions[i] = n.joinCond
n.joinCond = nil
}
}
return true
})
})
}
helper(root, func() bool {
if root.joinCond != nil && len(conditions) == 0 {
return false
}
return true
})
}

// A joinSearchNode is a simplified type representing a join tree node, which is either an internal node (a join) or a
// leaf node (a table). The top level node in a join tree is always an internal node. Every internal node has both a
// left and a right child.
type joinSearchNode struct {
table string // empty if this is an internal node
joinCond *joinCond // nil if this is a leaf node
parent *joinSearchNode // nil if this is the root node
left *joinSearchNode // nil if this is a leaf node
right *joinSearchNode // nil if this is a leaf node
params *joinSearchParams // search params that assembled this node
table string // empty if this is an internal node
joinCond *joinCond // nil if this is a leaf node
left *joinSearchNode // nil if this is a leaf node
right *joinSearchNode // nil if this is a leaf node
}

// used to mark the left or right branch of a node as being targeted for assignment
var childTargetNode = &joinSearchNode{}

// tableOrder returns the order of the tables in this part of the tree, using an in-order traversal
func (n *joinSearchNode) tableOrder() []string {
if n == nil {
Expand All @@ -253,225 +296,21 @@ func (n *joinSearchNode) isLeaf() bool {
return len(n.table) > 0
}

// joinConditionSatisfied returns whether all the tables mentioned in this join node are present in descendants.
func (n *joinSearchNode) joinConditionSatisfied() bool {
if n.isLeaf() {
return true
}

joinCondTables := findTables(n.joinCond.cond)
childTables := n.tableOrder()
// TODO: case sensitivity
if !containsAll(joinCondTables, childTables) {
return false
}

return n.left.joinConditionSatisfied() && n.right.joinConditionSatisfied()
}

// copy returns a copy of this node
func (n *joinSearchNode) copy() *joinSearchNode {
nn := *n
nn.params = nn.params.copy()
return &nn
}

// targetLeft returns a copy of this node with the left child marked for replacement by withChild
func (n *joinSearchNode) targetLeft() *joinSearchNode {
nn := n.copy()
nn.left = childTargetNode
return nn
}

// targetRight returns a copy of this node with the right child marked for replacement by withChild
func (n *joinSearchNode) targetRight() *joinSearchNode {
nn := n.copy()
nn.right = childTargetNode
return nn
}

// withChild returns a copy of this node with the previously marked child replaced by the node given.
// See targetLeft, targetRight
func (n *joinSearchNode) withChild(child *joinSearchNode) *joinSearchNode {
nn := n.copy()
if nn.left == childTargetNode {
nn.left = child
return nn
} else if nn.right == childTargetNode {
nn.right = child
return nn
} else {
panic("withChild couldn't find a child to assign")
}
}

// accumulateAllUsed rolls up joinSearchParams from this node and all descendants, combining their used tallies
func (n *joinSearchNode) accumulateAllUsed() *joinSearchParams {
if n == nil || n.params == nil {
return &joinSearchParams{}
}

if n.isLeaf() {
return n.params
}

leftParams := n.left.accumulateAllUsed()
rightParams := n.right.accumulateAllUsed()

result := n.params.copy()
// TODO: eliminate duplicates from these lists, or use sets
result.usedJoinCondsIndexes = append(result.usedJoinCondsIndexes, leftParams.usedJoinCondsIndexes...)
result.usedJoinCondsIndexes = append(result.usedJoinCondsIndexes, rightParams.usedJoinCondsIndexes...)
result.usedTableIndexes = append(result.usedTableIndexes, leftParams.usedTableIndexes...)
result.usedTableIndexes = append(result.usedTableIndexes, rightParams.usedTableIndexes...)

return result
}

func (n *joinSearchNode) String() string {
if n == nil {
return "nil"
}

if n == childTargetNode {
return "childTargetNode"
}

if n.isLeaf() {
return n.table
}

usedJoins := ""
if n.params != nil && len(n.params.usedJoinCondsIndexes) > 0 {
usedJoins = fmt.Sprintf("%v", n.params.usedJoinCondsIndexes)
}

usedTables := ""
if n.params != nil && len(n.params.usedTableIndexes) > 0 {
usedTables = fmt.Sprintf("%v", n.params.usedTableIndexes)
}

tp := sql.NewTreePrinter()
if len(usedTables)+len(usedJoins) > 0 {
_ = tp.WriteNode("%s (usedJoins = %v, usedTables = %v)", n.joinCond.cond, usedJoins, usedTables)
} else {
_ = tp.WriteNode("%s", n.joinCond.cond)
}

_ = tp.WriteNode("%s", n.joinCond.cond)
_ = tp.WriteChildren(n.left.String(), n.right.String())

return tp.String()
}

// searchJoins is the recursive helper function for buildJoinTree. It returns all possible join trees that satisfy the
// search parameters given. It calls itself recursively to generate subtrees as well. All nodes returned are valid
// subtrees (join conditions and table sub ordering satisfied), but may not be valid as an entire tree. Callers should
// verify this themselves using isValidJoinTree() on the result.
func searchJoins(parent *joinSearchNode, params *joinSearchParams) []*joinSearchNode {
// Our goal is to construct all possible child nodes for the parent given. Every permutation of a legal subtree should
// go into this list.
children := make([]*joinSearchNode, 0)

// If we have a parent to assign them to, consider returning tables as nodes. Otherwise, skip them.
if parent != nil {
// Find all tables mentioned in join nodes up to the root of the tree. We can't add any tables that aren't in this
// list
var validChildTables []string
n := parent
for n != nil {
validChildTables = append(validChildTables, findTables(n.joinCond.cond)...)
n = n.parent
}

// Tables are valid to return if they are mentioned in a join condition higher in the tree.
for i, table := range parent.params.tables {
if indexOf(table, validChildTables) < 0 || parent.params.tableIndexUsed(i) {
continue
}
paramsCopy := params.copy()
paramsCopy.usedTableIndexes = append(paramsCopy.usedTableIndexes, i)

childNode := &joinSearchNode{
table: table,
params: paramsCopy,
parent: parent.copy(),
}
if parent.withChild(childNode).tableOrderCorrect() {
children = append(children, childNode)
}
}
}

// now for each of the available join nodes
for i, cond := range params.joinConds {
if params.joinCondIndexUsed(i) {
continue
}

paramsCopy := params.copy()
paramsCopy.usedJoinCondsIndexes = append(paramsCopy.usedJoinCondsIndexes, i)

candidate := &joinSearchNode{
joinCond: cond,
parent: parent,
params: paramsCopy,
}

// For each of the left and right branch, find all possible children, add all valid subtrees to the list
candidate = candidate.targetLeft()
leftChildren := searchJoins(candidate, paramsCopy)

// pay attention to variable shadowing in this block
for _, left := range leftChildren {
if !isValidJoinSubTree(left) {
continue
}
candidate := candidate.withChild(left).targetRight()
candidate.params = candidate.accumulateAllUsed()
rightChildren := searchJoins(candidate, paramsCopy)
for _, right := range rightChildren {
if !isValidJoinSubTree(right) {
continue
}
candidate := candidate.withChild(right)
if isValidJoinSubTree(candidate) {
children = append(children, candidate)
}
}
}
}

return children
}

// tableOrderCorrect returns whether the tables in this subtree appear in a valid order.
func (n *joinSearchNode) tableOrderCorrect() bool {
tableOrder := n.tableOrder()
prevIdx := -1
for _, table := range tableOrder {
idx := indexOf(table, n.params.tables)
if idx <= prevIdx {
return false
}
prevIdx = idx
}
return true
}

// isValidJoinSubTree returns whether the node given satisfies all the constraints of a join subtree. Subtrees are not
// necessarily complete join plans, since they may not contain all tables. Use isValidJoinTree to verify that.
func isValidJoinSubTree(node *joinSearchNode) bool {
// Two constraints define a valid tree:
// 1) An in-order traversal has tables in the correct order
// 2) The conditions for all internal nodes can be satisfied by their child columns
return node.tableOrderCorrect() && node.joinConditionSatisfied()
}

// isValidJoinTree returns whether the join node given is a valid subtree and contains all the tables in the join.
func isValidJoinTree(node *joinSearchNode) bool {
return isValidJoinSubTree(node) && strArraysEqual(node.tableOrder(), node.params.tables)
}

func containsAll(needles []string, haystack []string) bool {
for _, needle := range needles {
if indexOf(needle, haystack) < 0 {
Expand Down
Loading

0 comments on commit bb0f2ed

Please sign in to comment.