diff --git a/group/memory_group.go b/group/memory_group.go index a6e809c..dd870c7 100644 --- a/group/memory_group.go +++ b/group/memory_group.go @@ -4,13 +4,17 @@ import ( "context" "sync" "time" + + "github.com/metagogs/gogs/utils/slicex" ) var _ Group = (*MemoryGroup)(nil) type MemoryGroup struct { + mutex sync.RWMutex name string - uids sync.Map + uids map[string]struct{} + uidsList []string groupID int64 lastRefresh int64 } @@ -19,13 +23,17 @@ func NewMemoryGroup(name string, groupID int64) *MemoryGroup { return &MemoryGroup{ name: name, groupID: groupID, + uids: make(map[string]struct{}), lastRefresh: time.Now().Unix(), } } func (group *MemoryGroup) AddUser(ctx context.Context, uid string) error { - if _, ok := group.uids.Load(uid); !ok { - group.uids.Store(uid, uid) + group.mutex.Lock() + defer group.mutex.Unlock() + if _, ok := group.uids[uid]; !ok { + group.uids[uid] = struct{}{} + group.uidsList = append(group.uidsList, uid) group.lastRefresh = time.Now().Unix() return nil } @@ -34,8 +42,11 @@ func (group *MemoryGroup) AddUser(ctx context.Context, uid string) error { } func (group *MemoryGroup) RemoveUser(ctx context.Context, uid string) error { - if _, ok := group.uids.Load(uid); ok { - group.uids.Delete(uid) + group.mutex.Lock() + defer group.mutex.Unlock() + if _, ok := group.uids[uid]; ok { + delete(group.uids, uid) + group.uidsList = slicex.RemoveSliceItem(group.uidsList, uid) group.lastRefresh = time.Now().Unix() return nil } @@ -51,31 +62,20 @@ func (group *MemoryGroup) RemoveUsers(ctx context.Context, uids []string) { } func (group *MemoryGroup) RemoveAllUsers(ctx context.Context) { - group.uids.Range(func(key, value interface{}) bool { - group.uids.Delete(key) - group.lastRefresh = time.Now().Unix() - return true - }) + group.mutex.Lock() + defer group.mutex.Unlock() + group.uids = make(map[string]struct{}) + group.uidsList = nil } func (group *MemoryGroup) GetUsers(ctx context.Context) []string { - uids := []string{} - group.uids.Range(func(key, value interface{}) bool { - uids = append(uids, key.(string)) - return true - }) - - return uids + group.mutex.RLock() + defer group.mutex.RUnlock() + return group.uidsList } func (group *MemoryGroup) GetUserCount(ctx context.Context) int { - count := 0 - group.uids.Range(func(key, value interface{}) bool { - count++ - return true - }) - - return count + return len(group.uidsList) } func (group *MemoryGroup) GetLastRefresh(ctx context.Context) int64 { @@ -83,7 +83,9 @@ func (group *MemoryGroup) GetLastRefresh(ctx context.Context) int64 { } func (group *MemoryGroup) ContainsUser(ctx context.Context, uid string) bool { - _, ok := group.uids.Load(uid) + group.mutex.RLock() + defer group.mutex.RUnlock() + _, ok := group.uids[uid] return ok } diff --git a/group/memory_group_test.go b/group/memory_group_test.go index e587b12..ce9fb26 100644 --- a/group/memory_group_test.go +++ b/group/memory_group_test.go @@ -1,8 +1,10 @@ package group import ( + "sync" "testing" + "github.com/metagogs/gogs/utils/randstr" "github.com/stretchr/testify/assert" ) @@ -64,4 +66,28 @@ func TestMemoryGroup_AddUser(t *testing.T) { exist = memGroup.ContainsUser(nil, "test2") assert.True(t, exist) + var wg sync.WaitGroup + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + memGroup.AddUser(nil, randstr.RandStr(10)) + memGroup.AddUser(nil, "test") + memGroup.GetUsers(nil) + memGroup.GetUserCount(nil) + memGroup.GetGroupID(nil) + memGroup.GetGroupName(nil) + memGroup.ContainsUser(nil, "test") + memGroup.ContainsUser(nil, "test2") + memGroup.RemoveUser(nil, "test") + memGroup.RemoveUsers(nil, []string{"test2"}) + memGroup.RemoveAllUsers(nil) + memGroup.AddUser(nil, randstr.RandStr(10)) + memGroup.AddUser(nil, "test") + memGroup.GetUsers(nil) + memGroup.GetUserCount(nil) + memGroup.RemoveAllUsers(nil) + }() + } + wg.Wait() } diff --git a/session/pool.go b/session/pool.go index 8eaa102..64b56b8 100644 --- a/session/pool.go +++ b/session/pool.go @@ -61,20 +61,17 @@ func (s *sessionList) GetList(filter *SessionFilter) ([]int64, []int64) { } result := []int64{} for _, v := range s.data { - if filter != nil { - if len(filter.ConnType) > 0 && v.ConnType != filter.ConnType { - continue - } - - if len(filter.ConnName) > 0 && v.ConnName != filter.ConnName { - continue - } - - if len(filter.ConnGroup) > 0 && v.ConnGroup != filter.ConnGroup { - continue - } + if len(filter.ConnType) > 0 && v.ConnType != filter.ConnType { + continue } + if len(filter.ConnName) > 0 && v.ConnName != filter.ConnName { + continue + } + + if len(filter.ConnGroup) > 0 && v.ConnGroup != filter.ConnGroup { + continue + } result = append(result, v.SessionID) }