diff --git a/callbacks/associations.go b/callbacks/associations.go index 9d7c1412e..f3cd464ae 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() != reflect.Struct { break } - if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) } } } if elems.Len() > 0 { - if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 845c16af5..b69d668aa 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -393,3 +393,33 @@ func TestConcurrentMany2ManyAssociation(t *testing.T) { AssertEqual(t, err, nil) AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") } + +func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { + user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + + user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{ + {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{ + ID: 1, + Name: "Test-company-1", + }}, + }} + users := []*User{&user1, &user2} + var err error + err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error + AssertEqual(t, nil, err) + + var findUser1 User + err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error + AssertEqual(t, nil, err) + AssertEqual(t, user1, findUser1) + + var findUser2 User + err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error + AssertEqual(t, nil, err) + AssertEqual(t, user2, findUser2) +}