Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch save/update for groups and users #2245

Merged
merged 9 commits into from
Jul 15, 2024
2 changes: 2 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type AccountManager interface {
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
Expand All @@ -95,6 +96,7 @@ type AccountManager interface {
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
Expand Down
8 changes: 8 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,3 +746,11 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}

func (s *FileStore) SaveUsers(_ *Account) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}

func (s *FileStore) SaveGroups(_ *Account) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
106 changes: 68 additions & 38 deletions management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,61 +112,85 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}

// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
pappz marked this conversation as resolved.
Show resolved Hide resolved
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}

if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
var eventsToStore []func()

if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}

existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}

// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}

newGroup.ID = xid.New().String()
}

// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
// coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}

newGroup.ID = xid.New().String()
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup

for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}

oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup

account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveGroups(account); err != nil {
return err
}

am.updateAccountPeers(ctx, account)

// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
for _, storeEvent := range eventsToStore {
storeEvent()
}

return nil
}

// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()

addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if exists {

if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
}

for _, p := range addedPeers {
Expand All @@ -175,11 +199,14 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}

for _, p := range removedPeers {
Expand All @@ -188,14 +215,17 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}

return nil
return eventsToStore
}

// difference returns the elements in `a` that aren't in `b`.
Expand Down
18 changes: 18 additions & 0 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type MockAccountManager struct {
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
Expand All @@ -64,6 +65,7 @@ type MockAccountManager struct {
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
Expand Down Expand Up @@ -308,6 +310,14 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
}

// SaveGroups mock implementation of SaveGroups from server.AccountManager interface
func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method SaveGroups is not implemented")
}

// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil {
Expand Down Expand Up @@ -502,6 +512,14 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
}

// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
if am.SaveOrAddUsersFunc != nil {
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUsers is not implemented")
}

// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
Expand Down
28 changes: 28 additions & 0 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,32 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
return nil
}

// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(account *Account) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same like groups

for id, user := range account.Users {
user.Id = id
user.AccountID = account.Id
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&account.UsersG).Error
}

// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(account *Account) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not memory efficient. If I understand well the reason of the "G" postfix it is just used by the Gorm. So after that, we always need to clean up because the garbage collector will not do it. Nevertheless, we should not mix the layers. The account structure holds the account info and the SQL-related things. We should create separate struct definition just for the SQL. It is more code but at the end, it will be safer.

Now I recommend to:

  • change the parameters of the function *SaveGroups(accountID string, groups map[string]nbgroup.Group)
  • create a new slice []nbgroup.Group
  • copy the relevant data to there
  • change the db.clauses to use the new slice

So after it the GroupsG will be empty again

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the "&gorm.Session{FullSaveAssociations: true}" keyword help to you

for id, group := range account.Groups {
group.ID = id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to do this operation? In this line I think we fill well the variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily needed, but it was to make sure the group ID is the same as the index

group.AccountID = account.Id
account.GroupsG = append(account.GroupsG, *group)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&account.GroupsG).Error
}

// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
Expand Down Expand Up @@ -653,6 +679,8 @@ func (s *SqlStore) GetStoreEngine() StoreEngine {
return s.storeEngine
}

//func (s *SqlStore) SaveGroups()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary comment


// NewSqliteStore creates a new SQLite store.
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
Expand Down
2 changes: 2 additions & 0 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(account *Account) error
SaveGroups(account *Account) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
Expand Down
Loading
Loading