diff --git a/enginetest/queries.go b/enginetest/queries.go index 0c89d493a7..58c21f495e 100755 --- a/enginetest/queries.go +++ b/enginetest/queries.go @@ -204,6 +204,14 @@ var QueryTests = []QueryTest{ {"second row"}, }, }, + { + Query: "SELECT mytable.i, selfjoined.s FROM mytable LEFT JOIN (SELECT * FROM mytable) selfjoined ON mytable.i = selfjoined.i", + Expected: []sql.Row{ + {1, "first row"}, + {2, "second row"}, + {3, "third row"}, + }, + }, { Query: "SELECT s,i FROM MyTable ORDER BY 2", Expected: []sql.Row{ diff --git a/enginetest/query_plans.go b/enginetest/query_plans.go index 1918678c30..06a8a5b32b 100755 --- a/enginetest/query_plans.go +++ b/enginetest/query_plans.go @@ -161,12 +161,13 @@ var PlanTests = []QueryPlanTest{ " ├─ Filter(ot.i2 > 0)\n" + " │ └─ TableAlias(ot)\n" + " │ └─ IndexedTableAccess(othertable on [othertable.i2])\n" + - " └─ SubqueryAlias(sub)\n" + - " └─ Project(mytable.i, othertable.i2, othertable.s2)\n" + - " └─ IndexedJoin(mytable.i = othertable.i2)\n" + - " ├─ Table(mytable)\n" + - " └─ Filter(NOT(convert(othertable.s2, signed) = 0))\n" + - " └─ IndexedTableAccess(othertable on [othertable.i2])\n" + + " └─ CachedResults\n" + + " └─ SubqueryAlias(sub)\n" + + " └─ Project(mytable.i, othertable.i2, othertable.s2)\n" + + " └─ IndexedJoin(mytable.i = othertable.i2)\n" + + " ├─ Table(mytable)\n" + + " └─ Filter(NOT(convert(othertable.s2, signed) = 0))\n" + + " └─ IndexedTableAccess(othertable on [othertable.i2])\n" + "", }, { @@ -253,6 +254,44 @@ var PlanTests = []QueryPlanTest{ " └─ IndexedTableAccess(othertable on [othertable.i2])\n" + "", }, + { + Query: "SELECT /*+ JOIN_ORDER(mytable, othertable) */ s2, i2, i FROM mytable INNER JOIN (SELECT * FROM othertable) othertable ON i2 = i", + ExpectedPlan: "Project(othertable.s2, othertable.i2, mytable.i)\n" + + " └─ InnerJoin(othertable.i2 = mytable.i)\n" + + " ├─ Table(mytable)\n" + + " └─ CachedResults\n" + + " └─ SubqueryAlias(othertable)\n" + + " └─ Table(othertable)\n" + + "", + }, + { + Query: "SELECT s2, i2, i FROM mytable LEFT JOIN (SELECT * FROM othertable) othertable ON i2 = i", + ExpectedPlan: "Project(othertable.s2, othertable.i2, mytable.i)\n" + + " └─ LeftJoin(othertable.i2 = mytable.i)\n" + + " ├─ Table(mytable)\n" + + " └─ CachedResults\n" + + " └─ SubqueryAlias(othertable)\n" + + " └─ Table(othertable)\n" + + "", + }, + { + Query: "SELECT s2, i2, i FROM mytable RIGHT JOIN (SELECT * FROM othertable) othertable ON i2 = i", + ExpectedPlan: "Project(othertable.s2, othertable.i2, mytable.i)\n" + + " └─ RightIndexedJoin(othertable.i2 = mytable.i)\n" + + " ├─ SubqueryAlias(othertable)\n" + + " │ └─ Table(othertable)\n" + + " └─ IndexedTableAccess(mytable on [mytable.i])\n" + + "", + }, + { + Query: "SELECT s2, i2, i FROM mytable INNER JOIN (SELECT * FROM othertable) othertable ON i2 = i", + ExpectedPlan: "Project(othertable.s2, othertable.i2, mytable.i)\n" + + " └─ IndexedJoin(othertable.i2 = mytable.i)\n" + + " ├─ SubqueryAlias(othertable)\n" + + " │ └─ Table(othertable)\n" + + " └─ IndexedTableAccess(mytable on [mytable.i])\n" + + "", + }, { Query: "SELECT othertable.s2, othertable.i2, mytable.i FROM mytable INNER JOIN (SELECT * FROM othertable) othertable ON othertable.i2 = mytable.i WHERE othertable.s2 > 'a'", ExpectedPlan: "Project(othertable.s2, othertable.i2, mytable.i)\n" + diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index 7cc5873f3e..df6682ecef 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -122,6 +122,23 @@ func nodeIsCacheable(n sql.Node, lowestAllowedIdx int) bool { return cacheable } +func isDeterminstic(n sql.Node) bool { + res := true + plan.InspectExpressions(n, func(e sql.Expression) bool { + if s, ok := e.(*plan.Subquery); ok { + if !isDeterminstic(s.Query) { + res = false + } + return false + } else if nd, ok := e.(sql.NonDeterministicExpression); ok && nd.IsNonDeterministic() { + res = false + return false + } + return true + }) + return res +} + // cacheSubqueryResults determines whether it's safe to cache the results for any subquery expressions, and marks the // subquery as cacheable if so. Caching subquery results is safe in the case that no outer scope columns are referenced, // and if all expressions in the subquery are deterministic. @@ -142,3 +159,49 @@ func cacheSubqueryResults(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scop return s, nil }) } + +// cacheSubqueryAlisesInJoins will look for joins against subquery aliases that +// will repeatedly execute the subquery, and will insert a *plan.CachedResults +// node on top of those nodes when it is safe to do so. +func cacheSubqueryAlisesInJoins(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { + n, err := plan.TransformUpWithParent(n, func(child, parent sql.Node, childNum int) (sql.Node, error) { + _, isJoin := parent.(plan.JoinNode) + _, isIndexedJoin := parent.(*plan.IndexedJoin) + if isJoin || isIndexedJoin { + sa, isSubqueryAlias := child.(*plan.SubqueryAlias) + if isSubqueryAlias && isDeterminstic(sa.Child) { + return plan.NewCachedResults(child), nil + } + } + return child, nil + }) + if err != nil { + return n, err + } + + // If the most primary table in the top level join is a CachedResults, remove it. + // We only want to do this if we're at the top of the tree. + // TODO: Not a perfect indicator of whether we're at the top of the tree... + if scope == nil { + selector := func(parent sql.Node, child sql.Node, childNum int) bool { + if _, isIndexedJoin := parent.(*plan.IndexedJoin); isIndexedJoin { + return childNum == 0 + } else if j, isJoin := parent.(plan.JoinNode); isJoin { + if j.JoinType() == plan.JoinTypeRight { + return childNum == 1 + } else { + return childNum == 0 + } + } + return true + } + n, err = plan.TransformUpWithSelector(n, selector, func(n sql.Node) (sql.Node, error) { + cr, isCR := n.(*plan.CachedResults) + if isCR { + return cr.UnaryNode.Child, nil + } + return n, nil + }) + } + return n, err +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index bd88687812..cbbb8326d1 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -77,6 +77,7 @@ var OnceAfterDefault = []Rule{ // previous rules. {"resolve_subquery_exprs", resolveSubqueryExpressions}, {"cache_subquery_results", cacheSubqueryResults}, + {"cache_subquery_aliases_in_joins", cacheSubqueryAlisesInJoins}, {"resolve_insert_rows", resolveInsertRows}, {"apply_triggers", applyTriggers}, {"apply_procedures", applyProcedures}, diff --git a/sql/plan/cached_results.go b/sql/plan/cached_results.go new file mode 100644 index 0000000000..e0ee5cbfa2 --- /dev/null +++ b/sql/plan/cached_results.go @@ -0,0 +1,141 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "io" + "sync" + + "github.com/dolthub/go-mysql-server/sql" +) + +// NewCachedResults returns a cached results plan Node, which will use a +// RowCache to cache results generated by Child.RowIter() and return those +// results for future calls to RowIter. This node is only safe to use if the +// Child is determinstic and is not dependent on the |row| parameter in the +// call to RowIter. +func NewCachedResults(n sql.Node) *CachedResults { + return &CachedResults{UnaryNode: UnaryNode{n}} +} + +type CachedResults struct { + UnaryNode + cache sql.RowsCache + dispose sql.DisposeFunc + mutex sync.Mutex + noCache bool +} + +func (n *CachedResults) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { + n.mutex.Lock() + defer n.mutex.Unlock() + if n.cache != nil { + return sql.RowsToRowIter(n.cache.Get()...), nil + } else if n.noCache { + return n.UnaryNode.Child.RowIter(ctx, r) + } + ci, err := n.UnaryNode.Child.RowIter(ctx, r) + if err != nil { + return nil, err + } + cache, dispose := ctx.Memory.NewRowsCache() + return &cachedResultsIter{n, ci, cache, dispose}, nil +} + +func (n *CachedResults) Dispose() { + if n.dispose != nil { + n.dispose() + } +} + +func (n *CachedResults) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("CachedResults") + _ = pr.WriteChildren(n.UnaryNode.Child.String()) + return pr.String() +} + +func (n *CachedResults) DebugString() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("CachedResults") + _ = pr.WriteChildren(sql.DebugString(n.UnaryNode.Child)) + return pr.String() +} + +func (n *CachedResults) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 1) + } + nn := *n + nn.UnaryNode.Child = children[0] + return &nn, nil +} + + + +type cachedResultsIter struct { + parent *CachedResults + iter sql.RowIter + cache sql.RowsCache + dispose sql.DisposeFunc +} + +func (i *cachedResultsIter) Next() (sql.Row, error) { + r, err := i.iter.Next() + if i.cache != nil { + if err != nil { + if err == io.EOF { + i.parent.mutex.Lock() + defer i.parent.mutex.Unlock() + i.setCacheInParent() + } else { + i.cleanUp() + } + } else { + aerr := i.cache.Add(r) + if aerr != nil { + i.cleanUp() + i.parent.mutex.Lock() + defer i.parent.mutex.Unlock() + i.parent.noCache = true + } + } + } + return r, err +} + +func (i *cachedResultsIter) setCacheInParent() { + if i.parent.cache == nil { + i.parent.cache = i.cache + i.parent.dispose = i.dispose + i.cache = nil + i.dispose = nil + } else { + i.cleanUp() + } +} + +func (i *cachedResultsIter) cleanUp() { + if i.dispose != nil { + i.dispose() + i.cache = nil + i.dispose = nil + } +} + +func (i *cachedResultsIter) Close(ctx *sql.Context) error { + i.cleanUp() + return i.iter.Close(ctx) +}