Skip to content

Commit

Permalink
Adding Reconnect method (#8)
Browse files Browse the repository at this point in the history
* Adding Reconnect method 
* Improve mock client data and tests
  • Loading branch information
dlampsi committed Mar 7, 2021
1 parent de814e7 commit 039c3ac
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 87 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ cl := New(cfg, adc.WithLogger(myCustomLogger))
```


### Reconnect

Client has reconnect method, that validates connection to server and reconnects to it with provided ticker interval and retries attempts count.

Exxample for recconect each 5 secconds with 24 retrie attempts:

```go
err := cl.Reconnect(nctx, time.NewTicker(5*time.Second), 24)
if err != nil {
// Handle error
}
```


## Contributing

1. Create new PR from `main` branch
Expand Down
47 changes: 47 additions & 0 deletions adc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
package adc

import (
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
"time"

Expand Down Expand Up @@ -113,6 +115,51 @@ func (cl *Client) Disconnect() {
}
}

// Checks connections to AD and tries to reconnect if the connection is lost.
func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempts int) error {
_, connErr := cl.searchEntry(&ldap.SearchRequest{
BaseDN: cl.cfg.SearchBase,
Scope: ldap.ScopeWholeSubtree,
DerefAliases: ldap.NeverDerefAliases,
TimeLimit: int(cl.cfg.Timeout.Seconds()),
Filter: fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(cl.cfg.Bind.DN)),
Attributes: []string{cl.cfg.Users.IdAttribute},
})
if connErr == nil {
return nil
}
if !ldap.IsErrorWithCode(connErr, 200) {
return connErr
}

if ticker == nil {
ticker = time.NewTicker(5 * time.Second)
}
defer ticker.Stop()

if maxAttempts == 0 {
maxAttempts = 2
}

attempt := 0
for {
select {
case <-ticker.C:
if attempt >= maxAttempts {
return fmt.Errorf("failed after '%d' attempts. error: %s", attempt, connErr)
}
attempt++
cl.logger.Debugf("Reconnecting to database. Attempt: %d", attempt)
if err := cl.Connect(); err == nil {
cl.logger.Debug("Successfully reconeted to server")
return nil
}
case <-ctx.Done():
return ctx.Err()
}
}
}

// SearchEntry Perfrom search for single ldap entry.
// Returns nil if no entries found.
// Returns 'ErrTooManyEntriesFound' error if entries more that one.
Expand Down
41 changes: 41 additions & 0 deletions adc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package adc

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -68,3 +69,43 @@ func Test_Client_Connect(t *testing.T) {
err = cl.Connect()
require.NoError(t, err)
}

func Test_Client_Reconnect(t *testing.T) {
mock := &mockClient{}
cfg := &Config{
Bind: validMockBind,
}
cl := New(cfg, WithLdapClient(mock))
err := cl.Connect()
require.NoError(t, err)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
require.NoError(t, err)

cl.cfg.Bind = &BindAccount{DN: mockEntriesData["entryForErr"].DN}
err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
require.Error(t, err)

cl.cfg.Bind = reconnectMockBind
err = cl.Reconnect(ctx, nil, 1)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 0)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(1*time.Second), 1)
require.Error(t, err)

nctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
err = cl.Reconnect(nctx, time.NewTicker(5*time.Second), 1)
require.Error(t, err)

cl.cfg.Bind = validMockBind
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 1)
require.NoError(t, err)
}
31 changes: 24 additions & 7 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func Test_Client_GetGroup(t *testing.T) {
_, badReqErr := cl.GetGroup(badReq)
require.Error(t, badReqErr)

req := &GetGroupequest{Id: "group2", SkipMembersSearch: true}
req := &GetGroupequest{Id: "entryForErr", SkipMembersSearch: true}
_, err = cl.GetGroup(req)
require.Error(t, err)

Expand All @@ -54,6 +54,23 @@ func Test_Client_GetGroup(t *testing.T) {
require.NoError(t, err)
require.Nil(t, group)

// Too many entries error
group, err = cl.GetGroup(&GetGroupequest{
Id: "notUniq",
SkipMembersSearch: true,
Attributes: []string{"sAMAccountName"},
})
require.Error(t, err)
require.Nil(t, group)

// Group with err members get
group, err = cl.GetGroup(&GetGroupequest{
Id: "groupWithErrMember",
SkipMembersSearch: false,
})
require.Error(t, err)
require.Nil(t, group)

dnReq := &GetGroupequest{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true}
groupByDn, err := cl.GetGroup(dnReq)
require.NoError(t, err)
Expand Down Expand Up @@ -98,7 +115,7 @@ func Test_popAddGroupMembers(t *testing.T) {
func Test_AddGroupMembers(t *testing.T) {
cl := New(&Config{}, WithLdapClient(&mockClient{}))

_, err := cl.AddGroupMembers("group2", "user1")
_, err := cl.AddGroupMembers("entryForErr", "user1")
require.Error(t, err)

_, err = cl.AddGroupMembers("groupFake", "user1")
Expand All @@ -110,7 +127,7 @@ func Test_AddGroupMembers(t *testing.T) {
require.Equal(t, 0, added)

// Error user
_, err = cl.AddGroupMembers("group1", "user2")
_, err = cl.AddGroupMembers("group1", "entryForErr")
require.Error(t, err)

// Already member user
Expand All @@ -119,7 +136,7 @@ func Test_AddGroupMembers(t *testing.T) {
require.Equal(t, 0, added)

// Ok user
added, err = cl.AddGroupMembers("group1", "user3")
added, err = cl.AddGroupMembers("group1", "userToAdd")
require.NoError(t, err)
require.Equal(t, 1, added)

Expand All @@ -144,7 +161,7 @@ func Test_popDelGroupMembers(t *testing.T) {
func Test_DeleteGroupMembers(t *testing.T) {
cl := New(&Config{}, WithLdapClient(&mockClient{}))

_, err := cl.DeleteGroupMembers("group2", "user1")
_, err := cl.DeleteGroupMembers("entryForErr", "user1")
require.Error(t, err)

_, err = cl.DeleteGroupMembers("groupFake", "user1")
Expand All @@ -156,11 +173,11 @@ func Test_DeleteGroupMembers(t *testing.T) {
require.Equal(t, 0, deleted)

// Error user
_, err = cl.DeleteGroupMembers("group1", "user2")
_, err = cl.DeleteGroupMembers("group1", "entryForErr")
require.Error(t, err)

// Already not member user
deleted, err = cl.DeleteGroupMembers("group1", "user3")
deleted, err = cl.DeleteGroupMembers("group1", "user2")
require.NoError(t, err)
require.Equal(t, 0, deleted)

Expand Down
Loading

0 comments on commit 039c3ac

Please sign in to comment.