Skip to content

Commit

Permalink
[sqle] Allow bindvars in most function constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman committed Mar 13, 2024
1 parent fc9dd6b commit 2b2abcb
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 47 deletions.
9 changes: 5 additions & 4 deletions go/libraries/doltcore/sqle/dolt_diff_stat_table_function.go
Expand Up @@ -17,6 +17,7 @@ package sqle
import (
"errors"
"fmt"
expression2 "github.com/dolthub/go-mysql-server/sql/expression"
"io"
"math"
"strings"
Expand Down Expand Up @@ -246,20 +247,20 @@ func (ds *DiffStatTableFunction) WithExpressions(expression ...sql.Expression) (

// validate the expressions
if newDstf.dotCommitExpr != nil {
if !types.IsText(newDstf.dotCommitExpr.Type()) {
if !types.IsText(newDstf.dotCommitExpr.Type()) && !expression2.IsBindVar(newDstf.dotCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.dotCommitExpr.String())
}
} else {
if !types.IsText(newDstf.fromCommitExpr.Type()) {
if !types.IsText(newDstf.fromCommitExpr.Type()) && !expression2.IsBindVar(newDstf.fromCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.fromCommitExpr.String())
}
if !types.IsText(newDstf.toCommitExpr.Type()) {
if !types.IsText(newDstf.toCommitExpr.Type()) && !expression2.IsBindVar(newDstf.toCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.toCommitExpr.String())
}
}

if newDstf.tableNameExpr != nil {
if !types.IsText(newDstf.tableNameExpr.Type()) {
if !types.IsText(newDstf.tableNameExpr.Type()) && !expression2.IsBindVar(newDstf.tableNameExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.tableNameExpr.String())
}
}
Expand Down
Expand Up @@ -16,6 +16,7 @@ package sqle

import (
"fmt"
expression2 "github.com/dolthub/go-mysql-server/sql/expression"
"io"
"sort"
"strings"
Expand Down Expand Up @@ -235,20 +236,20 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression

// validate the expressions
if newDstf.dotCommitExpr != nil {
if !types.IsText(newDstf.dotCommitExpr.Type()) {
if !types.IsText(newDstf.dotCommitExpr.Type()) && !expression2.IsBindVar(newDstf.dotCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.dotCommitExpr.String())
}
} else {
if !types.IsText(newDstf.fromCommitExpr.Type()) {
if !types.IsText(newDstf.fromCommitExpr.Type()) && !expression2.IsBindVar(newDstf.fromCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.fromCommitExpr.String())
}
if !types.IsText(newDstf.toCommitExpr.Type()) {
if !types.IsText(newDstf.toCommitExpr.Type()) && !expression2.IsBindVar(newDstf.toCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.toCommitExpr.String())
}
}

if newDstf.tableNameExpr != nil {
if !types.IsText(newDstf.tableNameExpr.Type()) {
if !types.IsText(newDstf.tableNameExpr.Type()) && !expression2.IsBindVar(newDstf.tableNameExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.tableNameExpr.String())
}
}
Expand Down
7 changes: 5 additions & 2 deletions go/libraries/doltcore/sqle/dolt_diff_table_function.go
Expand Up @@ -364,8 +364,11 @@ func (dtf *DiffTableFunction) CheckPrivileges(ctx *sql.Context, opChecker sql.Pr
sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select))
}

// evaluateArguments evaluates the argument expressions to turn them into values this DiffTableFunction
// can use. Note that this method only evals the expressions, and doesn't validate the values.
// evaluateArguments evaluates the argument expressions to turn them into
// values this DiffTableFunction can use. Note that this method only evals
// the expressions, and doesn't validate the values.
// TODO: evaluating expression arguments during binding is incompatible
// with prepared statement support.
func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, interface{}, string, error) {
if !dtf.Resolved() {
return nil, nil, nil, "", nil
Expand Down
40 changes: 20 additions & 20 deletions go/libraries/doltcore/sqle/dolt_patch_table_function.go
Expand Up @@ -350,12 +350,12 @@ func (p *PatchTableFunction) Expressions() []sql.Expression {
}

// WithExpressions implements the sql.Expressioner interface.
func (p *PatchTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 1 {
return nil, sql.ErrInvalidArgumentNumber.New(p.Name(), "1 to 3", len(expression))
func (p *PatchTableFunction) WithExpressions(expr ...sql.Expression) (sql.Node, error) {
if len(expr) < 1 {
return nil, sql.ErrInvalidArgumentNumber.New(p.Name(), "1 to 3", len(expr))
}

for _, expr := range expression {
for _, expr := range expr {
if !expr.Resolved() {
return nil, ErrInvalidNonLiteralArgument.New(p.Name(), expr.String())
}
Expand All @@ -366,41 +366,41 @@ func (p *PatchTableFunction) WithExpressions(expression ...sql.Expression) (sql.
}

newPtf := *p
if strings.Contains(expression[0].String(), "..") {
if len(expression) < 1 || len(expression) > 2 {
return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "1 or 2", len(expression))
if strings.Contains(expr[0].String(), "..") {
if len(expr) < 1 || len(expr) > 2 {
return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "1 or 2", len(expr))
}
newPtf.dotCommitExpr = expression[0]
if len(expression) == 2 {
newPtf.tableNameExpr = expression[1]
newPtf.dotCommitExpr = expr[0]
if len(expr) == 2 {
newPtf.tableNameExpr = expr[1]
}
} else {
if len(expression) < 2 || len(expression) > 3 {
return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "2 or 3", len(expression))
if len(expr) < 2 || len(expr) > 3 {
return nil, sql.ErrInvalidArgumentNumber.New(newPtf.Name(), "2 or 3", len(expr))
}
newPtf.fromCommitExpr = expression[0]
newPtf.toCommitExpr = expression[1]
if len(expression) == 3 {
newPtf.tableNameExpr = expression[2]
newPtf.fromCommitExpr = expr[0]
newPtf.toCommitExpr = expr[1]
if len(expr) == 3 {
newPtf.tableNameExpr = expr[2]
}
}

// validate the expressions
if newPtf.dotCommitExpr != nil {
if !sqltypes.IsText(newPtf.dotCommitExpr.Type()) {
if !sqltypes.IsText(newPtf.dotCommitExpr.Type()) && !expression.IsBindVar(newPtf.dotCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.dotCommitExpr.String())
}
} else {
if !sqltypes.IsText(newPtf.fromCommitExpr.Type()) {
if !sqltypes.IsText(newPtf.fromCommitExpr.Type()) && !expression.IsBindVar(newPtf.fromCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.fromCommitExpr.String())
}
if !sqltypes.IsText(newPtf.toCommitExpr.Type()) {
if !sqltypes.IsText(newPtf.toCommitExpr.Type()) && !expression.IsBindVar(newPtf.toCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.toCommitExpr.String())
}
}

if newPtf.tableNameExpr != nil {
if !sqltypes.IsText(newPtf.tableNameExpr.Type()) {
if !sqltypes.IsText(newPtf.tableNameExpr.Type()) && !expression.IsBindVar(newPtf.tableNameExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newPtf.Name(), newPtf.tableNameExpr.String())
}
}
Expand Down
9 changes: 5 additions & 4 deletions go/libraries/doltcore/sqle/dolt_schema_diff_table_function.go
Expand Up @@ -16,6 +16,7 @@ package sqle

import (
"fmt"
expression2 "github.com/dolthub/go-mysql-server/sql/expression"
"io"
"sort"
"strings"
Expand Down Expand Up @@ -241,19 +242,19 @@ func (ds *SchemaDiffTableFunction) WithExpressions(expression ...sql.Expression)

// validate the expressions
if newDstf.dotCommitExpr != nil {
if !types.IsText(newDstf.dotCommitExpr.Type()) {
if !types.IsText(newDstf.dotCommitExpr.Type()) && !expression2.IsBindVar(newDstf.dotCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.dotCommitExpr.String())
}
} else {
if !types.IsText(newDstf.fromCommitExpr.Type()) {
if !types.IsText(newDstf.fromCommitExpr.Type()) && !expression2.IsBindVar(newDstf.fromCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.fromCommitExpr.String())
}
if !types.IsText(newDstf.toCommitExpr.Type()) {
if !types.IsText(newDstf.toCommitExpr.Type()) && !expression2.IsBindVar(newDstf.toCommitExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.toCommitExpr.String())
}
}

if newDstf.tableNameExpr != nil && !types.IsText(newDstf.tableNameExpr.Type()) {
if newDstf.tableNameExpr != nil && !types.IsText(newDstf.tableNameExpr.Type()) && !expression2.IsBindVar(newDstf.tableNameExpr) {
return nil, sql.ErrInvalidArgumentDetails.New(newDstf.Name(), newDstf.tableNameExpr.String())
}

Expand Down
59 changes: 46 additions & 13 deletions go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go
Expand Up @@ -141,26 +141,59 @@ func TestSchemaOverrides(t *testing.T) {

// Convenience test for debugging a single query. Unskip and set to the desired query.
func TestSingleScript(t *testing.T) {
t.Skip()
//t.Skip()

var scripts = []queries.ScriptTest{
{
Name: "physical columns added after virtual one",
Name: "prepared table functions",
SetUpScript: []string{
"create table t (pk int primary key, col1 int as (pk + 1));",
"insert into t (pk) values (1), (3)",
"alter table t add index idx1 (col1, pk);",
"alter table t add index idx2 (col1);",
"alter table t add column col2 int;",
"alter table t add column col3 int;",
"insert into t (pk, col2, col3) values (2, 4, 5);",
"create table t1 (a int primary key)",
"insert into t1 values (0), (1)",
"call dolt_add('.');",
"set @Commit0 = '';",
"call dolt_commit_hash_out(@Commit0, '-am', 'commit 0');",
//
"alter table t1 add column b int default 1",
"call dolt_add('.');",
"set @Commit1 = '';",
"call dolt_commit_hash_out(@Commit1, '-am', 'commit 1');",
//
"create table t2 (a int primary key)",
"insert into t2 values (0), (1)",
"insert into t1 values (2,2), (3,2)",
"call dolt_add('.');",
"set @Commit2 = '';",
"call dolt_commit_hash_out(@Commit2, '-am', 'commit 2');",
//
"prepare sch_diff from 'select count(*) from dolt_schema_diff(?,?,?)'",
"prepare diff_stat from 'select count(*) from dolt_diff_stat(?,?,?)'",
"prepare diff_sum from 'select count(*) from dolt_diff_summary(?,?,?)'",
//"prepare table_diff from 'select * from dolt_diff(?,?,?)'",
"prepare patch from 'select count(*) from dolt_schema_diff(?,?,?)'",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select * from t where col1 = 2",
Expected: []sql.Row{
{1, 2, nil, nil},
},
Query: "set @t1_name = 't1';",
},
{
Query: "execute sch_diff using @Commit0, @Commit1, @t1_name",
Expected: []sql.Row{{1}},
},
{
Query: "execute diff_stat using @Commit1, @Commit2, @t1_name",
Expected: []sql.Row{{1}},
},
{
Query: "execute diff_sum using @Commit1, @Commit2, @t1_name",
Expected: []sql.Row{{1}},
},
//{
// Query: "execute table_diff using @Commit2, @Commit2, @t1_name",
// Expected: []sql.Row{},
//},
{
Query: "execute patch using @Commit0, @Commit1, @t1_name",
Expected: []sql.Row{{1}},
},
},
},
Expand Down
53 changes: 53 additions & 0 deletions go/libraries/doltcore/sqle/enginetest/dolt_queries.go
Expand Up @@ -721,6 +721,59 @@ var DoltScripts = []queries.ScriptTest{
},
},
},
{
Name: "prepared table functions",
SetUpScript: []string{
"create table t1 (a int primary key)",
"insert into t1 values (0), (1)",
"call dolt_add('.');",
"set @Commit0 = '';",
"call dolt_commit_hash_out(@Commit0, '-am', 'commit 0');",
//
"alter table t1 add column b int default 1",
"call dolt_add('.');",
"set @Commit1 = '';",
"call dolt_commit_hash_out(@Commit1, '-am', 'commit 1');",
//
"create table t2 (a int primary key)",
"insert into t2 values (0), (1)",
"insert into t1 values (2,2), (3,2)",
"call dolt_add('.');",
"set @Commit2 = '';",
"call dolt_commit_hash_out(@Commit2, '-am', 'commit 2');",
//
"prepare sch_diff from 'select count(*) from dolt_schema_diff(?,?,?)'",
"prepare diff_stat from 'select count(*) from dolt_diff_stat(?,?,?)'",
"prepare diff_sum from 'select count(*) from dolt_diff_summary(?,?,?)'",
//"prepare table_diff from 'select * from dolt_diff(?,?,?)'",
"prepare patch from 'select count(*) from dolt_schema_diff(?,?,?)'",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "set @t1_name = 't1';",
},
{
Query: "execute sch_diff using @Commit0, @Commit1, @t1_name",
Expected: []sql.Row{{1}},
},
{
Query: "execute diff_stat using @Commit1, @Commit2, @t1_name",
Expected: []sql.Row{{1}},
},
{
Query: "execute diff_sum using @Commit1, @Commit2, @t1_name",
Expected: []sql.Row{{1}},
},
//{
// Query: "execute table_diff using @Commit2, @Commit2, @t1_name",
// Expected: []sql.Row{},
//},
{
Query: "execute patch using @Commit0, @Commit1, @t1_name",
Expected: []sql.Row{{1}},
},
},
},
{
Name: "test has_ancestor",
SetUpScript: []string{
Expand Down

0 comments on commit 2b2abcb

Please sign in to comment.