From ea91f66ee9217dcdc924eec58fbfd3a8bb5ac788 Mon Sep 17 00:00:00 2001 From: Surya Asriadie Date: Sat, 19 Aug 2023 17:47:39 +0700 Subject: [PATCH] Refactor scanMulti to only call scan once (#339) --- .github/workflows/release.yml | 2 +- .github/workflows/test.yml | 4 +-- collection.go | 28 ++++++++------- cursor.go | 47 +++++++++++------------- cursor_test.go | 21 ++++++----- document.go | 33 ++++++++++++++--- document_test.go | 1 - go.mod | 2 +- repository_test.go | 68 +++++++++++++++++------------------ util.go | 6 ++++ util_test.go | 6 ++++ 11 files changed, 126 insertions(+), 92 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index fa75ee36..babe1155 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.20' + go-version: 1.21 - name: Run GoReleaser uses: goreleaser/goreleaser-action@v2 with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a076cd40..bf74b6f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.17, 1.18, 1.19, '1.20'] + go: [1.17, 1.18, 1.19, '1.20', 1.21] runs-on: ${{ matrix.os }} steps: - name: Set up Go 1.x @@ -33,7 +33,7 @@ jobs: - name: Set up Go 1.x uses: actions/setup-go@v3 with: - go-version: '1.20' + go-version: 1.21 - name: Check out code into the Go module directory uses: actions/checkout@v2 with: diff --git a/collection.go b/collection.go index f34ec28d..254f37d8 100644 --- a/collection.go +++ b/collection.go @@ -7,7 +7,8 @@ import ( type slice interface { table Reset() - Add() *Document + CreateDocument() *Document + Append(doc *Document) Get(index int) *Document Len() int Meta() DocumentMeta @@ -142,19 +143,22 @@ func (c Collection) Reset() { // Add new document into collection. func (c Collection) Add() *Document { - var ( - index = c.Len() - typ = c.rt.Elem() - drv = reflect.Zero(typ) - ) - - if typ.Kind() == reflect.Ptr && drv.IsNil() { - drv = reflect.New(drv.Type().Elem()) - } + c.Append(c.CreateDocument()) + return c.Get(c.Len() - 1) +} - c.rv.Set(reflect.Append(c.rv, drv)) +// CreateDocument returns new document with zero values. +func (c Collection) CreateDocument() *Document { + return newZeroDocument(c.rt.Elem()) +} - return NewDocument(c.rvIndex(index).Addr()) +// Append new document into collection. +func (c Collection) Append(doc *Document) { + if c.rt.Elem().Kind() == reflect.Ptr { + c.rv.Set(reflect.Append(c.rv, doc.rv.Addr())) + } else { + c.rv.Set(reflect.Append(c.rv, doc.rv)) + } } // Truncate collection. diff --git a/cursor.go b/cursor.go index 96892529..6e595367 100644 --- a/cursor.go +++ b/cursor.go @@ -1,7 +1,6 @@ package rel import ( - "database/sql" "reflect" ) @@ -63,48 +62,42 @@ func scanMulti(cur Cursor, keyField string, keyType reflect.Type, cols map[any][ return err } - var ( - found = false - keyValue = reflect.New(keyType) - keyScanners = make([]any, len(fields)) - ) - - for i, field := range fields { + keyFound := false + for _, field := range fields { if keyField == field { - found = true - keyScanners[i] = keyValue.Interface() - } else { - // need to create distinct copies - // otherwise next scan result will be corrupted - keyScanners[i] = &sql.RawBytes{} + keyFound = true } } - if !found && fields != nil { + if !keyFound && fields != nil { panic("rel: primary key row does not exists") } + var doc *Document + for k := range cols { + for _, col := range cols[k] { + doc = col.CreateDocument() + break + } + break + } + // scan the result for cur.Next() { // scan key - if err := cur.Scan(keyScanners...); err != nil { + if err := cur.Scan(doc.Scanners(fields)...); err != nil { return err } - var ( - key = reflect.Indirect(keyValue).Interface() - ) + key, found := doc.Value(keyField) + mustTrue(found, "rel: key field not found") for _, col := range cols[key] { - var ( - doc = col.Add() - scanners = doc.Scanners(fields) - ) - - if err := cur.Scan(scanners...); err != nil { - return err - } + col.Append(doc) } + + // create new doc for next scan + doc = doc.CreateDocument() } return nil diff --git a/cursor_test.go b/cursor_test.go index b25bf39a..4bd7b194 100644 --- a/cursor_test.go +++ b/cursor_test.go @@ -58,9 +58,13 @@ func (tc *testCursor) MockScan(ret ...any) *mock.Call { Return(func(scanners ...any) error { for i := 0; i < len(scanners); i++ { if v, ok := scanners[i].(sql.Scanner); ok { - v.Scan(ret[i]) + if err := v.Scan(ret[i]); err != nil { + return err + } } else { - convertAssign(scanners[i], ret[i]) + if err := convertAssign(scanners[i], ret[i]); err != nil { + return err + } } } @@ -192,8 +196,8 @@ func TestScanMulti(t *testing.T) { cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(10, "Del Piero", nil, now, nil).Times(3) - cur.MockScan(11, "Nedved", 46, now, now).Twice() + cur.MockScan(10, "Del Piero", nil, now, nil).Once() + cur.MockScan(11, "Nedved", 46, now, now).Once() cur.On("Next").Return(false).Once() assert.Nil(t, scanMulti(cur, keyField, keyType, cols)) @@ -237,7 +241,6 @@ func TestScanMulti_scanError(t *testing.T) { cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(11, "Nedved", 46, Now, Now).Once() cur.On("Scan", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err).Once() assert.Equal(t, err, scanMulti(cur, keyField, keyType, cols)) @@ -324,8 +327,8 @@ func TestScanMulti_multipleTimes(t *testing.T) { cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(10, "Del Piero", nil, now, nil).Times(3) - cur.MockScan(11, "Nedved", 46, now, now).Twice() + cur.MockScan(10, "Del Piero", nil, now, nil).Once() + cur.MockScan(11, "Nedved", 46, now, now).Once() cur.On("Next").Return(false).Once() assert.Nil(t, scanMulti(cur, keyField, keyType, cols)) @@ -360,8 +363,8 @@ func TestScanMulti_multipleTimes(t *testing.T) { cur.On("Fields").Return([]string{"id", "name", "age", "created_at", "updated_at"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(12, "Linus Torvalds", 52, now, nil).Times(3) - cur.MockScan(13, "Tim Cook", 61, now, now).Twice() + cur.MockScan(12, "Linus Torvalds", 52, now, nil).Once() + cur.MockScan(13, "Tim Cook", 61, now, now).Once() cur.On("Next").Return(false).Once() assert.Nil(t, scanMulti(cur, keyField, keyType, cols)) diff --git a/document.go b/document.go index 86e54146..564c1808 100644 --- a/document.go +++ b/document.go @@ -261,15 +261,27 @@ func (d Document) association(name string) (Association, bool) { func (d Document) Reset() { } -// Add returns this document. -func (d *Document) Add() *Document { - // if d.rv is a null pointer, set it to a new struct. +// CreateDocument returns new document with zero values. +func (d Document) CreateDocument() *Document { + return newZeroDocument(d.rt) +} + +// Append is alias for Assign for compatibility with internal slice interface +func (d *Document) Append(o *Document) { + d.Assign(o) +} + +// Assign document value to this document. +func (d *Document) Assign(o *Document) { if d.rv.Kind() == reflect.Ptr && d.rv.IsNil() { - d.rv.Set(reflect.New(d.rv.Type().Elem())) + d.rv.Set(o.rv.Addr()) d.rv = d.rv.Elem() + } else { + d.rv.Set(o.rv) } - return d + d.meta = o.meta + d.v = o.v } // Get always returns this document, this is a noop for compatibility with collection. @@ -336,3 +348,14 @@ func newDocument(v any, rv reflect.Value, readonly bool) *Document { meta: getDocumentMeta(rt, false), } } + +func newZeroDocument(rt reflect.Type) *Document { + if rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + + rv := reflect.New(rt) + rv.Elem().Set(reflect.Zero(rt)) + + return NewDocument(rv) +} diff --git a/document_test.go b/document_test.go index d42dcaee..5f3d845f 100644 --- a/document_test.go +++ b/document_test.go @@ -571,7 +571,6 @@ func TestDocument_Slice(t *testing.T) { doc.Reset() assert.Equal(t, 1, doc.Len()) assert.Equal(t, doc, doc.Get(0)) - assert.Equal(t, doc, doc.Add()) }) } diff --git a/go.mod b/go.mod index fa72625c..52db283f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-rel/rel -go 1.20 +go 1.21 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/repository_test.go b/repository_test.go index 71220cf5..7a83dbe0 100644 --- a/repository_test.go +++ b/repository_test.go @@ -3119,7 +3119,7 @@ func TestRepository_Preload_hasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(address.ID, *address.UserID).Times(2) + cur.MockScan(address.ID, *address.UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &user, "address")) @@ -3185,7 +3185,7 @@ func TestRepository_Preload_ptrHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(address.ID, *address.UserID).Times(2) + cur.MockScan(address.ID, *address.UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &user, "work_address")) @@ -3235,8 +3235,8 @@ func TestRepository_Preload_sliceHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(addresses[0].ID, *addresses[0].UserID).Twice() - cur.MockScan(addresses[1].ID, *addresses[1].UserID).Twice() + cur.MockScan(addresses[0].ID, *addresses[0].UserID).Once() + cur.MockScan(addresses[1].ID, *addresses[1].UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &users, "address")) @@ -3266,8 +3266,8 @@ func TestRepository_Preload_ptrSliceHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(addresses[0].ID, *addresses[0].UserID).Twice() - cur.MockScan(addresses[1].ID, *addresses[1].UserID).Twice() + cur.MockScan(addresses[0].ID, *addresses[0].UserID).Once() + cur.MockScan(addresses[1].ID, *addresses[1].UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &users, "work_address")) @@ -3317,7 +3317,7 @@ func TestRepository_Preload_nestedHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(address.ID, *address.UserID).Twice() + cur.MockScan(address.ID, *address.UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transaction, "buyer.address")) @@ -3343,7 +3343,7 @@ func TestRepository_Preload_ptrNestedHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(address.ID, *address.UserID).Twice() + cur.MockScan(address.ID, *address.UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transaction, "buyer.work_address")) @@ -3398,8 +3398,8 @@ func TestRepository_Preload_sliceNestedHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(addresses[0].ID, *addresses[0].UserID).Twice() - cur.MockScan(addresses[1].ID, *addresses[1].UserID).Twice() + cur.MockScan(addresses[0].ID, *addresses[0].UserID).Once() + cur.MockScan(addresses[1].ID, *addresses[1].UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transactions, "buyer.address")) @@ -3432,8 +3432,8 @@ func TestRepository_Preload_ptrSliceNestedHasOne(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(addresses[0].ID, *addresses[0].UserID).Twice() - cur.MockScan(addresses[1].ID, *addresses[1].UserID).Twice() + cur.MockScan(addresses[0].ID, *addresses[0].UserID).Once() + cur.MockScan(addresses[1].ID, *addresses[1].UserID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transactions, "buyer.work_address")) @@ -3487,8 +3487,8 @@ func TestRepository_Preload_hasMany(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Twice() - cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Twice() + cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Once() + cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &user, "transactions")) @@ -3518,10 +3518,10 @@ func TestRepository_Preload_sliceHasMany(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Times(4) - cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Twice() - cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Twice() - cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Twice() - cur.MockScan(transactions[3].ID, transactions[3].BuyerID).Twice() + cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Once() + cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Once() + cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Once() + cur.MockScan(transactions[3].ID, transactions[3].BuyerID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &users, "transactions")) @@ -3550,8 +3550,8 @@ func TestRepository_Preload_nestedHasMany(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Twice() - cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Twice() + cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Once() + cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &address, "user.transactions")) @@ -3597,10 +3597,10 @@ func TestRepository_Preload_nestedSliceHasMany(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Times(4) - cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Twice() - cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Twice() - cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Twice() - cur.MockScan(transactions[3].ID, transactions[3].BuyerID).Twice() + cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Once() + cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Once() + cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Once() + cur.MockScan(transactions[3].ID, transactions[3].BuyerID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &addresses, "user.transactions")) @@ -3634,9 +3634,9 @@ func TestRepository_Preload_nestedNullSliceHasMany(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "user_id"}, nil).Once() cur.On("Next").Return(true).Times(3) - cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Twice() - cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Twice() - cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Twice() + cur.MockScan(transactions[0].ID, transactions[0].BuyerID).Once() + cur.MockScan(transactions[1].ID, transactions[1].BuyerID).Once() + cur.MockScan(transactions[2].ID, transactions[2].BuyerID).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &addresses, "user.transactions")) @@ -3662,7 +3662,7 @@ func TestRepository_Preload_belongsTo(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "name"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(user.ID, user.Name).Twice() + cur.MockScan(user.ID, user.Name).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transaction, "buyer")) @@ -3686,7 +3686,7 @@ func TestRepository_Preload_ptrBelongsTo(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "name"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(user.ID, user.Name).Twice() + cur.MockScan(user.ID, user.Name).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &address, "user")) @@ -3752,8 +3752,8 @@ func TestRepository_Preload_sliceBelongsTo(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "name"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(users[0].ID, users[0].Name).Twice() - cur.MockScan(users[1].ID, users[1].Name).Twice() + cur.MockScan(users[0].ID, users[0].Name).Once() + cur.MockScan(users[1].ID, users[1].Name).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &transactions, "buyer")) @@ -3785,8 +3785,8 @@ func TestRepository_Preload_ptrSliceBelongsTo(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "name"}, nil).Once() cur.On("Next").Return(true).Twice() - cur.MockScan(users[0].ID, users[0].Name).Twice() - cur.MockScan(users[1].ID, users[1].Name).Twice() + cur.MockScan(users[0].ID, users[0].Name).Once() + cur.MockScan(users[1].ID, users[1].Name).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &addresses, "user")) @@ -3847,7 +3847,7 @@ func TestRepository_Preload_sliceNestedBelongsTo(t *testing.T) { cur.On("Close").Return(nil).Once() cur.On("Fields").Return([]string{"id", "street"}, nil).Once() cur.On("Next").Return(true).Once() - cur.MockScan(address.ID, address.Street).Twice() + cur.MockScan(address.ID, address.Street).Once() cur.On("Next").Return(false).Once() assert.Nil(t, repo.Preload(context.TODO(), &users, "transactions.address")) diff --git a/util.go b/util.go index dbcd0bda..f9c75ee1 100644 --- a/util.go +++ b/util.go @@ -34,6 +34,12 @@ func must(err error) { } } +func mustTrue(flag bool, msg string) { + if !flag { + panic(msg) + } +} + type isZeroer interface { IsZero() bool } diff --git a/util_test.go b/util_test.go index 5fa1b6d7..4f2d0687 100644 --- a/util_test.go +++ b/util_test.go @@ -15,6 +15,12 @@ func TestMust(t *testing.T) { }) } +func TestMustTrue(t *testing.T) { + assert.Panics(t, func() { + mustTrue(false, "error") + }) +} + func TestIsZero(t *testing.T) { tests := []any{ nil,