Skip to content

Commit

Permalink
Improvements and optimisations (#12)
Browse files Browse the repository at this point in the history
* Remove unused elements from mock
* Use ticker duration in Reconnect func
* Renaming structs and stoping pointer usage
  • Loading branch information
dlampsi committed May 23, 2021
1 parent 8d9d16a commit 9b85953
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 102 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ if err := cl.Connect(); err != nil {
}

// Search for a user
user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"})
user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"})
if err != nil {
// Handle error
}
Expand All @@ -45,7 +45,7 @@ if user == nil {
fmt.Println(user)

// Search for a group
group, err := cl.GetGroup(&adc.GetGroupequest{Id:"groupId"})
group, err := cl.GetGroup(adc.GetGroupArgs{Id:"groupId"})
if err != nil {
// Handle error
}
Expand Down Expand Up @@ -115,7 +115,7 @@ You can parse custom attributes to client config to fetch those attributes durin
cl.Config().AppendUsesAttributes("manager")

// Search for a user
user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"})
user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"})
if err != nil {
// Handle error
}
Expand All @@ -130,7 +130,7 @@ fmt.Println(userManager)

Also you can parse custom attributes during each get requests:
```go
user, err := cl.GetUser(&adc.GetUserRequest{Id: "userId", Attributes: []string{"manager"}})
user, err := cl.GetUser(adc.GetUserArgs{Id: "userId", Attributes: []string{"manager"}})
if err != nil {
// Handle error
}
Expand Down
7 changes: 4 additions & 3 deletions adc.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (cl *Client) Disconnect() {
}

// Checks connections to AD and tries to reconnect if the connection is lost.
func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempts int) error {
func (cl *Client) Reconnect(ctx context.Context, tickerDuration time.Duration, maxAttempts int) error {
_, connErr := cl.searchEntry(&ldap.SearchRequest{
BaseDN: cl.cfg.SearchBase,
Scope: ldap.ScopeWholeSubtree,
Expand All @@ -125,9 +125,10 @@ func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempt
return nil
}

if ticker == nil {
ticker = time.NewTicker(5 * time.Second)
if tickerDuration == 0 {
tickerDuration = 5 * time.Second
}
ticker := time.NewTicker(tickerDuration)
defer ticker.Stop()

if maxAttempts == 0 {
Expand Down
14 changes: 7 additions & 7 deletions adc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,30 @@ func Test_Client_Reconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
err = cl.Reconnect(ctx, 2*time.Second, 2)
require.NoError(t, err)

cl.cfg.Bind = &BindAccount{DN: mockEntriesData["entryForErr"].DN}
err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
err = cl.Reconnect(ctx, 2*time.Second, 2)
require.Error(t, err)

cl.cfg.Bind = reconnectMockBind
err = cl.Reconnect(ctx, nil, 1)
err = cl.Reconnect(ctx, 0, 1)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 0)
err = cl.Reconnect(ctx, 30*time.Millisecond, 0)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(1*time.Second), 1)
err = cl.Reconnect(ctx, 1*time.Second, 1)
require.Error(t, err)

nctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
err = cl.Reconnect(nctx, time.NewTicker(5*time.Second), 1)
err = cl.Reconnect(nctx, 5*time.Second, 1)
require.Error(t, err)

cl.cfg.Bind = validMockBind
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 1)
err = cl.Reconnect(ctx, 30*time.Millisecond, 1)
require.NoError(t, err)
}
33 changes: 15 additions & 18 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (g *Group) GetStringAttribute(name string) string {
return ""
}

type GetGroupequest struct {
type GetGroupArgs struct {
// Group ID to search.
Id string `json:"id"`
// Optional group DN. Overwrites ID if provided in request.
Expand All @@ -47,24 +47,21 @@ type GetGroupequest struct {
SkipMembersSearch bool `json:"skip_members_search"`
}

func (req *GetGroupequest) Validate() error {
if req == nil {
return errors.New("nil request")
}
if req.Id == "" && req.Dn == "" {
func (args GetGroupArgs) Validate() error {
if args.Id == "" && args.Dn == "" {
return errors.New("neither of ID of DN provided")
}
return nil
}

func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
if err := r.Validate(); err != nil {
func (cl *Client) GetGroup(args GetGroupArgs) (*Group, error) {
if err := args.Validate(); err != nil {
return nil, err
}

filter := fmt.Sprintf(cl.cfg.Groups.FilterById, r.Id)
if r.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(r.Dn))
filter := fmt.Sprintf(cl.cfg.Groups.FilterById, args.Id)
if args.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(args.Dn))
}

req := &ldap.SearchRequest{
Expand All @@ -75,8 +72,8 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
Filter: filter,
Attributes: cl.cfg.Groups.Attributes,
}
if r.Attributes != nil {
req.Attributes = r.Attributes
if args.Attributes != nil {
req.Attributes = args.Attributes
}

entry, err := cl.searchEntry(req)
Expand All @@ -96,7 +93,7 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
result.Attributes[a.Name] = entry.GetAttributeValue(a.Name)
}

if !r.SkipMembersSearch {
if !args.SkipMembersSearch {
members, err := cl.getGroupMembers(entry.DN)
if err != nil {
return nil, fmt.Errorf("can't get group members: %s", err.Error())
Expand Down Expand Up @@ -150,7 +147,7 @@ func (g *Group) MembersId() []string {

// Adds provided accounts IDs to provided group members. Returns number of addedd accounts.
func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, error) {
group, err := cl.GetGroup(&GetGroupequest{Id: groupId})
group, err := cl.GetGroup(GetGroupArgs{Id: groupId})
if err != nil {
return 0, fmt.Errorf("can't get group: %s", err.Error())
}
Expand All @@ -166,7 +163,7 @@ func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, er
wg.Add(1)
go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
user, err := cl.GetUser(&GetUserRequest{Id: userId})
user, err := cl.GetUser(GetUserArgs{Id: userId})
if err != nil {
errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error())
return
Expand Down Expand Up @@ -226,7 +223,7 @@ func popAddGroupMembers(g *Group, toAdd []string) []string {

// Deletes provided accounts IDs from provided group members. Returns number of deleted from group members.
func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int, error) {
group, err := cl.GetGroup(&GetGroupequest{Id: groupId})
group, err := cl.GetGroup(GetGroupArgs{Id: groupId})
if err != nil {
return 0, fmt.Errorf("can't get group: %s", err.Error())
}
Expand All @@ -242,7 +239,7 @@ func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int,
wg.Add(1)
go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
user, err := cl.GetUser(&GetUserRequest{Id: userId})
user, err := cl.GetUser(GetUserArgs{Id: userId})
if err != nil {
errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error())
return
Expand Down
46 changes: 23 additions & 23 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func Test_Group_GetStringAttribute(t *testing.T) {
}

func Test_GetGroupRequest_Validate(t *testing.T) {
var req *GetGroupequest
var req GetGroupArgs
err := req.Validate()
require.Error(t, err)

req = &GetGroupequest{}
req = GetGroupArgs{}
err1 := req.Validate()
require.Error(t, err1)

req = &GetGroupequest{Id: "fake"}
req = GetGroupArgs{Id: "fake"}
errOk := req.Validate()
require.NoError(t, errOk)
}
Expand All @@ -40,21 +40,21 @@ func Test_Client_GetGroup(t *testing.T) {
err := cl.Connect()
require.NoError(t, err)

var badReq *GetGroupequest
_, badReqErr := cl.GetGroup(badReq)
var badArgs GetGroupArgs
_, badReqErr := cl.GetGroup(badArgs)
require.Error(t, badReqErr)

req := &GetGroupequest{Id: "entryForErr", SkipMembersSearch: true}
_, err = cl.GetGroup(req)
args := GetGroupArgs{Id: "entryForErr", SkipMembersSearch: true}
_, err = cl.GetGroup(args)
require.Error(t, err)

req = &GetGroupequest{Id: "groupFake", SkipMembersSearch: true}
group, err := cl.GetGroup(req)
args = GetGroupArgs{Id: "groupFake", SkipMembersSearch: true}
group, err := cl.GetGroup(args)
require.NoError(t, err)
require.Nil(t, group)

// Too many entries error
group, err = cl.GetGroup(&GetGroupequest{
group, err = cl.GetGroup(GetGroupArgs{
Id: "notUniq",
SkipMembersSearch: true,
Attributes: []string{"sAMAccountName"},
Expand All @@ -63,36 +63,36 @@ func Test_Client_GetGroup(t *testing.T) {
require.Nil(t, group)

// Group with err members get
group, err = cl.GetGroup(&GetGroupequest{
group, err = cl.GetGroup(GetGroupArgs{
Id: "groupWithErrMember",
SkipMembersSearch: false,
})
require.Error(t, err)
require.Nil(t, group)

dnReq := &GetGroupequest{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true}
groupByDn, err := cl.GetGroup(dnReq)
dnArgs := GetGroupArgs{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true}
groupByDn, err := cl.GetGroup(dnArgs)
require.NoError(t, err)
require.NotNil(t, groupByDn)
require.Equal(t, dnReq.Dn, groupByDn.DN)
require.Equal(t, dnArgs.Dn, groupByDn.DN)

req = &GetGroupequest{Id: "group1", SkipMembersSearch: true}
group, err = cl.GetGroup(req)
args = GetGroupArgs{Id: "group1", SkipMembersSearch: true}
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)

req.Attributes = []string{"something"}
group, err = cl.GetGroup(req)
args.Attributes = []string{"something"}
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)

req.SkipMembersSearch = false
group, err = cl.GetGroup(req)
args.SkipMembersSearch = false
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)
require.NotNil(t, group.Members)
require.Len(t, group.Members, 1)
}
Expand Down
12 changes: 0 additions & 12 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ var mockEntriesData = mockEntries{

type mockEntries map[string]*ldap.Entry

func (me mockEntries) getEntryById(id string) *ldap.Entry {
if v, ok := me[id]; ok {
return v
}
return nil
}

func (me mockEntries) getEntryByDn(dn string) *ldap.Entry {
for _, entry := range me {
if entry.DN == dn {
Expand Down Expand Up @@ -247,11 +240,6 @@ func (cl *mockClient) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.Passwor
return nil, nil
}

type mockDataEntry struct {
entry *ldap.Entry
filters []string
}

func (cl *mockClient) Search(req *ldap.SearchRequest) (*ldap.SearchResult, error) {
entries, err := mockEntriesData.getEntriesByFilter(req.Filter)
if err != nil {
Expand Down
25 changes: 11 additions & 14 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (u *User) GetStringAttribute(name string) string {
return ""
}

type GetUserRequest struct {
type GetUserArgs struct {
// User ID to search.
Id string `json:"id"`
// Optional User DN. Overwrites ID if provided in request.
Expand All @@ -45,24 +45,21 @@ type GetUserRequest struct {
SkipGroupsSearch bool `json:"skip_groups_search"`
}

func (req *GetUserRequest) Validate() error {
if req == nil {
return errors.New("nil request")
}
if req.Id == "" && req.Dn == "" {
func (args GetUserArgs) Validate() error {
if args.Id == "" && args.Dn == "" {
return errors.New("neither of ID of DN provided")
}
return nil
}

func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
if err := r.Validate(); err != nil {
func (cl *Client) GetUser(args GetUserArgs) (*User, error) {
if err := args.Validate(); err != nil {
return nil, err
}

filter := fmt.Sprintf(cl.cfg.Users.FilterById, r.Id)
if r.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(r.Dn))
filter := fmt.Sprintf(cl.cfg.Users.FilterById, args.Id)
if args.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(args.Dn))
}

req := &ldap.SearchRequest{
Expand All @@ -73,8 +70,8 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
Filter: filter,
Attributes: cl.cfg.Users.Attributes,
}
if r.Attributes != nil {
req.Attributes = r.Attributes
if args.Attributes != nil {
req.Attributes = args.Attributes
}

entry, err := cl.searchEntry(req)
Expand All @@ -94,7 +91,7 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
result.Attributes[a.Name] = entry.GetAttributeValue(a.Name)
}

if !r.SkipGroupsSearch {
if !args.SkipGroupsSearch {
groups, err := cl.getUserGroups(entry.DN)
if err != nil {
return nil, fmt.Errorf("can't get user groups: %s", err.Error())
Expand Down
Loading

0 comments on commit 9b85953

Please sign in to comment.