From 62cd6ec987c760a76c4b86f2c5183c1e27d08d94 Mon Sep 17 00:00:00 2001 From: Egon Elbre Date: Wed, 7 Aug 2024 12:10:24 +0300 Subject: [PATCH] feat: support structs in queries cloud.google.com/go/spanner supports passing structs as arguments to queries. Fixes #281 --- driver.go | 36 ++++++++++++++++++ examples/struct-types/main.go | 71 +++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 examples/struct-types/main.go diff --git a/driver.go b/driver.go index 36efcc74..9f63b424 100644 --- a/driver.go +++ b/driver.go @@ -20,6 +20,7 @@ import ( "database/sql/driver" "fmt" "math/big" + "reflect" "regexp" "strconv" "strings" @@ -36,6 +37,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" ) const userAgent = "go-sql-spanner/1.0.2" @@ -748,6 +751,10 @@ func (c *conn) IsValid() bool { func checkIsValidType(v driver.Value) bool { switch v.(type) { default: + // google-cloud-go/spanner knows how to deal with these + if isStructOrArrayOfStructValue(v) || isAnArrayOfProtoColumn(v) { + return true + } return false case nil: case sql.NullInt64: @@ -1052,3 +1059,32 @@ func (c *conn) createPartitionedDmlQueryOptions() spanner.QueryOptions { defer func() { c.excludeTxnFromChangeStreams = false }() return spanner.QueryOptions{ExcludeTxnFromChangeStreams: c.excludeTxnFromChangeStreams} } + +/* The following is the same implementation as in google-cloud-go/spanner */ + +func isStructOrArrayOfStructValue(v interface{}) bool { + typ := reflect.TypeOf(v) + if typ.Kind() == reflect.Slice { + typ = typ.Elem() + } + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + return typ.Kind() == reflect.Struct +} + +func isAnArrayOfProtoColumn(v interface{}) bool { + typ := reflect.TypeOf(v) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() == reflect.Slice { + typ = typ.Elem() + } + return typ.Implements(protoMsgReflectType) || typ.Implements(protoEnumReflectType) +} + +var ( + protoMsgReflectType = reflect.TypeOf((*proto.Message)(nil)).Elem() + protoEnumReflectType = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() +) diff --git a/examples/struct-types/main.go b/examples/struct-types/main.go new file mode 100644 index 00000000..a5240fd6 --- /dev/null +++ b/examples/struct-types/main.go @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "database/sql" + "fmt" + + _ "github.com/googleapis/go-sql-spanner" + "github.com/googleapis/go-sql-spanner/examples" +) + +// Example for executing a query with struct arguments described in: +// +// * https://cloud.google.com/spanner/docs/structs +// * https://pkg.go.dev/cloud.google.com/go/spanner#hdr-Structs +func structTypes(projectId, instanceId, databaseId string) error { + ctx := context.Background() + db, err := sql.Open("spanner", fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, databaseId)) + if err != nil { + return fmt.Errorf("failed to open database connection: %v\n", err) + } + defer db.Close() + + type Entry struct { + ID int64 + Name string + } + + entries := []Entry{ + {ID: 0, Name: "Hello"}, + {ID: 1, Name: "World"}, + } + + rows, err := db.QueryContext(ctx, "SELECT id, name FROM UNNEST(@entries)", entries) + if err != nil { + return fmt.Errorf("failed to execute query: %v", err) + } + defer rows.Close() + + for rows.Next() { + var id int64 + var name string + + if err := rows.Scan(&id, &name); err != nil { + return fmt.Errorf("failed to scan row values: %v", err) + } + fmt.Printf("%v %v\n", id, name) + } + if err := rows.Err(); err != nil { + return fmt.Errorf("failed to execute query: %v", err) + } + return nil +} + +func main() { + examples.RunSampleOnEmulator(structTypes) +}