Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Yaiba committed Mar 25, 2024
1 parent c79b07c commit 75e47aa
Show file tree
Hide file tree
Showing 15 changed files with 283 additions and 133 deletions.
18 changes: 17 additions & 1 deletion internal/engine/cost/datasource/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func dsScan(dsSchema *datatypes.Schema, dsRecords []Row, projection []string) *R
//}

fieldIndex := dsSchema.MapProjection(projection)
newSchema := dsSchema.Select(projection...)
newSchema := dsSchema.Project(projection...)
//newFields := make([]datatypes.Field, len(projection))
//for i, idx := range fieldIndex {
// newFields[i] = dsSchema.Fields[idx]
Expand Down Expand Up @@ -290,3 +290,19 @@ func (ds *csvDataSource) Scan(projection ...string) *Result {
func (ds *csvDataSource) SourceType() SourceType {
return "csv"
}

type DefaultSchemaSource struct {
datasource DataSource
}

func (s *DefaultSchemaSource) Schema() *datatypes.Schema {
return s.datasource.Schema()
}

func (s *DefaultSchemaSource) Scan(projection ...string) *Result {
return s.datasource.Scan(projection...)
}

func DataAsSchemaSource(ds DataSource) SchemaSource {
return &DefaultSchemaSource{datasource: ds}
}
15 changes: 12 additions & 3 deletions internal/engine/cost/datasource/schema.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package datasource

import "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
import (
"github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
)

// SchemaSource is an interface that provides the access to schema.
// It's used to get the schema of a table. It doesn't have the ability to
Expand All @@ -11,6 +13,13 @@ type SchemaSource interface {
}

// SchemaSourceToDataSource converts a SchemaSource to a DataSource.
func SchemaSourceToDataSource(s SchemaSource) DataSource {
return nil
func SchemaSourceToDataSource(ss SchemaSource) DataSource {
switch t := ss.(type) {
case *DefaultSchemaSource:
return t.datasource
case *csvDataSource:
return t
default:
panic("SchemaSourceToDataSource: SchemaSource cannot be converted to DataSource")
}
}
2 changes: 1 addition & 1 deletion internal/engine/cost/datatypes/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (s *Schema) String() string {
return fmt.Sprintf("[%s]", strings.Join(fields, ", "))
}

func (s *Schema) Select(projection ...string) *Schema {
func (s *Schema) Project(projection ...string) *Schema {
if len(projection) == 0 {
return NewSchema(s.Fields...)
}
Expand Down
19 changes: 4 additions & 15 deletions internal/engine/cost/logical_plan/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package logical_plan

import (
"fmt"
"github.com/kwilteam/kwil-db/parse/sql/tree"
"strings"

dt "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
pt "github.com/kwilteam/kwil-db/internal/engine/cost/plantree"
"github.com/kwilteam/kwil-db/parse/sql/tree"
)

//
Expand All @@ -29,16 +27,6 @@ type LogicalExpr interface {
Resolve(*dt.Schema) dt.Field
}

type LogicalExprList []LogicalExpr

func (e LogicalExprList) String() string {
fields := make([]string, len(e))
for i, expr := range e {
fields[i] = expr.String()
}
return strings.Join(fields, ", ")
}

// ColumnExpr represents a column in a schema.
// NOTE: it will be transformed to columnIdxExpr in the logical plan.????
type ColumnExpr struct {
Expand Down Expand Up @@ -512,6 +500,7 @@ func (a *aggregateIntExpr) E() LogicalExpr {
func (a *aggregateIntExpr) aggregate() {}

func Count(expr LogicalExpr) *aggregateIntExpr {
// TODO: Count should be an sql function expression
return &aggregateIntExpr{BaseTreeNode: pt.NewBaseTreeNode(), name: "COUNT", expr: expr}
}

Expand Down Expand Up @@ -642,7 +631,7 @@ type scalarFuncExpr struct {
}

func (e *scalarFuncExpr) String() string {
return fmt.Sprintf("%s(%s)", e.fn.Name(), LogicalExprList(e.args))
return fmt.Sprintf("%s(%s)", e.fn.Name(), ppList(e.args))
}

func (e *scalarFuncExpr) Resolve(schema *dt.Schema) dt.Field {
Expand Down Expand Up @@ -672,7 +661,7 @@ type aggregateFuncExpr struct {
}

func (e *aggregateFuncExpr) String() string {
return fmt.Sprintf("%s(%s)", e.fn.Name(), LogicalExprList(e.args))
return fmt.Sprintf("%s(%s)", e.fn.Name(), ppList(e.args))
}

func (e *aggregateFuncExpr) Resolve(schema *dt.Schema) dt.Field {
Expand Down
44 changes: 25 additions & 19 deletions internal/engine/cost/logical_plan/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type ScanOp struct {
// schema after projection(i.e. only keep the projected columns in the schema)
projectedSchema *dt.Schema
// used for selection push down optimization
filter LogicalExprList
filter []LogicalExpr
}

func (s *ScanOp) Table() *dt.TableRef {
Expand All @@ -57,22 +57,26 @@ func (s *ScanOp) Projection() []string {
return s.projection
}

func (s *ScanOp) Selection() []LogicalExpr {
func (s *ScanOp) Filter() []LogicalExpr {
if len(s.filter) == 0 {
return []LogicalExpr{}
}
return s.filter
}

func (s *ScanOp) String() string {
output := fmt.Sprintf("Scan: %s", s.table)
if len(s.filter) > 0 {
return fmt.Sprintf("Scan: %s; filter=%s; projection=%s", s.table, s.filter, s.projection)
output += fmt.Sprintf("; filter=[%s]", ppList(s.filter))
}
return fmt.Sprintf("Scan: %s; projection=%s", s.table, s.projection)
if len(s.projection) > 0 {
output += fmt.Sprintf("; projection=[%s]", ppList(s.projection))
}
return output
}

func (s *ScanOp) Schema() *dt.Schema {
//return s.dataSource.Schema().Select(s.projection...)
//return s.dataSource.Schema().Project(s.projection...)
return s.projectedSchema
}

Expand All @@ -86,23 +90,23 @@ func (s *ScanOp) Exprs() []LogicalExpr {

// Scan creates a table scan logical plan.
func Scan(table *dt.TableRef, ds ds.SchemaSource,
selection []LogicalExpr, projection ...string) LogicalPlan {
projectedSchema := ds.Schema().Select(projection...)
filter []LogicalExpr, projection ...string) LogicalPlan {
projectedSchema := ds.Schema().Project(projection...)
qualifiedSchema := dt.NewSchemaQualified(table, projectedSchema.Fields...)
return &ScanOp{table: table, dataSource: ds, projection: projection,
filter: selection, projectedSchema: qualifiedSchema}
filter: filter, projectedSchema: qualifiedSchema}
}

// ProjectionOp represents a projection operator, which produces new columns
// from the input by evaluating given expressions.
// It corresponds to `SELECT (expr...)` clause in SQL.
type ProjectionOp struct {
input LogicalPlan
exprs LogicalExprList
exprs []LogicalExpr
}

func (p *ProjectionOp) String() string {
return fmt.Sprintf("Projection: %s", p.exprs)
return fmt.Sprintf("Projection: %s", ppList(p.exprs))
}

func (p *ProjectionOp) Schema() *dt.Schema {
Expand Down Expand Up @@ -182,7 +186,16 @@ func (a *AggregateOp) Aggregate() []LogicalExpr {
}

func (a *AggregateOp) String() string {
return fmt.Sprintf("Aggregate: %s, %s", a.groupBy, a.aggregate)
output := "Aggregate: "
if len(a.groupBy) > 0 {
output += fmt.Sprintf("groupBy=[%s]", ppList(a.groupBy))
}
if len(a.aggregate) > 0 {
output += fmt.Sprintf("; aggr=[%s]", ppList(a.aggregate))

}

return output
}

// Schema returns groupBy fields and aggregate fields
Expand Down Expand Up @@ -285,15 +298,10 @@ func Limit(plan LogicalPlan, skip int, fetch int) LogicalPlan {
type SortOp struct {
input LogicalPlan
by []LogicalExpr
//asc bool
}

//func (s *SortOp) IsAsc() bool {
// return s.asc
//}

func (s *SortOp) String() string {
return fmt.Sprintf("Sort: %s", s.by)
return fmt.Sprintf("Sort: %s", ppList(s.by))
}

func (s *SortOp) Schema() *dt.Schema {
Expand All @@ -309,12 +317,10 @@ func (s *SortOp) Exprs() []LogicalExpr {
}

// Sort creates a sort logical plan.
// func Sort(plan LogicalPlan, by []LogicalExpr, asc bool) LogicalPlan {
func Sort(plan LogicalPlan, by []LogicalExpr) LogicalPlan {
return &SortOp{
input: plan,
by: by,
//asc: asc,
}
}

Expand Down
125 changes: 107 additions & 18 deletions internal/engine/cost/logical_plan/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,124 @@ package logical_plan_test
import (
"fmt"

"github.com/kwilteam/kwil-db/internal/engine/cost/datasource"
"github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan"
ds "github.com/kwilteam/kwil-db/internal/engine/cost/datasource"
dt "github.com/kwilteam/kwil-db/internal/engine/cost/datatypes"
lp "github.com/kwilteam/kwil-db/internal/engine/cost/logical_plan"
)

func ExampleLogicalPlan_String_selection() {
ds := datasource.NewMemDataSource(nil, nil)
plan := logical_plan.Scan("users", ds, nil)
//plan = logical_plan.Projection(plan, logical_plan.Column("", "username"), logical_plan.Column("", "age"))
plan = logical_plan.Projection(plan, logical_plan.ColumnUnqualified("username"), logical_plan.ColumnUnqualified("age"))
fmt.Println(logical_plan.Format(plan, 0))
var stubDS, _ = ds.NewCSVDataSource("../testdata/users.csv")
var stubTable = &dt.TableRef{Table: "users"}

func ExampleScanOp_String_no_filter() {
op := lp.Scan(stubTable, stubDS, nil, "username", "age")
fmt.Println(op.String())
// Output:
// Scan: users; projection=[username, age]
}

func ExampleScanOp_String_with_filter() {
op := lp.Scan(stubTable, stubDS,
[]lp.LogicalExpr{
lp.Gt(lp.ColumnUnqualified("age"),
lp.LiteralInt(20)),
lp.Lt(lp.ColumnUnqualified("age"),
lp.LiteralInt(30)),
}, "username", "age")
fmt.Println(op.String())
// Output:
// Scan: users; filter=[age > 20, age < 30]; projection=[username, age]
}

func ExampleProjectionOp_String() {
op := lp.Projection(
nil,
lp.ColumnUnqualified("username"),
lp.ColumnUnqualified("age"))
fmt.Println(op.String())
// Output:
// Projection: username, age
}

func ExampleFilterOp_String() {
op := lp.Filter(nil,
lp.Eq(
lp.ColumnUnqualified("age"),
lp.LiteralInt(20)))
fmt.Println(op.String())
// Output:
// Filter: age = 20
}

func ExampleAggregateOp_String() {
op := lp.Aggregate(
lp.Scan(stubTable, stubDS, nil),
[]lp.LogicalExpr{lp.ColumnUnqualified("state")},
[]lp.LogicalExpr{lp.Count(lp.ColumnUnqualified("username"))})
fmt.Println(op.String())
// Output:
// Aggregate: groupBy=[state]; aggr=[COUNT(username)]
}

func ExampleAggregateOp_String_without_groupby() {
op := lp.Aggregate(
lp.Scan(stubTable, stubDS, nil),
nil,
[]lp.LogicalExpr{lp.Count(lp.ColumnUnqualified("username"))})
fmt.Println(op.String())
// Output:
// Aggregate: ; aggr=[COUNT(username)]
}

func ExampleLimitOp_String_without_skip() {
op := lp.Limit(nil, 0, 10)
fmt.Println(op.String())
// Output:
// Limit: skip=0, fetch=10
}

func ExampleLimitOp_String_with_skip() {
op := lp.Limit(nil, 5, 10)
fmt.Println(op.String())
// Output:
// Limit: skip=5, fetch=10
}

func ExampleSortOp_String() {
op := lp.Sort(nil,
[]lp.LogicalExpr{
lp.SortExpr(lp.ColumnUnqualified("state"), false, true),
lp.SortExpr(lp.ColumnUnqualified("age"), true, false),
},
)
fmt.Println(op.String())
// Output:
// Sort: state DESC NULLS FIRST, age ASC NULLS LAST
}

func ExampleLogicalPlan_Projection() {
plan := lp.Scan(stubTable, stubDS, nil)
plan = lp.Projection(plan,
lp.ColumnUnqualified("username"),
lp.ColumnUnqualified("age"))
fmt.Println(lp.Format(plan, 0))
// Output:
// Projection: username, age
// Scan: users; projection=[]
// Scan: users
}

func ExampleLogicalPlan_DataFrame() {
ds := datasource.NewMemDataSource(nil, nil)
aop := logical_plan.NewDataFrame(logical_plan.Scan("users", ds, nil))
plan := aop.Filter(logical_plan.Eq(logical_plan.ColumnUnqualified("age"), logical_plan.LiteralInt(20))).
Aggregate([]logical_plan.LogicalExpr{logical_plan.ColumnUnqualified("state")},
[]logical_plan.LogicalExpr{logical_plan.Count(logical_plan.ColumnUnqualified("username"))}).
aop := lp.NewDataFrame(lp.Scan(stubTable, stubDS, nil))
plan := aop.Filter(lp.Eq(lp.ColumnUnqualified("age"), lp.LiteralInt(20))).
Aggregate([]lp.LogicalExpr{lp.ColumnUnqualified("state")},
[]lp.LogicalExpr{lp.Count(lp.ColumnUnqualified("username"))}).
// the alias for aggregate result is bit weird
Project(logical_plan.ColumnUnqualified("state"), logical_plan.Alias(logical_plan.Count(logical_plan.ColumnUnqualified("username")), "num")).
Project(lp.ColumnUnqualified("state"), lp.Alias(lp.Count(lp.ColumnUnqualified("username")), "num")).
LogicalPlan()

fmt.Println(logical_plan.Format(plan, 0))
fmt.Println(lp.Format(plan, 0))
// Output:
// Projection: state, COUNT(username) AS num
// Aggregate: [state], [COUNT(username)]
// Aggregate: groupBy=[state]; aggr=[COUNT(username)]
// Filter: age = 20
// Scan: users; projection=[]
// Scan: users
}
16 changes: 16 additions & 0 deletions internal/engine/cost/logical_plan/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,19 @@ func inferNullable(expr LogicalExpr, schema *ds.Schema) bool {
panic(fmt.Sprintf("unknown expression type %T", e))
}
}

// ppList returns a string representation of the given list.
func ppList[T any](l []T) string {
if len(l) == 0 {
return ""
}

str := ""
for i, e := range l {
str += fmt.Sprintf("%v", e)
if i < len(l)-1 {
str += ", "
}
}
return str
}
Loading

0 comments on commit 75e47aa

Please sign in to comment.