Skip to content

Commit

Permalink
Autoload association (#140)
Browse files Browse the repository at this point in the history
* add new autoload tag

* add auto tag to enable both

* auto preload
  • Loading branch information
Fs02 committed Nov 5, 2020
1 parent 110059a commit 53b3ab6
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 38 deletions.
9 changes: 8 additions & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type associationData struct {
foreignField string
foreignIndex int
through string
autoload bool
autosave bool
}

Expand Down Expand Up @@ -147,6 +148,11 @@ func (a Association) Through() string {
return a.data.through
}

// Autoload assoc setting when parent is loaded.
func (a Association) Autoload() bool {
return a.data.autoload
}

// Autosave setting when parent is created/updated/deleted.
func (a Association) Autosave() bool {
return a.data.autosave
Expand Down Expand Up @@ -184,7 +190,8 @@ func extractAssociationData(rt reflect.Type, index int) associationData {
assocData = associationData{
targetIndex: sf.Index,
through: sf.Tag.Get("through"),
autosave: sf.Tag.Get("autosave") == "true",
autoload: sf.Tag.Get("auto") == "true" || sf.Tag.Get("autoload") == "true",
autosave: sf.Tag.Get("auto") == "true" || sf.Tag.Get("autosave") == "true",
}
)

Expand Down
12 changes: 12 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func TestAssociation_Document(t *testing.T) {
referenceValue interface{}
foreignField string
foreignValue interface{}
autosave bool
autoload bool
}{
{
record: "Transaction",
Expand All @@ -42,6 +44,7 @@ func TestAssociation_Document(t *testing.T) {
referenceValue: transaction.BuyerID,
foreignField: "id",
foreignValue: transaction.Buyer.ID,
autoload: true,
},
{
record: "Transaction",
Expand All @@ -55,6 +58,7 @@ func TestAssociation_Document(t *testing.T) {
referenceValue: transactionLoaded.BuyerID,
foreignField: "id",
foreignValue: transactionLoaded.Buyer.ID,
autoload: true,
},
{
record: "User",
Expand All @@ -68,6 +72,7 @@ func TestAssociation_Document(t *testing.T) {
referenceValue: user.ID,
foreignField: "user_id",
foreignValue: nil,
autosave: true,
},
{
record: "User",
Expand All @@ -81,6 +86,7 @@ func TestAssociation_Document(t *testing.T) {
referenceValue: userLoaded.ID,
foreignField: "user_id",
foreignValue: nil,
autosave: true,
},
{
record: "Address",
Expand Down Expand Up @@ -128,6 +134,8 @@ func TestAssociation_Document(t *testing.T) {
assert.Equal(t, test.referenceField, assoc.ReferenceField())
assert.Equal(t, test.referenceValue, assoc.ReferenceValue())
assert.Equal(t, test.foreignField, assoc.ForeignField())
assert.Equal(t, test.autoload, assoc.Autoload())
assert.Equal(t, test.autosave, assoc.Autosave())

if test.typ == HasMany {
assert.Panics(t, func() {
Expand Down Expand Up @@ -169,6 +177,8 @@ func TestAssociation_Collection(t *testing.T) {
foreignValue interface{}
foreignThrough string
through string
autoload bool
autosave bool
}{
{
record: "User",
Expand Down Expand Up @@ -271,6 +281,8 @@ func TestAssociation_Collection(t *testing.T) {
assert.Equal(t, test.referenceValue, assoc.ReferenceValue())
assert.Equal(t, test.foreignField, assoc.ForeignField())
assert.Equal(t, test.through, assoc.Through())
assert.Equal(t, test.autoload, assoc.Autoload())
assert.Equal(t, test.autosave, assoc.Autosave())

if test.typ == HasMany {
assert.Panics(t, func() {
Expand Down
16 changes: 15 additions & 1 deletion document.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type documentData struct {
hasMany []string
primaryField []string
primaryIndex []int
preload []string
flag DocumentFlag
}

Expand Down Expand Up @@ -315,6 +316,11 @@ func (d Document) HasMany() []string {
return d.data.hasMany
}

// Preload fields of this document.
func (d Document) Preload() []string {
return d.data.preload
}

// Association of this document with given name.
func (d Document) Association(name string) Association {
index, ok := d.data.index[name]
Expand Down Expand Up @@ -440,14 +446,22 @@ func extractDocumentData(rt reflect.Type, skipAssoc bool) documentData {
}

if !skipAssoc {
switch extractAssociationData(rt, i).typ {
var (
assocData = extractAssociationData(rt, i)
)

switch assocData.typ {
case BelongsTo:
data.belongsTo = append(data.belongsTo, name)
case HasOne:
data.hasOne = append(data.hasOne, name)
case HasMany:
data.hasMany = append(data.hasMany, name)
}

if assocData.autoload {
data.preload = append(data.preload, name)
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ func TestDocument_Association(t *testing.T) {
belongsTo []string
hasOne []string
hasMany []string
preload []string
}{
{
name: "User",
Expand All @@ -408,6 +409,7 @@ func TestDocument_Association(t *testing.T) {
record: &Transaction{},
belongsTo: []string{"buyer", "address"},
hasMany: []string{"histories"},
preload: []string{"buyer"},
},
{
name: "Address",
Expand All @@ -429,6 +431,7 @@ func TestDocument_Association(t *testing.T) {
assert.Equal(t, test.belongsTo, doc.BelongsTo())
assert.Equal(t, test.hasOne, doc.HasOne())
assert.Equal(t, test.hasMany, doc.HasMany())
assert.Equal(t, test.preload, doc.Preload())
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion rel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type Transaction struct {
Item string
Status Status
BuyerID int `db:"user_id"`
Buyer User `ref:"user_id" fk:"id"`
Buyer User `ref:"user_id" fk:"id" autoload:"true"`
AddressID int
Address Address
Histories *[]History
Expand Down
30 changes: 15 additions & 15 deletions repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (r repository) MustFind(ctx context.Context, record interface{}, queriers .
}

func (r repository) find(cw contextWrapper, doc *Document, query Query) error {
query = r.withDefaultScope(doc.data, query)
query = r.withDefaultScope(doc.data, query, true)
cur, err := cw.adapter.Query(cw.ctx, query.Limit(1))
if err != nil {
return err
Expand All @@ -227,11 +227,9 @@ func (r repository) find(cw contextWrapper, doc *Document, query Query) error {
}
finish(nil)

if query.CascadeQuery {
for i := range query.PreloadQuery {
if err := r.preload(cw, doc, query.PreloadQuery[i], nil); err != nil {
return err
}
for i := range query.PreloadQuery {
if err := r.preload(cw, doc, query.PreloadQuery[i], nil); err != nil {
return err
}
}

Expand All @@ -258,7 +256,7 @@ func (r repository) MustFindAll(ctx context.Context, records interface{}, querie
}

func (r repository) findAll(cw contextWrapper, col *Collection, query Query) error {
query = r.withDefaultScope(col.data, query)
query = r.withDefaultScope(col.data, query, true)
cur, err := cw.adapter.Query(cw.ctx, query)
if err != nil {
return err
Expand All @@ -271,11 +269,9 @@ func (r repository) findAll(cw contextWrapper, col *Collection, query Query) err
}
finish(nil)

if query.CascadeQuery {
for i := range query.PreloadQuery {
if err := r.preload(cw, col, query.PreloadQuery[i], nil); err != nil {
return err
}
for i := range query.PreloadQuery {
if err := r.preload(cw, col, query.PreloadQuery[i], nil); err != nil {
return err
}
}

Expand Down Expand Up @@ -479,7 +475,7 @@ func (r repository) update(cw contextWrapper, doc *Document, mutation Mutation,

if !mutation.IsMutatesEmpty() {
var (
query = r.withDefaultScope(doc.data, Build(doc.Table(), filter, mutation.Unscoped, mutation.Cascade))
query = r.withDefaultScope(doc.data, Build(doc.Table(), filter, mutation.Unscoped, mutation.Cascade), false)
)

if updatedCount, err := cw.adapter.Update(cw.ctx, query, mutation.Mutates); err != nil {
Expand Down Expand Up @@ -925,7 +921,7 @@ func (r repository) preload(cw contextWrapper, records slice, field string, quer
}

var (
cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query))
cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false))
)

if err != nil {
Expand Down Expand Up @@ -1052,7 +1048,7 @@ func (r repository) targetIDs(targets map[interface{}][]slice) []interface{} {
return ids
}

func (r repository) withDefaultScope(ddata documentData, query Query) Query {
func (r repository) withDefaultScope(ddata documentData, query Query, preload bool) Query {
if query.UnscopedQuery {
return query
}
Expand All @@ -1061,6 +1057,10 @@ func (r repository) withDefaultScope(ddata documentData, query Query) Query {
query = query.Where(Nil("deleted_at"))
}

if preload && bool(query.CascadeQuery) {
query.PreloadQuery = append(ddata.preload, query.PreloadQuery...)
}

return query
}

Expand Down
44 changes: 24 additions & 20 deletions repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,21 @@ func TestRepository_Find_softDeleteUnscoped(t *testing.T) {
cur.AssertExpectations(t)
}

func TestRepository_Find_withPreloadAndDisabledCascade(t *testing.T) {
func TestRepository_Find_withCascade(t *testing.T) {
var (
user User
adapter = &testAdapter{}
repo = New(adapter)
query = From("users").Limit(1).Preload("address").Cascade(false)
cur = createCursor(1)
trx Transaction
adapter = &testAdapter{}
repo = New(adapter)
query = From("transactions").Limit(1).Cascade(true)
cur = createCursor(1)
curPreload = createCursor(0)
)

adapter.On("Query", query).Return(cur, nil).Once()
adapter.On("Query", query.Preload("buyer")).Return(cur, nil).Once()
adapter.On("Query", From("users").Where(In("id", 0))).Return(curPreload, nil).Once()

assert.Nil(t, repo.Find(context.TODO(), &user, query))
assert.Equal(t, 10, user.ID)
assert.Nil(t, repo.Find(context.TODO(), &trx, query))
assert.Equal(t, 10, trx.ID)
assert.False(t, cur.Next())

adapter.AssertExpectations(t)
Expand Down Expand Up @@ -434,21 +436,23 @@ func TestRepository_FindAll_softDeleteUnscoped(t *testing.T) {
cur.AssertExpectations(t)
}

func TestRepository_FindAll_withPreloadAndDisabledCascade(t *testing.T) {
func TestRepository_FindAll_withCascade(t *testing.T) {
var (
addresses []Address
adapter = &testAdapter{}
repo = New(adapter)
query = From("addresses").Preload("user").Cascade(false)
cur = createCursor(2)
trxs []Transaction
adapter = &testAdapter{}
repo = New(adapter)
query = From("transactions")
cur = createCursor(2)
curPreload = createCursor(0)
)

adapter.On("Query", query.Where(Nil("deleted_at"))).Return(cur, nil).Once()
adapter.On("Query", query.Preload("buyer")).Return(cur, nil).Once()
adapter.On("Query", From("users").Where(In("id", 0))).Return(curPreload, nil)

assert.Nil(t, repo.FindAll(context.TODO(), &addresses, query))
assert.Len(t, addresses, 2)
assert.Equal(t, 10, addresses[0].ID)
assert.Equal(t, 10, addresses[1].ID)
assert.Nil(t, repo.FindAll(context.TODO(), &trxs, query))
assert.Len(t, trxs, 2)
assert.Equal(t, 10, trxs[0].ID)
assert.Equal(t, 10, trxs[1].ID)

adapter.AssertExpectations(t)
cur.AssertExpectations(t)
Expand Down

0 comments on commit 53b3ab6

Please sign in to comment.