Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and optimisations #12

Merged
merged 3 commits into from
May 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ if err := cl.Connect(); err != nil {
}

// Search for a user
user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"})
user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"})
if err != nil {
// Handle error
}
Expand All @@ -45,7 +45,7 @@ if user == nil {
fmt.Println(user)

// Search for a group
group, err := cl.GetGroup(&adc.GetGroupequest{Id:"groupId"})
group, err := cl.GetGroup(adc.GetGroupArgs{Id:"groupId"})
if err != nil {
// Handle error
}
Expand Down Expand Up @@ -115,7 +115,7 @@ You can parse custom attributes to client config to fetch those attributes durin
cl.Config().AppendUsesAttributes("manager")

// Search for a user
user, err := cl.GetUser(&adc.GetUserRequest{Id:"userId"})
user, err := cl.GetUser(adc.GetUserArgs{Id:"userId"})
if err != nil {
// Handle error
}
Expand All @@ -130,7 +130,7 @@ fmt.Println(userManager)

Also you can parse custom attributes during each get requests:
```go
user, err := cl.GetUser(&adc.GetUserRequest{Id: "userId", Attributes: []string{"manager"}})
user, err := cl.GetUser(adc.GetUserArgs{Id: "userId", Attributes: []string{"manager"}})
if err != nil {
// Handle error
}
Expand Down
7 changes: 4 additions & 3 deletions adc.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (cl *Client) Disconnect() {
}

// Checks connections to AD and tries to reconnect if the connection is lost.
func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempts int) error {
func (cl *Client) Reconnect(ctx context.Context, tickerDuration time.Duration, maxAttempts int) error {
_, connErr := cl.searchEntry(&ldap.SearchRequest{
BaseDN: cl.cfg.SearchBase,
Scope: ldap.ScopeWholeSubtree,
Expand All @@ -125,9 +125,10 @@ func (cl *Client) Reconnect(ctx context.Context, ticker *time.Ticker, maxAttempt
return nil
}

if ticker == nil {
ticker = time.NewTicker(5 * time.Second)
if tickerDuration == 0 {
tickerDuration = 5 * time.Second
}
ticker := time.NewTicker(tickerDuration)
defer ticker.Stop()

if maxAttempts == 0 {
Expand Down
14 changes: 7 additions & 7 deletions adc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,30 @@ func Test_Client_Reconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
err = cl.Reconnect(ctx, 2*time.Second, 2)
require.NoError(t, err)

cl.cfg.Bind = &BindAccount{DN: mockEntriesData["entryForErr"].DN}
err = cl.Reconnect(ctx, time.NewTicker(2*time.Second), 2)
err = cl.Reconnect(ctx, 2*time.Second, 2)
require.Error(t, err)

cl.cfg.Bind = reconnectMockBind
err = cl.Reconnect(ctx, nil, 1)
err = cl.Reconnect(ctx, 0, 1)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 0)
err = cl.Reconnect(ctx, 30*time.Millisecond, 0)
require.Error(t, err)
err = cl.Reconnect(ctx, time.NewTicker(1*time.Second), 1)
err = cl.Reconnect(ctx, 1*time.Second, 1)
require.Error(t, err)

nctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
err = cl.Reconnect(nctx, time.NewTicker(5*time.Second), 1)
err = cl.Reconnect(nctx, 5*time.Second, 1)
require.Error(t, err)

cl.cfg.Bind = validMockBind
err = cl.Reconnect(ctx, time.NewTicker(30*time.Millisecond), 1)
err = cl.Reconnect(ctx, 30*time.Millisecond, 1)
require.NoError(t, err)
}
33 changes: 15 additions & 18 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (g *Group) GetStringAttribute(name string) string {
return ""
}

type GetGroupequest struct {
type GetGroupArgs struct {
// Group ID to search.
Id string `json:"id"`
// Optional group DN. Overwrites ID if provided in request.
Expand All @@ -47,24 +47,21 @@ type GetGroupequest struct {
SkipMembersSearch bool `json:"skip_members_search"`
}

func (req *GetGroupequest) Validate() error {
if req == nil {
return errors.New("nil request")
}
if req.Id == "" && req.Dn == "" {
func (args GetGroupArgs) Validate() error {
if args.Id == "" && args.Dn == "" {
return errors.New("neither of ID of DN provided")
}
return nil
}

func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
if err := r.Validate(); err != nil {
func (cl *Client) GetGroup(args GetGroupArgs) (*Group, error) {
if err := args.Validate(); err != nil {
return nil, err
}

filter := fmt.Sprintf(cl.cfg.Groups.FilterById, r.Id)
if r.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(r.Dn))
filter := fmt.Sprintf(cl.cfg.Groups.FilterById, args.Id)
if args.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Groups.FilterByDn, ldap.EscapeFilter(args.Dn))
}

req := &ldap.SearchRequest{
Expand All @@ -75,8 +72,8 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
Filter: filter,
Attributes: cl.cfg.Groups.Attributes,
}
if r.Attributes != nil {
req.Attributes = r.Attributes
if args.Attributes != nil {
req.Attributes = args.Attributes
}

entry, err := cl.searchEntry(req)
Expand All @@ -96,7 +93,7 @@ func (cl *Client) GetGroup(r *GetGroupequest) (*Group, error) {
result.Attributes[a.Name] = entry.GetAttributeValue(a.Name)
}

if !r.SkipMembersSearch {
if !args.SkipMembersSearch {
members, err := cl.getGroupMembers(entry.DN)
if err != nil {
return nil, fmt.Errorf("can't get group members: %s", err.Error())
Expand Down Expand Up @@ -150,7 +147,7 @@ func (g *Group) MembersId() []string {

// Adds provided accounts IDs to provided group members. Returns number of addedd accounts.
func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, error) {
group, err := cl.GetGroup(&GetGroupequest{Id: groupId})
group, err := cl.GetGroup(GetGroupArgs{Id: groupId})
if err != nil {
return 0, fmt.Errorf("can't get group: %s", err.Error())
}
Expand All @@ -166,7 +163,7 @@ func (cl *Client) AddGroupMembers(groupId string, membersIds ...string) (int, er
wg.Add(1)
go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
user, err := cl.GetUser(&GetUserRequest{Id: userId})
user, err := cl.GetUser(GetUserArgs{Id: userId})
if err != nil {
errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error())
return
Expand Down Expand Up @@ -226,7 +223,7 @@ func popAddGroupMembers(g *Group, toAdd []string) []string {

// Deletes provided accounts IDs from provided group members. Returns number of deleted from group members.
func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int, error) {
group, err := cl.GetGroup(&GetGroupequest{Id: groupId})
group, err := cl.GetGroup(GetGroupArgs{Id: groupId})
if err != nil {
return 0, fmt.Errorf("can't get group: %s", err.Error())
}
Expand All @@ -242,7 +239,7 @@ func (cl *Client) DeleteGroupMembers(groupId string, membersIds ...string) (int,
wg.Add(1)
go func(userId string, ch chan<- string, errCh chan<- error, wg *sync.WaitGroup) {
defer wg.Done()
user, err := cl.GetUser(&GetUserRequest{Id: userId})
user, err := cl.GetUser(GetUserArgs{Id: userId})
if err != nil {
errCh <- fmt.Errorf("can't get account '%s': %s", userId, err.Error())
return
Expand Down
46 changes: 23 additions & 23 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ func Test_Group_GetStringAttribute(t *testing.T) {
}

func Test_GetGroupRequest_Validate(t *testing.T) {
var req *GetGroupequest
var req GetGroupArgs
err := req.Validate()
require.Error(t, err)

req = &GetGroupequest{}
req = GetGroupArgs{}
err1 := req.Validate()
require.Error(t, err1)

req = &GetGroupequest{Id: "fake"}
req = GetGroupArgs{Id: "fake"}
errOk := req.Validate()
require.NoError(t, errOk)
}
Expand All @@ -40,21 +40,21 @@ func Test_Client_GetGroup(t *testing.T) {
err := cl.Connect()
require.NoError(t, err)

var badReq *GetGroupequest
_, badReqErr := cl.GetGroup(badReq)
var badArgs GetGroupArgs
_, badReqErr := cl.GetGroup(badArgs)
require.Error(t, badReqErr)

req := &GetGroupequest{Id: "entryForErr", SkipMembersSearch: true}
_, err = cl.GetGroup(req)
args := GetGroupArgs{Id: "entryForErr", SkipMembersSearch: true}
_, err = cl.GetGroup(args)
require.Error(t, err)

req = &GetGroupequest{Id: "groupFake", SkipMembersSearch: true}
group, err := cl.GetGroup(req)
args = GetGroupArgs{Id: "groupFake", SkipMembersSearch: true}
group, err := cl.GetGroup(args)
require.NoError(t, err)
require.Nil(t, group)

// Too many entries error
group, err = cl.GetGroup(&GetGroupequest{
group, err = cl.GetGroup(GetGroupArgs{
Id: "notUniq",
SkipMembersSearch: true,
Attributes: []string{"sAMAccountName"},
Expand All @@ -63,36 +63,36 @@ func Test_Client_GetGroup(t *testing.T) {
require.Nil(t, group)

// Group with err members get
group, err = cl.GetGroup(&GetGroupequest{
group, err = cl.GetGroup(GetGroupArgs{
Id: "groupWithErrMember",
SkipMembersSearch: false,
})
require.Error(t, err)
require.Nil(t, group)

dnReq := &GetGroupequest{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true}
groupByDn, err := cl.GetGroup(dnReq)
dnArgs := GetGroupArgs{Dn: "OU=group1,DC=company,DC=com", SkipMembersSearch: true}
groupByDn, err := cl.GetGroup(dnArgs)
require.NoError(t, err)
require.NotNil(t, groupByDn)
require.Equal(t, dnReq.Dn, groupByDn.DN)
require.Equal(t, dnArgs.Dn, groupByDn.DN)

req = &GetGroupequest{Id: "group1", SkipMembersSearch: true}
group, err = cl.GetGroup(req)
args = GetGroupArgs{Id: "group1", SkipMembersSearch: true}
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)

req.Attributes = []string{"something"}
group, err = cl.GetGroup(req)
args.Attributes = []string{"something"}
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)

req.SkipMembersSearch = false
group, err = cl.GetGroup(req)
args.SkipMembersSearch = false
group, err = cl.GetGroup(args)
require.NoError(t, err)
require.NotNil(t, group)
require.Equal(t, req.Id, group.Id)
require.Equal(t, args.Id, group.Id)
require.NotNil(t, group.Members)
require.Len(t, group.Members, 1)
}
Expand Down
12 changes: 0 additions & 12 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ var mockEntriesData = mockEntries{

type mockEntries map[string]*ldap.Entry

func (me mockEntries) getEntryById(id string) *ldap.Entry {
if v, ok := me[id]; ok {
return v
}
return nil
}

func (me mockEntries) getEntryByDn(dn string) *ldap.Entry {
for _, entry := range me {
if entry.DN == dn {
Expand Down Expand Up @@ -247,11 +240,6 @@ func (cl *mockClient) PasswordModify(*ldap.PasswordModifyRequest) (*ldap.Passwor
return nil, nil
}

type mockDataEntry struct {
entry *ldap.Entry
filters []string
}

func (cl *mockClient) Search(req *ldap.SearchRequest) (*ldap.SearchResult, error) {
entries, err := mockEntriesData.getEntriesByFilter(req.Filter)
if err != nil {
Expand Down
25 changes: 11 additions & 14 deletions user.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (u *User) GetStringAttribute(name string) string {
return ""
}

type GetUserRequest struct {
type GetUserArgs struct {
// User ID to search.
Id string `json:"id"`
// Optional User DN. Overwrites ID if provided in request.
Expand All @@ -45,24 +45,21 @@ type GetUserRequest struct {
SkipGroupsSearch bool `json:"skip_groups_search"`
}

func (req *GetUserRequest) Validate() error {
if req == nil {
return errors.New("nil request")
}
if req.Id == "" && req.Dn == "" {
func (args GetUserArgs) Validate() error {
if args.Id == "" && args.Dn == "" {
return errors.New("neither of ID of DN provided")
}
return nil
}

func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
if err := r.Validate(); err != nil {
func (cl *Client) GetUser(args GetUserArgs) (*User, error) {
if err := args.Validate(); err != nil {
return nil, err
}

filter := fmt.Sprintf(cl.cfg.Users.FilterById, r.Id)
if r.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(r.Dn))
filter := fmt.Sprintf(cl.cfg.Users.FilterById, args.Id)
if args.Dn != "" {
filter = fmt.Sprintf(cl.cfg.Users.FilterByDn, ldap.EscapeFilter(args.Dn))
}

req := &ldap.SearchRequest{
Expand All @@ -73,8 +70,8 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
Filter: filter,
Attributes: cl.cfg.Users.Attributes,
}
if r.Attributes != nil {
req.Attributes = r.Attributes
if args.Attributes != nil {
req.Attributes = args.Attributes
}

entry, err := cl.searchEntry(req)
Expand All @@ -94,7 +91,7 @@ func (cl *Client) GetUser(r *GetUserRequest) (*User, error) {
result.Attributes[a.Name] = entry.GetAttributeValue(a.Name)
}

if !r.SkipGroupsSearch {
if !args.SkipGroupsSearch {
groups, err := cl.getUserGroups(entry.DN)
if err != nil {
return nil, fmt.Errorf("can't get user groups: %s", err.Error())
Expand Down
Loading