Skip to content

Commit

Permalink
feat(spanner): add SelectAll method to decode from Spanner iterator.R…
Browse files Browse the repository at this point in the history
…ows to golang struct (#9206)

* feat(spanner): add SelectAll method to decode from Spanner iterator.Rows to golang struct

* fix go vet

* incorporate suggestions

* preallocate if returned rows count is known

* fix go vet

* incorporate suggestions

* allocate when  rowsReturned is lowerbound

* incorporate changes and add benchmark to compare test runs for 5 fields struct

* incorporate suggestions
  • Loading branch information
rahul2393 committed Jan 18, 2024
1 parent 00b9900 commit 802088f
Show file tree
Hide file tree
Showing 8 changed files with 835 additions and 4 deletions.
1 change: 1 addition & 0 deletions spanner/go.mod
Expand Up @@ -41,6 +41,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect
go.opentelemetry.io/otel v1.21.0 // indirect
Expand Down
1 change: 1 addition & 0 deletions spanner/go.sum
Expand Up @@ -87,6 +87,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
95 changes: 95 additions & 0 deletions spanner/mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions spanner/read.go
Expand Up @@ -90,6 +90,13 @@ func streamWithReplaceSessionFunc(
}
}

// rowIterator is an interface for iterating over Rows.
type rowIterator interface {
Next() (*Row, error)
Do(f func(r *Row) error) error
Stop()
}

