Skip to content

Commit

Permalink
fix: circular reference save, close #5140
Browse files Browse the repository at this point in the history
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144
Author: Jinzhu <wosmvp@gmail.com>
Date:   Thu Mar 17 23:49:21 2022 +0800

    Refactor #5140

commit 6e3ca2d
Author: a631807682 <631807682@qq.com>
Date:   Sun Mar 13 12:52:08 2022 +0800

    test: add test for LoadOrStoreVisitMap

commit 9d5c68e
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:33:47 2022 +0800

    chore: add more comment

commit bfffefb
Author: chenrui <chenrui@jingdaka.com>
Date:   Thu Mar 10 20:28:48 2022 +0800

    fix: should check values has been saved instead of rel.Name

commit e55cdfa
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:48:01 2022 +0800

    chore: go lint

commit fe4715c
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:27:24 2022 +0800

    chore: add test comment

commit 326862f
Author: chenrui <chenrui@jingdaka.com>
Date:   Tue Mar 8 17:22:33 2022 +0800

    fix: circular reference save
  • Loading branch information
a631807682 authored and jinzhu committed Mar 17, 2022
1 parent 2990790 commit 9b9ae32
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 8 deletions.
41 changes: 34 additions & 7 deletions callbacks/associations.go
Expand Up @@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
}

if elems.Len() > 0 {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
Expand All @@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
rv = rv.Addr()
}

if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
Expand Down Expand Up @@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
Expand All @@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}

Expand Down Expand Up @@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
}

for i := 0; i < elemLen; i++ {
Expand Down Expand Up @@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
return
}

func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}

var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)

for name, ok := range selectColumns {
Expand Down Expand Up @@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},

return db.AddError(tx.Create(values).Error)
}

// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"

func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}

return false
}
30 changes: 30 additions & 0 deletions callbacks/helper.go
@@ -1,6 +1,7 @@
package callbacks

import (
"reflect"
"sort"

"gorm.io/gorm"
Expand Down Expand Up @@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) {
return
}
}

type visitMap = map[reflect.Value]bool

// Check if circular values, return true if loaded
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(vistMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*vistMap)[p]; ok {
return true
}
(*vistMap)[p] = true
}
}

return
}
36 changes: 36 additions & 0 deletions callbacks/visit_map_test.go
@@ -0,0 +1,36 @@
package callbacks

import (
"reflect"
"testing"
)

func TestLoadOrStoreVisitMap(t *testing.T) {
var vm visitMap
var loaded bool
type testM struct {
Name string
}

t1 := testM{Name: "t1"}
t2 := testM{Name: "t2"}
t3 := testM{Name: "t3"}

vm = make(visitMap)
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
t.Fatalf("loaded should be false")
}

if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
t.Fatalf("loaded should be true")
}

// t1 already exist but t2 not
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
t.Fatalf("loaded should be false")
}

if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
t.Fatalf("loaded should be true")
}
}
41 changes: 41 additions & 0 deletions tests/associations_test.go
Expand Up @@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) {
t.Errorf("Failed to preload AppliesToProduct")
}
}

func TestSaveBelongsCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)

child := Child{ParentID: &parent.ID, Parent: &parent}
DB.Create(&child)

parent.FavChildID = child.ID
parent.FavChild = &child
DB.Save(&parent)

var parent1 Parent
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")

// Save and Updates is the same
DB.Updates(&parent)
DB.First(&parent1, parent.ID)
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
}

func TestSaveHasManyCircularReference(t *testing.T) {
parent := Parent{}
DB.Create(&parent)

child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}

parent.Children = []*Child{&child, &child1}
DB.Save(&parent)

var children []*Child
DB.Where("parent_id = ?", parent.ID).Find(&children)
if len(children) != len(parent.Children) ||
children[0].ID != parent.Children[0].ID ||
children[1].ID != parent.Children[1].ID {
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
children, parent.Children)
}
}
2 changes: 1 addition & 1 deletion tests/tests_test.go
Expand Up @@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {

func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

Expand Down
14 changes: 14 additions & 0 deletions utils/tests/models.go
Expand Up @@ -80,3 +80,17 @@ type Order struct {
Coupon *Coupon
CouponID string
}

type Parent struct {
gorm.Model
FavChildID uint
FavChild *Child
Children []*Child
}

type Child struct {
gorm.Model
Name string
ParentID *uint
Parent *Parent
}

0 comments on commit 9b9ae32

Please sign in to comment.