Skip to content

Commit

Permalink
fix: Add ReplaceManyToMany function which replaces associations
Browse files Browse the repository at this point in the history
  • Loading branch information
codelite7 committed Jan 31, 2024
1 parent eadce0a commit c9daf2c
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 0 deletions.
22 changes: 22 additions & 0 deletions example/cockroachdb/example.pb.gorm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions example/postgres/example.pb.gorm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions plugin/generics_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,28 @@ func (m *ManyToManyAssociations) AddAssociation(modelId, associatedId string) {
m.data.Store(modelId, associations)
}
func ReplaceManyToMany[L Models, R Models](ctx context.Context, db *gorm.DB, associations *ManyToManyAssociations, associationName string) error {
session := db.Session(&gorm.Session{})
session = session.Clauses(clause.OnConflict{DoNothing: true})
for id, associatedIds := range associations.Associations() {
var associations []R
var temp L
model := temp.New().(L)
model.SetModelId(id)
for _, id := range associatedIds {
var associatedTemp R
associatedModel := associatedTemp.New().(R)
associatedModel.SetModelId(id)
associations = append(associations, associatedModel)
}
err := session.Model(&model).Association(associationName).Replace(&associations)
if err != nil {
return err
}
}
return nil
}
func AssociateManyToMany[L Models, R Models](ctx context.Context, db *gorm.DB, associations *ManyToManyAssociations, associationName string) error {
session := db.Session(&gorm.Session{})
session = session.Clauses(clause.OnConflict{DoNothing: true})
Expand Down
50 changes: 50 additions & 0 deletions test/cockroachdb_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,56 @@ func (s *CockroachdbPluginSuite) TestDissociateManyToMany() {
}
}

// TestDissociateManyToMany tests that dissociateManyToMany works as expected
func (s *CockroachdbPluginSuite) TestReplaceManyToMany() {
// create a user and profiles
user := getCockroachdbUser(s.T())
userModels, err := Upsert[*User, *UserGormModel](context.Background(), cockroachdbDb, []*User{user})
require.NoError(s.T(), err)
numProfiles := gofakeit.Number(5, 10)
profiles := getCockroachdbProfiles(s.T(), numProfiles)
profileModels, err := Upsert[*Profile, *ProfileGormModel](context.Background(), cockroachdbDb, profiles)
require.NoError(s.T(), err)
// associate the users and profiles
associations := &ManyToManyAssociations{}
for _, profile := range profileModels {
associations.AddAssociation(*userModels[0].Id, *profile.Id)
}
err = AssociateManyToMany[*UserGormModel, *ProfileGormModel](context.Background(), cockroachdbDb, associations, "Profiles")
require.NoError(s.T(), err)
// get with preload
fetchedUsers, err := GetByIds[*UserGormModel](context.Background(), cockroachdbDb, []string{*user.Id}, []string{"Profiles"})
require.NoError(s.T(), err)
// assert
expectedUserModel := fetchedUsers[0]
expectedUser, err := expectedUserModel.ToProto()
require.NoError(s.T(), err)
assertCockroachdbProtosEquality(s.T(), profiles, expectedUser.Profiles,
protocmp.IgnoreFields(&Profile{}, "created_at", "updated_at"),
)
// replace
replacementProfiles := getCockroachdbProfiles(s.T(), numProfiles)
replacementProfileModels, err := Upsert[*Profile, *ProfileGormModel](context.Background(), cockroachdbDb, replacementProfiles)
require.NoError(s.T(), err)
// associate the users and profiles
replacementAssociations := &ManyToManyAssociations{}
for _, profile := range replacementProfileModels {
replacementAssociations.AddAssociation(*userModels[0].Id, *profile.Id)
}
err = ReplaceManyToMany[*UserGormModel, *ProfileGormModel](context.Background(), cockroachdbDb, replacementAssociations, "Profiles")
require.NoError(s.T(), err)
// get with preload
fetchedUsers, err = GetByIds[*UserGormModel](context.Background(), cockroachdbDb, []string{*user.Id}, []string{"Profiles"})
require.NoError(s.T(), err)
// assert
expectedUserModel = fetchedUsers[0]
expectedUser, err = expectedUserModel.ToProto()
require.NoError(s.T(), err)
assertCockroachdbProtosEquality(s.T(), replacementProfiles, expectedUser.Profiles,
protocmp.IgnoreFields(&Profile{}, "created_at", "updated_at"),
)
}

// TestListWithWhere tests that the list function works with a where clause set on the tx
func (s *CockroachdbPluginSuite) TestListWithWhere() {
// create profiles
Expand Down
50 changes: 50 additions & 0 deletions test/postgres_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,56 @@ func (s *PostgresPluginSuite) TestDissociateManyToMany() {
}
}

// TestReplaceManyToMany tests that replaceManyToMany works as expected
func (s *PostgresPluginSuite) TestReplaceManyToMany() {
// create a user and profiles
user := getPostgresUser(s.T())
userModels, err := Upsert[*User, *UserGormModel](context.Background(), postgresDb, []*User{user})
require.NoError(s.T(), err)
numProfiles := gofakeit.Number(5, 10)
profiles := getPostgresProfiles(s.T(), numProfiles)
profileModels, err := Upsert[*Profile, *ProfileGormModel](context.Background(), postgresDb, profiles)
require.NoError(s.T(), err)
// associate the users and profiles
associations := &ManyToManyAssociations{}
for _, profile := range profileModels {
associations.AddAssociation(*userModels[0].Id, *profile.Id)
}
err = AssociateManyToMany[*UserGormModel, *ProfileGormModel](context.Background(), postgresDb, associations, "Profiles")
require.NoError(s.T(), err)
// get with preload
fetchedUsers, err := GetByIds[*UserGormModel](context.Background(), postgresDb, []string{*user.Id}, []string{"Profiles"})
require.NoError(s.T(), err)
// assert
expectedUserModel := fetchedUsers[0]
expectedUser, err := expectedUserModel.ToProto()
require.NoError(s.T(), err)
assertPostgresProtosEquality(s.T(), profiles, expectedUser.Profiles,
protocmp.IgnoreFields(&Profile{}, "created_at", "updated_at"),
)
// replace
replacementProfiles := getPostgresProfiles(s.T(), numProfiles)
replacementProfileModels, err := Upsert[*Profile, *ProfileGormModel](context.Background(), postgresDb, replacementProfiles)
require.NoError(s.T(), err)
// associate the users and profiles
replacementAssociations := &ManyToManyAssociations{}
for _, profile := range replacementProfileModels {
replacementAssociations.AddAssociation(*userModels[0].Id, *profile.Id)
}
err = ReplaceManyToMany[*UserGormModel, *ProfileGormModel](context.Background(), postgresDb, replacementAssociations, "Profiles")
require.NoError(s.T(), err)
// get with preload
fetchedUsers, err = GetByIds[*UserGormModel](context.Background(), postgresDb, []string{*user.Id}, []string{"Profiles"})
require.NoError(s.T(), err)
// assert
expectedUserModel = fetchedUsers[0]
expectedUser, err = expectedUserModel.ToProto()
require.NoError(s.T(), err)
assertPostgresProtosEquality(s.T(), replacementProfiles, expectedUser.Profiles,
protocmp.IgnoreFields(&Profile{}, "created_at", "updated_at"),
)
}

// TestListWithWhere tests that the list function works with a where clause set on the tx
func (s *PostgresPluginSuite) TestListWithWhere() {
// create profiles
Expand Down

0 comments on commit c9daf2c

Please sign in to comment.