// RowIterator is an iterator over Rows.
type RowIterator struct {
// The plan for the query. Available after RowIterator.Next returns
Expand Down Expand Up @@ -121,6 +128,9 @@ type RowIterator struct {
sawStats bool
}

// this is for safety from future changes to RowIterator making sure that it implements rowIterator interface.
var _ rowIterator = (*RowIterator)(nil)

// Next returns the next result. Its second return value is iterator.Done if
// there are no more results. Once Next returns Done, all subsequent calls
// will return Done.
Expand Down
184 changes: 184 additions & 0 deletions spanner/row.go
Expand Up @@ -249,6 +249,18 @@ func errColNotFound(n string) error {
return spannerErrorf(codes.NotFound, "column %q not found", n)
}

func errNotASlicePointer() error {
return spannerErrorf(codes.InvalidArgument, "destination must be a pointer to a slice")
}

func errNilSlicePointer() error {
return spannerErrorf(codes.InvalidArgument, "destination must be a non nil pointer")
}

func errTooManyColumns() error {
return spannerErrorf(codes.InvalidArgument, "too many columns returned for primitive slice")
}

// ColumnByName fetches the value from the named column, decoding it into ptr.
// See the Row documentation for the list of acceptable argument types.
func (r *Row) ColumnByName(name string, ptr interface{}) error {
Expand Down Expand Up @@ -378,3 +390,175 @@ func (r *Row) ToStructLenient(p interface{}) error {
true,
)
}

// SelectAll iterates all rows to the end. After iterating it closes the rows
// and propagates any errors that could pop up with destination slice partially filled.
// It expects that destination should be a slice. For each row, it scans data and appends it to the destination slice.
// SelectAll supports both types of slices: slice of pointers and slice of structs or primitives by value,
// for example:
//
// type Singer struct {
// ID string
// Name string
// }
//
// var singersByPtr []*Singer
// var singersByValue []Singer
//
// Both singersByPtr and singersByValue are valid destinations for SelectAll function.
//
// Add the option `spanner.WithLenient()` to instruct SelectAll to ignore additional columns in the rows that are not present in the destination struct.
// example:
//
// var singersByPtr []*Singer
// err := spanner.SelectAll(row, &singersByPtr, spanner.WithLenient())
func SelectAll(rows rowIterator, destination interface{}, options ...DecodeOptions) error {
if rows == nil {
return fmt.Errorf("rows is nil")
}
if destination == nil {
return fmt.Errorf("destination is nil")
}
dstVal := reflect.ValueOf(destination)
if !dstVal.IsValid() || (dstVal.Kind() == reflect.Ptr && dstVal.IsNil()) {
return errNilSlicePointer()
}
if dstVal.Kind() != reflect.Ptr {
return errNotASlicePointer()
}
dstVal = dstVal.Elem()
dstType := dstVal.Type()
if k := dstType.Kind(); k != reflect.Slice {
return errNotASlicePointer()
}

itemType := dstType.Elem()
var itemByPtr bool
// If it's a slice of pointers to structs,
// we handle it the same way as it would be slice of struct by value
// and dereference pointers to values,
// because eventually we work with fields.
// But if it's a slice of primitive type e.g. or []string or []*string,
// we must leave and pass elements as is.
if itemType.Kind() == reflect.Ptr {
elementBaseTypeElem := itemType.Elem()
if elementBaseTypeElem.Kind() == reflect.Struct {
itemType = elementBaseTypeElem
itemByPtr = true
}
}
s := &decodeSetting{}
for _, opt := range options {
opt.Apply(s)
}

isPrimitive := itemType.Kind() != reflect.Struct
var pointers []interface{}
isFirstRow := true
var err error
return rows.Do(func(row *Row) error {
sliceItem := reflect.New(itemType)
if isFirstRow && !isPrimitive {
defer func() {
isFirstRow = false
}()
if pointers, err = structPointers(sliceItem.Elem(), row.fields, s.Lenient); err != nil {
return err
}
} else if isPrimitive {
if len(row.fields) > 1 && !s.Lenient {
return errTooManyColumns()
}
pointers = []interface{}{sliceItem.Interface()}
}
if len(pointers) == 0 {
return nil
}
err = row.Columns(pointers...)
if err != nil {
return err
}
if !isPrimitive {
e := sliceItem.Elem()
for i, p := range pointers {
if p == nil {
continue
}
e.Field(i).Set(reflect.ValueOf(p).Elem())
}
}
var elemVal reflect.Value
if itemByPtr {
if isFirstRow {
// create a new pointer to the struct with all the values copied from sliceItem
// because same underlying pointers array will be used for next rows
elemVal = reflect.New(itemType)
elemVal.Elem().Set(sliceItem.Elem())
} else {
elemVal = sliceItem
}
} else {
elemVal = sliceItem.Elem()
}
dstVal.Set(reflect.Append(dstVal, elemVal))
return nil
})
}

func structPointers(sliceItem reflect.Value, cols []*sppb.StructType_Field, lenient bool) ([]interface{}, error) {
pointers := make([]interface{}, 0, len(cols))
fieldTag := make(map[string]reflect.Value, len(cols))
initFieldTag(sliceItem, &fieldTag)

for _, colName := range cols {
var fieldVal reflect.Value
if v, ok := fieldTag[colName.GetName()]; ok {
fieldVal = v
} else {
if !lenient {
return nil, errNoOrDupGoField(sliceItem, colName.GetName())
}
fieldVal = sliceItem.FieldByName(colName.GetName())
}
if !fieldVal.IsValid() || !fieldVal.CanSet() {
// have to add if we found a column because Columns() requires
// len(cols) arguments or it will error. This way we can scan to
// a useless pointer
pointers = append(pointers, nil)
continue
}

pointers = append(pointers, fieldVal.Addr().Interface())
}
return pointers, nil
}

// Initialization the tags from struct.
func initFieldTag(sliceItem reflect.Value, fieldTagMap *map[string]reflect.Value) {
typ := sliceItem.Type()

for i := 0; i < sliceItem.NumField(); i++ {
fieldType := typ.Field(i)
exported := (fieldType.PkgPath == "")
// If a named field is unexported, ignore it. An anonymous
// unexported field is processed, because it may contain
// exported fields, which are visible.
if !exported && !fieldType.Anonymous {
continue
}
if fieldType.Type.Kind() == reflect.Struct {
// found an embedded struct
sliceItemOfAnonymous := sliceItem.Field(i)
initFieldTag(sliceItemOfAnonymous, fieldTagMap)
continue
}
name, keep, _, _ := spannerTagParser(fieldType.Tag)
if !keep {
continue
}
if name == "" {
name = fieldType.Name
}
(*fieldTagMap)[name] = sliceItem.Field(i)
}
}

0 comments on commit 802088f

Please sign in to comment.