Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feature/sql exec cost estimate #603

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/engine/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
cost_1st
cost_2nd
15 changes: 15 additions & 0 deletions internal/engine/cost/catalog/catalog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package catalog

import (
ds "github.com/kwilteam/kwil-db/internal/engine/cost/datasource"
dt "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
)

type Catalog interface {
GetDataSource(tableRef *dt.TableRef) (ds.DataSource, error)
}

type defaultCatalogProvider struct {
dbidAliases map[string]string // alias -> dbid
srcs map[string]ds.DataSource
}
39 changes: 39 additions & 0 deletions internal/engine/cost/catalog/relation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package catalog

import (
"fmt"
"github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
"strings"
)

// RelationName is the name of a relation in the catalog.
// It can be an unqualified name(table) or a fully qualified name (database.table).
type RelationName string

func (t RelationName) String() string {
return string(t)
}

func (t RelationName) Segments() []string {
return strings.Split(string(t), ".")
}

func (t RelationName) IsQualified() bool {
return len(t.Segments()) > 1
}

func (t RelationName) Parse() (*datatypes.TableRef, error) {
segments := t.Segments()
switch len(segments) {
case 1:
return &datatypes.TableRef{Table: segments[0]}, nil
case 2:
return &datatypes.TableRef{DB: segments[0], Table: segments[1]}, nil
default:
return nil, fmt.Errorf("invalid relation name: %s", t)
}
}

func RelationNameFromString(s string) RelationName {
return RelationName(s)
}
136 changes: 136 additions & 0 deletions internal/engine/cost/costmodel/rel_expr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package costmodel

import (
"bytes"
"fmt"
"github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
"github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan"
)

const (
SeqAccessCostPerRow = 1 // sequential access disk cost
RandAccessCost = 3 // random access disk cost, i.e., index scan
)

// RelExpr is a wrapper of a logical plan,it's used for cost estimation.
// It tracks the statistics and cost from bottom to top.
// NOTE: this is simplified version of LogicalRel in memo package.
type RelExpr struct {
logical_plan.LogicalPlan

stat *datatypes.Statistics // current node's statistics
cost int64
inputs []*RelExpr
}

func (r *RelExpr) Inputs() []*RelExpr {
return r.inputs
}

func (r *RelExpr) String() string {
return fmt.Sprintf("%s, Stat: (%s), Cost: %d",
logical_plan.PlanString(r.LogicalPlan), r.stat, r.cost)
}

//// reorderColStat reorders the columns in the statistics according to the schema.
//// Schema can be changed by the projection/join, so we need to reorder the columns in
//// the statistics.
//func reorderColStat(oldStat *datatypes.Statistics, schema *datatypes.Schema) *datatypes.Statistics {
//
//}

// BuildRelExpr builds a RelExpr from a logical plan, also build the statistics.
// TODO: using iterator to traverse the plan tree.
func BuildRelExpr(plan logical_plan.LogicalPlan) *RelExpr {
inputs := make([]*RelExpr, len(plan.Inputs()))
for i, input := range plan.Inputs() {
inputs[i] = BuildRelExpr(input)
}

var stat *datatypes.Statistics

switch p := plan.(type) {
case *logical_plan.ScanOp:
stat = p.DataSource().Statistics()

case *logical_plan.ProjectionOp:
stat = inputs[0].stat

case *logical_plan.FilterOp:
stat = inputs[0].stat
// with filter, we can make uniformity assumption to simplify the cost model
exprs := p.Exprs()
fields := make([]datatypes.Field, len(exprs))
for i, expr := range exprs {
fields[i] = expr.Resolve(plan.Schema())
}

default:
stat = datatypes.NewEmptyStatistics()
}

return &RelExpr{
LogicalPlan: plan,
cost: 0,
inputs: inputs,
stat: stat,
}
}

func Format(plan *RelExpr, indent int) string {
var msg bytes.Buffer
for i := 0; i < indent; i++ {
msg.WriteString(" ")
}
msg.WriteString(plan.String())
msg.WriteString("\n")
for _, child := range plan.Inputs() {
msg.WriteString(Format(child, indent+2))
}
return msg.String()
}

//func EstimateCost(plan *RelExpr) int64 {
// cost := int64(0)
// // bottom-up
// for _, child := range plan.Inputs() {
// cost += EstimateCost(child)
// }
//
// // estimate current node's cost
// switch plan.LogicalPlan.(type) {
// case *logical_plan.ScanOp:
// // TODO: index scan
// cost += SeqAccessCost
// }
// return cost
//}
//

