Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check named value parameter types #35

Merged
merged 6 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 74 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
"time"

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/cloudspannerecosystem/go-sql-spanner/internal"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const userAgent = "go-sql-spanner/0.1"
Expand Down Expand Up @@ -248,6 +250,74 @@ func (c *conn) IsValid() bool {
return !c.closed
}

func (c *conn) CheckNamedValue(value *driver.NamedValue) error {
if value == nil {
return nil
}
switch t := value.Value.(type) {
default:
// Default is to fail, unless it is one of the following supported types.
return spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "unsupported value type: %v", t))
case nil:
case sql.NullInt64:
case sql.NullTime:
case sql.NullString:
case sql.NullFloat64:
case sql.NullBool:
case sql.NullInt32:
case string:
case spanner.NullString:
case []string:
case []spanner.NullString:
case *string:
case []*string:
case []byte:
case [][]byte:
case int:
case []int:
case int64:
case []int64:
case spanner.NullInt64:
case []spanner.NullInt64:
case *int64:
case []*int64:
case bool:
case []bool:
case spanner.NullBool:
case []spanner.NullBool:
case *bool:
case []*bool:
case float64:
case []float64:
case spanner.NullFloat64:
case []spanner.NullFloat64:
case *float64:
case []*float64:
case big.Rat:
case []big.Rat:
case spanner.NullNumeric:
case []spanner.NullNumeric:
case *big.Rat:
case []*big.Rat:
case time.Time:
case []time.Time:
case spanner.NullTime:
case []spanner.NullTime:
case *time.Time:
case []*time.Time:
case civil.Date:
case []civil.Date:
case spanner.NullDate:
case []spanner.NullDate:
case *civil.Date:
case []*civil.Date:
case spanner.NullJSON:
case []spanner.NullJSON:
case spanner.GenericColumnValue:
}
return nil
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
Expand Down
260 changes: 145 additions & 115 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,129 +707,159 @@ func TestQueryWithNullParameters(t *testing.T) {
t.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(
context.Background(),
nil, // bool
nil, // string
nil, // bytes
nil, // int64
nil, // float64
nil, // numeric
nil, // date
nil, // timestamp
nil, // json
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()

for rows.Next() {
var b sql.NullBool
var s sql.NullString
var bt []byte
var i sql.NullInt64
var f sql.NullFloat64
var r spanner.NullNumeric // There's no equivalent sql type.
var d spanner.NullDate // There's no equivalent sql type.
var ts sql.NullTime
var j spanner.NullJSON // There's no equivalent sql type.
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
var iArray []spanner.NullInt64
var fArray []spanner.NullFloat64
var rArray []spanner.NullNumeric
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &fArray, &rArray, &dArray, &tsArray, &jArray)
for _, p := range []struct {
typed int
values []interface{}
}{
{
typed: 0,
values: []interface{}{
nil, // bool
nil, // string
nil, // bytes
nil, // int64
nil, // float64
nil, // numeric
nil, // date
nil, // timestamp
nil, // json
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
}},
{
typed: 8,
values: []interface{}{
spanner.NullBool{},
spanner.NullString{},
nil, // bytes
spanner.NullInt64{},
spanner.NullFloat64{},
spanner.NullNumeric{},
spanner.NullDate{},
spanner.NullTime{},
spanner.NullJSON{},
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
}},
} {
rows, err := stmt.QueryContext(context.Background(), p.values...)
if err != nil {
t.Fatal(err)
}
if b.Valid {
t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", b, spanner.NullBool{})
}
if s.Valid {
t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", s, spanner.NullString{})
}
if bt != nil {
t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", bt, nil)
}
if i.Valid {
t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", i, spanner.NullInt64{})
}
if f.Valid {
t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", f, spanner.NullFloat64{})
}
if r.Valid {
t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", r, spanner.NullNumeric{})
}
if d.Valid {
t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", d, spanner.NullDate{})
}
if ts.Valid {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", ts, spanner.NullTime{})
}
if j.Valid {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", j, spanner.NullJSON{})
}
if bArray != nil {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", bArray, nil)
}
if sArray != nil {
t.Errorf("row value mismatch for string array\nGot: %v\nWant: %v", sArray, nil)
}
if btArray != nil {
t.Errorf("row value mismatch for bytes array array\nGot: %v\nWant: %v", btArray, nil)
}
if iArray != nil {
t.Errorf("row value mismatch for int64 array\nGot: %v\nWant: %v", iArray, nil)
}
if fArray != nil {
t.Errorf("row value mismatch for float64 array\nGot: %v\nWant: %v", fArray, nil)
defer rows.Close()

for rows.Next() {
var b sql.NullBool
var s sql.NullString
var bt []byte
var i sql.NullInt64
var f sql.NullFloat64
var r spanner.NullNumeric // There's no equivalent sql type.
var d spanner.NullDate // There's no equivalent sql type.
var ts sql.NullTime
var j spanner.NullJSON // There's no equivalent sql type.
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
var iArray []spanner.NullInt64
var fArray []spanner.NullFloat64
var rArray []spanner.NullNumeric
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &fArray, &rArray, &dArray, &tsArray, &jArray)
if err != nil {
t.Fatal(err)
}
if b.Valid {
t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", b, spanner.NullBool{})
}
if s.Valid {
t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", s, spanner.NullString{})
}
if bt != nil {
t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", bt, nil)
}
if i.Valid {
t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", i, spanner.NullInt64{})
}
if f.Valid {
t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", f, spanner.NullFloat64{})
}
if r.Valid {
t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", r, spanner.NullNumeric{})
}
if d.Valid {
t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", d, spanner.NullDate{})
}
if ts.Valid {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", ts, spanner.NullTime{})
}
if j.Valid {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", j, spanner.NullJSON{})
}
if bArray != nil {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", bArray, nil)
}
if sArray != nil {
t.Errorf("row value mismatch for string array\nGot: %v\nWant: %v", sArray, nil)
}
if btArray != nil {
t.Errorf("row value mismatch for bytes array array\nGot: %v\nWant: %v", btArray, nil)
}
if iArray != nil {
t.Errorf("row value mismatch for int64 array\nGot: %v\nWant: %v", iArray, nil)
}
if fArray != nil {
t.Errorf("row value mismatch for float64 array\nGot: %v\nWant: %v", fArray, nil)
}
if rArray != nil {
t.Errorf("row value mismatch for numeric array\nGot: %v\nWant: %v", rArray, nil)
}
if dArray != nil {
t.Errorf("row value mismatch for date array\nGot: %v\nWant: %v", dArray, nil)
}
if tsArray != nil {
t.Errorf("row value mismatch for timestamp array\nGot: %v\nWant: %v", tsArray, nil)
}
if jArray != nil {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", jArray, nil)
}
}
if rArray != nil {
t.Errorf("row value mismatch for numeric array\nGot: %v\nWant: %v", rArray, nil)
if rows.Err() != nil {
t.Fatal(rows.Err())
}
if dArray != nil {
t.Errorf("row value mismatch for date array\nGot: %v\nWant: %v", dArray, nil)
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
if tsArray != nil {
t.Errorf("row value mismatch for timestamp array\nGot: %v\nWant: %v", tsArray, nil)
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
// The param types map should be empty when we are only sending untyped nil params.
if g, w := len(req.ParamTypes), p.typed; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if jArray != nil {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", jArray, nil)
if g, w := len(req.Params.Fields), 18; g != w {
t.Fatalf("params length mismatch\nGot: %v\nWant: %v", g, w)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
// The param types map should be empty, as we are only sending nil params.
if g, w := len(req.ParamTypes), 0; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if g, w := len(req.Params.Fields), 18; g != w {
t.Fatalf("params length mismatch\nGot: %v\nWant: %v", g, w)
}
for _, param := range req.Params.Fields {
if _, ok := param.GetKind().(*proto3.Value_NullValue); !ok {
t.Errorf("param value mismatch\nGot: %v\nWant: %v", param.GetKind(), proto3.Value_NullValue{})
for _, param := range req.Params.Fields {
if _, ok := param.GetKind().(*proto3.Value_NullValue); !ok {
t.Errorf("param value mismatch\nGot: %v\nWant: %v", param.GetKind(), proto3.Value_NullValue{})
}
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
return &rows{it: it}, nil
}

func (s *stmt) CheckNamedValue(value *driver.NamedValue) error {
return nil
}

func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) {
names, err := internal.ParseNamedParameters(q)
if err != nil {
Expand Down