From 9b85953d4206e5f0a62ad621ffd50c51fe3ff2ca Mon Sep 17 00:00:00 2001 From: dlampsi <32041193+dlampsi@users.noreply.github.com> Date: Sun, 23 May 2021 12:14:00 +0300 Subject: [PATCH] Improvements and optimisations (#12) * Remove unused elements from mock * Use ticker duration in Reconnect func * Renaming structs and stoping pointer usage --- README.md | 8 ++++---- adc.go | 7 ++++--- adc_test.go | 14 +++++++------- group.go | 33 +++++++++++++++------------------ group_test.go | 46 +++++++++++++++++++++++----------------------- mock.go | 12 ------------ user.go | 25 +++++++++++-------------- user_test.go | 42 +++++++++++++++++++++--------------------- 8 files changed, 85 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index f5daba7..1d1ebb8 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ if err := cl.Connect(); err != nil { } // Search for a user -user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"}) +user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"}) if err != nil { // Handle error } @@ -45,7 +45,7 @@ if user == nil { fmt.Println(user) // Search for a group -group, err := cl.GetGroup(&adc.GetGroupequest{Id:"groupId"}) +group, err := cl.GetGroup(adc.GetGroupArgs{Id:"groupId"}) if err != nil { // Handle error } @@ -115,7 +115,7 @@ You can parse custom attributes to client config to fetch those attributes durin cl.Config().AppendUsesAttributes("manager") // Search for a user -user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"}) +user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"}) if err != nil { // Handle error } @@ -130,7 +130,7 @@ fmt.Println(userManager) Also you can parse custom attributes during each get requests: ```go -user, err := cl.GetUser(&adc.GetUserRequest{Id: "userId", Attributes: []string{"manager"}}) +user, err := cl.GetUser(adc.GetUserArgs{Id: "userId", Attributes: []string{"manager"}}) if err != nil { // Handle error } diff --git a/adc.go b/adc.go index 8ccff27..0c3ef9f 100644 --- a/adc.go +++ b/adc.go @@ -112,7 +112,7 @@ func (cl *Client) Disconnect() { } // Checks connections to AD and tries to reconnect if the connection is lost. -func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempts int) error { +func (cl *Client) Reconnect(ctx context.Context, tickerDuration time.Duration, maxAttempts int) error { _, connErr := cl.searchEntry(&ldap.SearchRequest{ BaseDN: cl.cfg.SearchBase, Scope: ldap.ScopeWholeSubtree, @@ -125,9 +125,10 @@ func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempt return nil } - if ticker == nil { - ticker = time.NewTicker(5 * time.Second) + if tickerDuration == 0 { + tickerDuration = 5 * time.Second } + ticker := time.NewTicker(tickerDuration) defer ticker.Stop() if maxAttempts == 0 { diff --git a/adc_test.go b/adc_test.go index c7834af..7eed882 100644 --- a/adc_test.go +++ b/adc_test.go @@ -80,19 +80,19 @@ func Test_Client_Reconnect(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2) + err = cl.Reconnect(ctx, 2*time.Second, 2) require.NoError(t, err) cl.cfg.Bind = &BindAccount{DN: mockEntriesData["entryForErr"].DN} - err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2) + err = cl.Reconnect(ctx, 2*time.Second, 2) require.Error(t, err) cl.cfg.Bind = reconnectMockBind - err = cl.Reconnect(ctx, nil, 1) + err = cl.Reconnect(ctx, 0, 1) require.Error(t, err) - err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 0) + err = cl.Reconnect(ctx, 30*time.Millisecond, 0) require.Error(t, err) - err = cl.Reconnect(ctx, time.NewTicker(1*time.Second), 1) + err = cl.Reconnect(ctx, 1*time.Second, 1) require.Error(t, err) nctx, cancel := context.WithCancel(context.Background()) @@ -100,10 +100,10 @@ func Test_Client_Reconnect(t *testing.T) { time.Sleep(10 * time.Millisecond) cancel() }() - err = cl.Reconnect(nctx, time.NewTicker(5*time.Second), 1) + err = cl.Reconnect(nctx, 5*time.Second, 1) require.Error(t, err) cl.cfg.Bind = validMockBind - err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 1) + err = cl.Reconnect(ctx, 30*time.Millisecond, 1) require.NoError(t, err) } diff --git a/group.go b/group.go index c00a0ef..9b183fd 100644 --- a/group.go +++ b/group.go @@ -36,7 +36,7 @@ func (g *Group) GetStringAttribute(name string) string { return "" } -type GetGroupequest struct { +type GetGroupArgs struct { // Group ID to search. Id string `json:"id"` // Optional group DN. Overwrites ID if provided in request. @@ -47,24 +47,21 @@ type GetGroupequest struct { SkipMembersSearch bool `json:"skip_members_search"` } -func (req *GetGroupequest) Validate() error { - if req == nil { - return errors.New("nil request") - } - if req.Id == "" && req.Dn == "" { +func (args GetGroupArgs) Validate() error { + if args.Id == "" && args.Dn == "" { return errors.New("neither of ID of DN provided") } return nil } -func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) { - if err := r.Validate(); err != nil { +func (cl *Client) GetGroup(args GetGroupArgs) (*Group, error) { + if err := args.Validate(); err != nil { return nil, err } - filter := fmt.Sprintf(cl.cfg.Groups.FilterById, r.Id) - if r.Dn != "" { - filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(r.Dn)) + filter := fmt.Sprintf(cl.cfg.Groups.FilterById, args.Id) + if args.Dn != "" { + filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(args.Dn)) } req := &ldap.SearchRequest{ @@ -75,8 +72,8 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) { Filter: filter, Attributes: cl.cfg.Groups.Attributes, } - if r.Attributes != nil { - req.Attributes = r.Attributes + if args.Attributes != nil { + req.Attributes = args.Attributes } entry, err := cl.searchEntry(req) @@ -96,7 +93,7 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) { result.Attributes[a.Name] = entry.GetAttributeValue(a.Name) } - if !r.SkipMembersSearch { + if !args.SkipMembersSearch { members, err := cl.getGroupMembers(entry.DN) if err != nil { return nil, fmt.Errorf("can't get group members: %s", err.Error()) @@ -150,7 +147,7 @@ func (g *Group) MembersId() []string { // Adds provided accounts IDs to provided group members. Returns number of addedd accounts. func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, error) { - group, err := cl.GetGroup(&GetGroupequest{Id: groupId}) + group, err := cl.GetGroup(GetGroupArgs{Id: groupId}) if err != nil { return 0, fmt.Errorf("can't get group: %s", err.Error()) } @@ -166,7 +163,7 @@ func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, er wg.Add(1) go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) { defer wg.Done() - user, err := cl.GetUser(&GetUserRequest{Id: userId}) + user, err := cl.GetUser(GetUserArgs{Id: userId}) if err != nil { errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error()) return @@ -226,7 +223,7 @@ func popAddGroupMembers(g *Group, toAdd []string) []string { // Deletes provided accounts IDs from provided group members. Returns number of deleted from group members. func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int, error) { - group, err := cl.GetGroup(&GetGroupequest{Id: groupId}) + group, err := cl.GetGroup(GetGroupArgs{Id: groupId}) if err != nil { return 0, fmt.Errorf("can't get group: %s", err.Error()) } @@ -242,7 +239,7 @@ func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int, wg.Add(1) go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) { defer wg.Done() - user, err := cl.GetUser(&GetUserRequest{Id: userId}) + user, err := cl.GetUser(GetUserArgs{Id: userId}) if err != nil { errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error()) return diff --git a/group_test.go b/group_test.go index 360c4df..d28e57b 100644 --- a/group_test.go +++ b/group_test.go @@ -22,15 +22,15 @@ func Test_Group_GetStringAttribute(t *testing.T) { } func Test_GetGroupRequest_Validate(t *testing.T) { - var req *GetGroupequest + var req GetGroupArgs err := req.Validate() require.Error(t, err) - req = &GetGroupequest{} + req = GetGroupArgs{} err1 := req.Validate() require.Error(t, err1) - req = &GetGroupequest{Id: "fake"} + req = GetGroupArgs{Id: "fake"} errOk := req.Validate() require.NoError(t, errOk) } @@ -40,21 +40,21 @@ func Test_Client_GetGroup(t *testing.T) { err := cl.Connect() require.NoError(t, err) - var badReq *GetGroupequest - _, badReqErr := cl.GetGroup(badReq) + var badArgs GetGroupArgs + _, badReqErr := cl.GetGroup(badArgs) require.Error(t, badReqErr) - req := &GetGroupequest{Id: "entryForErr", SkipMembersSearch: true} - _, err = cl.GetGroup(req) + args := GetGroupArgs{Id: "entryForErr", SkipMembersSearch: true} + _, err = cl.GetGroup(args) require.Error(t, err) - req = &GetGroupequest{Id: "groupFake", SkipMembersSearch: true} - group, err := cl.GetGroup(req) + args = GetGroupArgs{Id: "groupFake", SkipMembersSearch: true} + group, err := cl.GetGroup(args) require.NoError(t, err) require.Nil(t, group) // Too many entries error - group, err = cl.GetGroup(&GetGroupequest{ + group, err = cl.GetGroup(GetGroupArgs{ Id: "notUniq", SkipMembersSearch: true, Attributes: []string{"sAMAccountName"}, @@ -63,36 +63,36 @@ func Test_Client_GetGroup(t *testing.T) { require.Nil(t, group) // Group with err members get - group, err = cl.GetGroup(&GetGroupequest{ + group, err = cl.GetGroup(GetGroupArgs{ Id: "groupWithErrMember", SkipMembersSearch: false, }) require.Error(t, err) require.Nil(t, group) - dnReq := &GetGroupequest{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true} - groupByDn, err := cl.GetGroup(dnReq) + dnArgs := GetGroupArgs{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true} + groupByDn, err := cl.GetGroup(dnArgs) require.NoError(t, err) require.NotNil(t, groupByDn) - require.Equal(t, dnReq.Dn, groupByDn.DN) + require.Equal(t, dnArgs.Dn, groupByDn.DN) - req = &GetGroupequest{Id: "group1", SkipMembersSearch: true} - group, err = cl.GetGroup(req) + args = GetGroupArgs{Id: "group1", SkipMembersSearch: true} + group, err = cl.GetGroup(args) require.NoError(t, err) require.NotNil(t, group) - require.Equal(t, req.Id, group.Id) + require.Equal(t, args.Id, group.Id) - req.Attributes = []string{"something"} - group, err = cl.GetGroup(req) + args.Attributes = []string{"something"} + group, err = cl.GetGroup(args) require.NoError(t, err) require.NotNil(t, group) - require.Equal(t, req.Id, group.Id) + require.Equal(t, args.Id, group.Id) - req.SkipMembersSearch = false - group, err = cl.GetGroup(req) + args.SkipMembersSearch = false + group, err = cl.GetGroup(args) require.NoError(t, err) require.NotNil(t, group) - require.Equal(t, req.Id, group.Id) + require.Equal(t, args.Id, group.Id) require.NotNil(t, group.Members) require.Len(t, group.Members, 1) } diff --git a/mock.go b/mock.go index ddf5b88..f68e6bb 100644 --- a/mock.go +++ b/mock.go @@ -131,13 +131,6 @@ var mockEntriesData = mockEntries{ type mockEntries map[string]*ldap.Entry -func (me mockEntries) getEntryById(id string) *ldap.Entry { - if v, ok := me[id]; ok { - return v - } - return nil -} - func (me mockEntries) getEntryByDn(dn string) *ldap.Entry { for _, entry := range me { if entry.DN == dn { @@ -247,11 +240,6 @@ func (cl *mockClient) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.Passwor return nil, nil } -type mockDataEntry struct { - entry *ldap.Entry - filters []string -} - func (cl *mockClient) Search(req *ldap.SearchRequest) (*ldap.SearchResult, error) { entries, err := mockEntriesData.getEntriesByFilter(req.Filter) if err != nil { diff --git a/user.go b/user.go index d75c418..eb16df2 100644 --- a/user.go +++ b/user.go @@ -34,7 +34,7 @@ func (u *User) GetStringAttribute(name string) string { return "" } -type GetUserRequest struct { +type GetUserArgs struct { // User ID to search. Id string `json:"id"` // Optional User DN. Overwrites ID if provided in request. @@ -45,24 +45,21 @@ type GetUserRequest struct { SkipGroupsSearch bool `json:"skip_groups_search"` } -func (req *GetUserRequest) Validate() error { - if req == nil { - return errors.New("nil request") - } - if req.Id == "" && req.Dn == "" { +func (args GetUserArgs) Validate() error { + if args.Id == "" && args.Dn == "" { return errors.New("neither of ID of DN provided") } return nil } -func (cl *Client) GetUser(r *GetUserRequest) (*User, error) { - if err := r.Validate(); err != nil { +func (cl *Client) GetUser(args GetUserArgs) (*User, error) { + if err := args.Validate(); err != nil { return nil, err } - filter := fmt.Sprintf(cl.cfg.Users.FilterById, r.Id) - if r.Dn != "" { - filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(r.Dn)) + filter := fmt.Sprintf(cl.cfg.Users.FilterById, args.Id) + if args.Dn != "" { + filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(args.Dn)) } req := &ldap.SearchRequest{ @@ -73,8 +70,8 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) { Filter: filter, Attributes: cl.cfg.Users.Attributes, } - if r.Attributes != nil { - req.Attributes = r.Attributes + if args.Attributes != nil { + req.Attributes = args.Attributes } entry, err := cl.searchEntry(req) @@ -94,7 +91,7 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) { result.Attributes[a.Name] = entry.GetAttributeValue(a.Name) } - if !r.SkipGroupsSearch { + if !args.SkipGroupsSearch { groups, err := cl.getUserGroups(entry.DN) if err != nil { return nil, fmt.Errorf("can't get user groups: %s", err.Error()) diff --git a/user_test.go b/user_test.go index 3b68e88..5a167f4 100644 --- a/user_test.go +++ b/user_test.go @@ -21,16 +21,16 @@ func Test_User_GetStringAttribute(t *testing.T) { require.Empty(t, u.GetStringAttribute("nonexists")) } -func Test_GetUserRequest_Validate(t *testing.T) { - var req *GetUserRequest +func Test_GetUserArgs_Validate(t *testing.T) { + var req GetUserArgs err := req.Validate() require.Error(t, err) - req = &GetUserRequest{} + req = GetUserArgs{} err1 := req.Validate() require.Error(t, err1) - req = &GetUserRequest{Id: "fake"} + req = GetUserArgs{Id: "fake"} errOk := req.Validate() require.NoError(t, errOk) } @@ -40,21 +40,21 @@ func Test_Client_GetUser(t *testing.T) { err := cl.Connect() require.NoError(t, err) - var badReq *GetUserRequest - _, badReqErr := cl.GetUser(badReq) + var badArgs GetUserArgs + _, badReqErr := cl.GetUser(badArgs) require.Error(t, badReqErr) - req := &GetUserRequest{Id: "entryForErr", SkipGroupsSearch: true} - _, err = cl.GetUser(req) + args := GetUserArgs{Id: "entryForErr", SkipGroupsSearch: true} + _, err = cl.GetUser(args) require.Error(t, err) - req = &GetUserRequest{Id: "userFake", SkipGroupsSearch: true} - user, err := cl.GetUser(req) + args = GetUserArgs{Id: "userFake", SkipGroupsSearch: true} + user, err := cl.GetUser(args) require.NoError(t, err) require.Nil(t, user) // Too many entries error - user, err = cl.GetUser(&GetUserRequest{ + user, err = cl.GetUser(GetUserArgs{ Id: "notUniq", SkipGroupsSearch: true, Attributes: []string{"sAMAccountName"}, @@ -62,31 +62,31 @@ func Test_Client_GetUser(t *testing.T) { require.Error(t, err) require.Nil(t, user) - dnReq := &GetUserRequest{Dn: "OU=user1,DC=company,DC=com", SkipGroupsSearch: true} + dnReq := GetUserArgs{Dn: "OU=user1,DC=company,DC=com", SkipGroupsSearch: true} groupByDn, err := cl.GetUser(dnReq) require.NoError(t, err) require.NotNil(t, groupByDn) require.Equal(t, dnReq.Dn, groupByDn.DN) - req = &GetUserRequest{Id: "user1", SkipGroupsSearch: true} - user, err = cl.GetUser(req) + args = GetUserArgs{Id: "user1", SkipGroupsSearch: true} + user, err = cl.GetUser(args) require.NoError(t, err) require.NotNil(t, user) - require.Equal(t, req.Id, user.Id) + require.Equal(t, args.Id, user.Id) require.Nil(t, user.Groups) - req.Attributes = []string{"something"} - user, err = cl.GetUser(req) + args.Attributes = []string{"something"} + user, err = cl.GetUser(args) require.NoError(t, err) require.NotNil(t, user) - require.Equal(t, req.Id, user.Id) + require.Equal(t, args.Id, user.Id) require.Nil(t, user.Groups) - req.SkipGroupsSearch = false - user, err = cl.GetUser(req) + args.SkipGroupsSearch = false + user, err = cl.GetUser(args) require.NoError(t, err) require.NotNil(t, user) - require.Equal(t, req.Id, user.Id) + require.Equal(t, args.Id, user.Id) require.NotNil(t, user.Groups) require.Len(t, user.Groups, 1) }