//// EstimateCost estimates the cost of a logical plan.
//// It uses iterator to traverse the plan tree.
//func EstimateCost(plan *RelExpr) int64 {
// stack := []*RelExpr{plan}
// cost := int64(0)
//
// for len(stack) > 0 {
// // Pop a node from the stack
// n := len(stack) - 1
// node := stack[n]
// stack = stack[:n]
//
// // Estimate current node's cost
// switch p := node.LogicalPlan.(type) {
// case *logical_plan.ScanOp:
// // TODO: index scan
// cost += p.
// }
//
// // Push all children onto the stack
// for _, child := range node.Inputs() {
// stack = append(stack, child)
// }
// }
//
// return cost
//}
183 changes: 183 additions & 0 deletions internal/engine/cost/costmodel/rel_expr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package costmodel

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/kwilteam/kwil-db/core/types"
"github.com/kwilteam/kwil-db/internal/engine/cost/internal/testkit"
"github.com/kwilteam/kwil-db/internal/engine/cost/query_planner"
"github.com/kwilteam/kwil-db/parse"
)

func Test_RelExpr_String(t *testing.T) {
tests := []struct {
name string
r *RelExpr
want string
}{
{
name: "test",
r: &RelExpr{},
want: "Unknown LogicalPlan type <nil>, Stat: (<nil>), Cost: 0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.r.String(); got != tt.want {
t.Errorf("RelExpr.String() = %v, want %v", got, tt.want)
}
})
}
}

func Test_NewRelExpr(t *testing.T) {
cat := testkit.InitMockCatalog()

tests := []struct {
name string
sql string
wt string // want
}{
/////////////////////// no relation
{
name: "select int",
sql: "SELECT 1",
wt: "Projection: 1, Stat: (RowCount: 0), Cost: 0\n" +
" NoRelationOp, Stat: (RowCount: 0), Cost: 0\n",
},
{
name: "select string",
sql: "SELECT 'hello'",
wt: "Projection: 'hello', Stat: (RowCount: 0), Cost: 0\n" +
" NoRelationOp, Stat: (RowCount: 0), Cost: 0\n",
},
{
name: "select value expression",
sql: "SELECT 1+2",
wt: "Projection: 1 + 2, Stat: (RowCount: 0), Cost: 0\n" +
" NoRelationOp, Stat: (RowCount: 0), Cost: 0\n",
},
// TODO: add function metadata to catalog
// TODO: add support for functions in logical expr
//{
// name: "select function abs",
// sql: "SELECT ABS(-1)",
// wt: "",
//},
/////////////////////// one relation
{
name: "select wildcard",
sql: "SELECT * FROM users",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.id, users.username, users.age, users.state, users.wallet, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
//{ // TODO?
// name: "select wildcard, deduplication",
// sql: "SELECT *, age FROM users",
// wt: "Projection: users.id, users.username, users.age, users.state, users.wallet\n" +
// " Scan: users; projection=[]\n",
//},
{
name: "select columns",
sql: "select username, age from users",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select column with alias",
sql: "select username as name from users",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username AS name, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select column expression",
sql: "select username, age+10 from users",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age + 10, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select with where",
sql: "select username, age from users where age > 20",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Filter: users.age > 20, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select with multiple where",
sql: "select username, age from users where age > 20 and state = 'CA'",
wt: "Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Filter: users.age > 20 AND users.state = 'CA', Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
//{
// name: "select with group by",
// sql: "select username, count(*) from users group by username",
// wt: "GroupBy: users.username\n",
//},
{
name: "select with limit, without offset",
sql: "select username, age from users limit 10",
wt: "Limit: skip=0, fetch=10, Stat: (RowCount: 0), Cost: 0\n" +
" Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select with limit and offset",
sql: "select username, age from users limit 10 offset 5",
wt: "Limit: skip=5, fetch=10, Stat: (RowCount: 0), Cost: 0\n" +
" Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select with order by default",
sql: "select username, age from users order by age",
wt: "Sort: age ASC NULLS LAST, id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
{
name: "select with order by desc",
sql: "select username, age from users order by age desc",
wt: "Sort: age DESC NULLS LAST, id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 5), Cost: 0\n" +
" Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
/////////////////////// subquery
{
name: "select with subquery",
sql: "select username, age from (select * from users) as u",
wt: "Sort: id ASC NULLS LAST, username ASC NULLS LAST, age ASC NULLS LAST, state ASC NULLS LAST, wallet ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.username, users.age, Stat: (RowCount: 0), Cost: 0\n" +
" Sort: id ASC NULLS LAST, Stat: (RowCount: 0), Cost: 0\n" +
" Projection: users.id, users.username, users.age, users.state, users.wallet, Stat: (RowCount: 5), Cost: 0\n Scan: users, Stat: (RowCount: 5), Cost: 0\n",
},
/////////////////////// two relations

}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pr, err := parse.ParseSQL(tt.sql, &types.Schema{
Name: "",
Tables: []*types.Table{testkit.MockUsersSchemaTable},
})

assert.NoError(t, err)
assert.NoError(t, pr.ParseErrs.Err())

q := query_planner.NewPlanner(cat)
plan := q.ToPlan(pr.AST)
rel := BuildRelExpr(plan)
assert.Equal(t, tt.wt, Format(rel, 0))
})
}
}
Loading