Skip to content

Commit

Permalink
fix: check named value parameter types (#35)
Browse files Browse the repository at this point in the history
* fix: check named value parameter types

Move the implementation of NamedValueChecker from Stmt to Conn so it
is always applied. Also, the check will always be the same, and will
not be different per statement. This will allow queries that use Spanner
specific types, such as spanner.NullDate to be executed.

* fix: include sql.Null* types
  • Loading branch information
olavloite committed Sep 6, 2021
1 parent 80b5d3f commit f260dd2
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 123 deletions.
78 changes: 74 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
"time"

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner"
adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/cloudspannerecosystem/go-sql-spanner/internal"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const userAgent = "go-sql-spanner/0.1"
Expand Down Expand Up @@ -248,6 +250,74 @@ func (c *conn) IsValid() bool {
return !c.closed
}

func (c *conn) CheckNamedValue(value *driver.NamedValue) error {
if value == nil {
return nil
}
switch t := value.Value.(type) {
default:
// Default is to fail, unless it is one of the following supported types.
return spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "unsupported value type: %v", t))
case nil:
case sql.NullInt64:
case sql.NullTime:
case sql.NullString:
case sql.NullFloat64:
case sql.NullBool:
case sql.NullInt32:
case string:
case spanner.NullString:
case []string:
case []spanner.NullString:
case *string:
case []*string:
case []byte:
case [][]byte:
case int:
case []int:
case int64:
case []int64:
case spanner.NullInt64:
case []spanner.NullInt64:
case *int64:
case []*int64:
case bool:
case []bool:
case spanner.NullBool:
case []spanner.NullBool:
case *bool:
case []*bool:
case float64:
case []float64:
case spanner.NullFloat64:
case []spanner.NullFloat64:
case *float64:
case []*float64:
case big.Rat:
case []big.Rat:
case spanner.NullNumeric:
case []spanner.NullNumeric:
case *big.Rat:
case []*big.Rat:
case time.Time:
case []time.Time:
case spanner.NullTime:
case []spanner.NullTime:
case *time.Time:
case []*time.Time:
case civil.Date:
case []civil.Date:
case spanner.NullDate:
case []spanner.NullDate:
case *civil.Date:
case []*civil.Date:
case spanner.NullJSON:
case []spanner.NullJSON:
case spanner.GenericColumnValue:
}
return nil
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
Expand Down
260 changes: 145 additions & 115 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,129 +707,159 @@ func TestQueryWithNullParameters(t *testing.T) {
t.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(
context.Background(),
nil, // bool
nil, // string
nil, // bytes
nil, // int64
nil, // float64
nil, // numeric
nil, // date
nil, // timestamp
nil, // json
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
)
if err != nil {
t.Fatal(err)
}
defer rows.Close()

for rows.Next() {
var b sql.NullBool
var s sql.NullString
var bt []byte
var i sql.NullInt64
var f sql.NullFloat64
var r spanner.NullNumeric // There's no equivalent sql type.
var d spanner.NullDate // There's no equivalent sql type.
var ts sql.NullTime
var j spanner.NullJSON // There's no equivalent sql type.
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
var iArray []spanner.NullInt64
var fArray []spanner.NullFloat64
var rArray []spanner.NullNumeric
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &fArray, &rArray, &dArray, &tsArray, &jArray)
for _, p := range []struct {
typed int
values []interface{}
}{
{
typed: 0,
values: []interface{}{
nil, // bool
nil, // string
nil, // bytes
nil, // int64
nil, // float64
nil, // numeric
nil, // date
nil, // timestamp
nil, // json
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
}},
{
typed: 8,
values: []interface{}{
spanner.NullBool{},
spanner.NullString{},
nil, // bytes
spanner.NullInt64{},
spanner.NullFloat64{},
spanner.NullNumeric{},
spanner.NullDate{},
spanner.NullTime{},
spanner.NullJSON{},
nil, // bool array
nil, // string array
nil, // bytes array
nil, // int64 array
nil, // float64 array
nil, // numeric array
nil, // date array
nil, // timestamp array
nil, // json array
}},
} {
rows, err := stmt.QueryContext(context.Background(), p.values...)
if err != nil {
t.Fatal(err)
}
if b.Valid {
t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", b, spanner.NullBool{})
}
if s.Valid {
t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", s, spanner.NullString{})
}
if bt != nil {
t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", bt, nil)
}
if i.Valid {
t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", i, spanner.NullInt64{})
}
if f.Valid {
t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", f, spanner.NullFloat64{})
}
if r.Valid {
t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", r, spanner.NullNumeric{})
}
if d.Valid {
t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", d, spanner.NullDate{})
}
if ts.Valid {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", ts, spanner.NullTime{})
}
if j.Valid {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", j, spanner.NullJSON{})
}
if bArray != nil {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", bArray, nil)
}
if sArray != nil {
t.Errorf("row value mismatch for string array\nGot: %v\nWant: %v", sArray, nil)
}
if btArray != nil {
t.Errorf("row value mismatch for bytes array array\nGot: %v\nWant: %v", btArray, nil)
}
if iArray != nil {
t.Errorf("row value mismatch for int64 array\nGot: %v\nWant: %v", iArray, nil)
}
if fArray != nil {
t.Errorf("row value mismatch for float64 array\nGot: %v\nWant: %v", fArray, nil)
defer rows.Close()

for rows.Next() {
var b sql.NullBool
var s sql.NullString
var bt []byte
var i sql.NullInt64
var f sql.NullFloat64
var r spanner.NullNumeric // There's no equivalent sql type.
var d spanner.NullDate // There's no equivalent sql type.
var ts sql.NullTime
var j spanner.NullJSON // There's no equivalent sql type.
var bArray []spanner.NullBool
var sArray []spanner.NullString
var btArray [][]byte
var iArray []spanner.NullInt64
var fArray []spanner.NullFloat64
var rArray []spanner.NullNumeric
var dArray []spanner.NullDate
var tsArray []spanner.NullTime
var jArray []spanner.NullJSON
err = rows.Scan(&b, &s, &bt, &i, &f, &r, &d, &ts, &j, &bArray, &sArray, &btArray, &iArray, &fArray, &rArray, &dArray, &tsArray, &jArray)
if err != nil {
t.Fatal(err)
}
if b.Valid {
t.Errorf("row value mismatch for bool\nGot: %v\nWant: %v", b, spanner.NullBool{})
}
if s.Valid {
t.Errorf("row value mismatch for string\nGot: %v\nWant: %v", s, spanner.NullString{})
}
if bt != nil {
t.Errorf("row value mismatch for bytes\nGot: %v\nWant: %v", bt, nil)
}
if i.Valid {
t.Errorf("row value mismatch for int64\nGot: %v\nWant: %v", i, spanner.NullInt64{})
}
if f.Valid {
t.Errorf("row value mismatch for float64\nGot: %v\nWant: %v", f, spanner.NullFloat64{})
}
if r.Valid {
t.Errorf("row value mismatch for numeric\nGot: %v\nWant: %v", r, spanner.NullNumeric{})
}
if d.Valid {
t.Errorf("row value mismatch for date\nGot: %v\nWant: %v", d, spanner.NullDate{})
}
if ts.Valid {
t.Errorf("row value mismatch for timestamp\nGot: %v\nWant: %v", ts, spanner.NullTime{})
}
if j.Valid {
t.Errorf("row value mismatch for json\nGot: %v\nWant: %v", j, spanner.NullJSON{})
}
if bArray != nil {
t.Errorf("row value mismatch for bool array\nGot: %v\nWant: %v", bArray, nil)
}
if sArray != nil {
t.Errorf("row value mismatch for string array\nGot: %v\nWant: %v", sArray, nil)
}
if btArray != nil {
t.Errorf("row value mismatch for bytes array array\nGot: %v\nWant: %v", btArray, nil)
}
if iArray != nil {
t.Errorf("row value mismatch for int64 array\nGot: %v\nWant: %v", iArray, nil)
}
if fArray != nil {
t.Errorf("row value mismatch for float64 array\nGot: %v\nWant: %v", fArray, nil)
}
if rArray != nil {
t.Errorf("row value mismatch for numeric array\nGot: %v\nWant: %v", rArray, nil)
}
if dArray != nil {
t.Errorf("row value mismatch for date array\nGot: %v\nWant: %v", dArray, nil)
}
if tsArray != nil {
t.Errorf("row value mismatch for timestamp array\nGot: %v\nWant: %v", tsArray, nil)
}
if jArray != nil {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", jArray, nil)
}
}
if rArray != nil {
t.Errorf("row value mismatch for numeric array\nGot: %v\nWant: %v", rArray, nil)
if rows.Err() != nil {
t.Fatal(rows.Err())
}
if dArray != nil {
t.Errorf("row value mismatch for date array\nGot: %v\nWant: %v", dArray, nil)
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
if tsArray != nil {
t.Errorf("row value mismatch for timestamp array\nGot: %v\nWant: %v", tsArray, nil)
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
// The param types map should be empty when we are only sending untyped nil params.
if g, w := len(req.ParamTypes), p.typed; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if jArray != nil {
t.Errorf("row value mismatch for json array\nGot: %v\nWant: %v", jArray, nil)
if g, w := len(req.Params.Fields), 18; g != w {
t.Fatalf("params length mismatch\nGot: %v\nWant: %v", g, w)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}
requests := drainRequestsFromServer(server.TestSpanner)
sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{}))
if g, w := len(sqlRequests), 1; g != w {
t.Fatalf("sql requests count mismatch\nGot: %v\nWant: %v", g, w)
}
req := sqlRequests[0].(*sppb.ExecuteSqlRequest)
// The param types map should be empty, as we are only sending nil params.
if g, w := len(req.ParamTypes), 0; g != w {
t.Fatalf("param types length mismatch\nGot: %v\nWant: %v", g, w)
}
if g, w := len(req.Params.Fields), 18; g != w {
t.Fatalf("params length mismatch\nGot: %v\nWant: %v", g, w)
}
for _, param := range req.Params.Fields {
if _, ok := param.GetKind().(*proto3.Value_NullValue); !ok {
t.Errorf("param value mismatch\nGot: %v\nWant: %v", param.GetKind(), proto3.Value_NullValue{})
for _, param := range req.Params.Fields {
if _, ok := param.GetKind().(*proto3.Value_NullValue); !ok {
t.Errorf("param value mismatch\nGot: %v\nWant: %v", param.GetKind(), proto3.Value_NullValue{})
}
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
return &rows{it: it}, nil
}

func (s *stmt) CheckNamedValue(value *driver.NamedValue) error {
return nil
}

func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement, error) {
names, err := internal.ParseNamedParameters(q)
if err != nil {
Expand Down

0 comments on commit f260dd2

Please sign in to comment.