From d41eb0bdf9199ddf2f6e27c0b419bce65e09e0c8 Mon Sep 17 00:00:00 2001 From: Emil Tullstedt Date: Tue, 14 Sep 2021 10:49:37 +0200 Subject: [PATCH] LDAP: Search all DNs for users (#38891) (cherry picked from commit ad971cc9beb78251ea7edda17c24b1de49d6099a) --- pkg/services/ldap/ldap.go | 47 +-- pkg/services/ldap/ldap_helpers_test.go | 242 ++++++------- pkg/services/ldap/ldap_login_test.go | 438 ++++++++++++------------ pkg/services/ldap/ldap_private_test.go | 453 ++++++++++++------------ pkg/services/ldap/ldap_test.go | 455 +++++++++++++++---------- pkg/services/ldap/testing.go | 23 +- 6 files changed, 845 insertions(+), 813 deletions(-) diff --git a/pkg/services/ldap/ldap.go b/pkg/services/ldap/ldap.go index 6caf5494205c..1ad3fa159055 100644 --- a/pkg/services/ldap/ldap.go +++ b/pkg/services/ldap/ldap.go @@ -12,9 +12,10 @@ import ( "strings" "github.com/davecgh/go-spew/spew" + "gopkg.in/ldap.v3" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" - "gopkg.in/ldap.v3" ) // IConnection is interface for LDAP connection manipulation @@ -252,16 +253,11 @@ func (server *Server) Users(logins []string) ( []*models.ExternalUserInfo, error, ) { - var users []*ldap.Entry + var users [][]*ldap.Entry err := getUsersIteration(logins, func(previous, current int) error { - entries, err := server.users(logins[previous:current]) - if err != nil { - return err - } - - users = append(users, entries...) - - return nil + var err error + users, err = server.users(logins[previous:current]) + return err }) if err != nil { return nil, err @@ -308,13 +304,15 @@ func getUsersIteration(logins []string, fn func(int, int) error) error { // users is helper method for the Users() func (server *Server) users(logins []string) ( - []*ldap.Entry, + [][]*ldap.Entry, error, ) { var result *ldap.SearchResult var Config = server.Config var err error + var entries = make([][]*ldap.Entry, 0, len(Config.SearchBaseDNs)) + for _, base := range Config.SearchBaseDNs { result, err = server.Connection.Search( server.getSearchRequest(base, logins), @@ -324,11 +322,11 @@ func (server *Server) users(logins []string) ( } if len(result.Entries) > 0 { - break + entries = append(entries, result.Entries) } } - return result.Entries, nil + return entries, nil } // validateGrafanaUser validates user access. @@ -557,17 +555,26 @@ func (server *Server) requestMemberOf(entry *ldap.Entry) ([]string, error) { // serializeUsers serializes the users // from LDAP result to ExternalInfo struct func (server *Server) serializeUsers( - entries []*ldap.Entry, + entries [][]*ldap.Entry, ) ([]*models.ExternalUserInfo, error) { var serialized []*models.ExternalUserInfo + var users = map[string]struct{}{} - for _, user := range entries { - extUser, err := server.buildGrafanaUser(user) - if err != nil { - return nil, err - } + for _, dn := range entries { + for _, user := range dn { + extUser, err := server.buildGrafanaUser(user) + if err != nil { + return nil, err + } + + if _, exists := users[extUser.Login]; exists { + // ignore duplicates + continue + } + users[extUser.Login] = struct{}{} - serialized = append(serialized, extUser) + serialized = append(serialized, extUser) + } } return serialized, nil diff --git a/pkg/services/ldap/ldap_helpers_test.go b/pkg/services/ldap/ldap_helpers_test.go index 5062623d546f..e276917c973f 100644 --- a/pkg/services/ldap/ldap_helpers_test.go +++ b/pkg/services/ldap/ldap_helpers_test.go @@ -1,191 +1,141 @@ package ldap import ( + "fmt" "testing" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" "gopkg.in/ldap.v3" ) -func TestLDAPHelpers(t *testing.T) { - Convey("isMemberOf()", t, func() { - Convey("Wildcard", func() { - result := isMemberOf([]string{}, "*") - So(result, ShouldBeTrue) +func TestIsMemberOf(t *testing.T) { + tests := []struct { + memberOf []string + group string + expected bool + }{ + {memberOf: []string{}, group: "*", expected: true}, + {memberOf: []string{"one", "Two", "three"}, group: "two", expected: true}, + {memberOf: []string{"one", "Two", "three"}, group: "twos", expected: false}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("isMemberOf(%v, \"%s\") = %v", tc.memberOf, tc.group, tc.expected), func(t *testing.T) { + assert.Equal(t, tc.expected, isMemberOf(tc.memberOf, tc.group)) }) + } +} - Convey("Should find one", func() { - result := isMemberOf([]string{"one", "Two", "three"}, "two") - So(result, ShouldBeTrue) - }) +func TestGetUsersIteration(t *testing.T) { + const pageSize = UsersMaxRequest + iterations := map[int]int{ + 0: 0, + 400: 1, + 600: 2, + 1500: 3, + } - Convey("Should not find one", func() { - result := isMemberOf([]string{"one", "Two", "three"}, "twos") - So(result, ShouldBeFalse) - }) - }) + for userCount, expectedIterations := range iterations { + t.Run(fmt.Sprintf("getUserIteration iterates %d times for %d users", expectedIterations, userCount), func(t *testing.T) { + logins := make([]string, userCount) - Convey("getUsersIteration()", t, func() { - Convey("it should execute twice for 600 users", func() { - logins := make([]string, 600) i := 0 + _ = getUsersIteration(logins, func(first int, last int) error { + assert.Equal(t, pageSize*i, first) - result := getUsersIteration(logins, func(previous, current int) error { - i++ - - if i == 1 { - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 500) - } else { - So(previous, ShouldEqual, 500) - So(current, ShouldEqual, 600) + expectedLast := pageSize*i + pageSize + if expectedLast > userCount { + expectedLast = userCount } - return nil - }) - - So(i, ShouldEqual, 2) - So(result, ShouldBeNil) - }) - - Convey("it should execute three times for 1500 users", func() { - logins := make([]string, 1500) - i := 0 + assert.Equal(t, expectedLast, last) - result := getUsersIteration(logins, func(previous, current int) error { i++ - switch i { - case 1: - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 500) - case 2: - So(previous, ShouldEqual, 500) - So(current, ShouldEqual, 1000) - default: - So(previous, ShouldEqual, 1000) - So(current, ShouldEqual, 1500) - } - - return nil - }) - - So(i, ShouldEqual, 3) - So(result, ShouldBeNil) - }) - - Convey("it should execute once for 400 users", func() { - logins := make([]string, 400) - i := 0 - - result := getUsersIteration(logins, func(previous, current int) error { - i++ - if i == 1 { - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 400) - } - return nil }) - So(i, ShouldEqual, 1) - So(result, ShouldBeNil) + assert.Equal(t, expectedIterations, i) }) + } +} - Convey("it should not execute for 0 users", func() { - logins := make([]string, 0) - i := 0 - - result := getUsersIteration(logins, func(previous, current int) error { - i++ - return nil - }) +func TestGetAttribute(t *testing.T) { + t.Run("DN", func(t *testing.T) { + entry := &ldap.Entry{ + DN: "test", + } - So(i, ShouldEqual, 0) - So(result, ShouldBeNil) - }) + result := getAttribute("dn", entry) + assert.Equal(t, "test", result) }) - Convey("getAttribute()", t, func() { - Convey("Should get DN", func() { - entry := &ldap.Entry{ - DN: "test", - } - - result := getAttribute("dn", entry) - - So(result, ShouldEqual, "test") - }) - - Convey("Should get username", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, + t.Run("username", func(t *testing.T) { + value := "roelgerrits" + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: []string{value}, }, - } + }, + } - result := getAttribute("username", entry) - - So(result, ShouldEqual, value[0]) - }) + result := getAttribute("username", entry) + assert.Equal(t, value, result) + }) - Convey("Should not get anything", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "killa", Values: value, - }, + t.Run("no result", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "killa", Values: value, }, - } + }, + } - result := getAttribute("username", entry) - - So(result, ShouldEqual, "") - }) + result := getAttribute("username", entry) + assert.Empty(t, result) }) +} - Convey("getArrayAttribute()", t, func() { - Convey("Should get DN", func() { - entry := &ldap.Entry{ - DN: "test", - } +func TestGetArrayAttribute(t *testing.T) { + t.Run("DN", func(t *testing.T) { + entry := &ldap.Entry{ + DN: "test", + } - result := getArrayAttribute("dn", entry) + result := getArrayAttribute("dn", entry) - So(result, ShouldResemble, []string{"test"}) - }) + assert.EqualValues(t, []string{"test"}, result) + }) - Convey("Should get username", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, + t.Run("username", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: value, }, - } + }, + } - result := getArrayAttribute("username", entry) + result := getArrayAttribute("username", entry) - So(result, ShouldResemble, value) - }) + assert.EqualValues(t, value, result) + }) - Convey("Should not get anything", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, + t.Run("no result", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: value, }, - } + }, + } - result := getArrayAttribute("something", entry) + result := getArrayAttribute("something", entry) - So(result, ShouldResemble, []string{}) - }) + assert.Empty(t, result) }) } diff --git a/pkg/services/ldap/ldap_login_test.go b/pkg/services/ldap/ldap_login_test.go index dea64fab48c6..7b552a8edfaa 100644 --- a/pkg/services/ldap/ldap_login_test.go +++ b/pkg/services/ldap/ldap_login_test.go @@ -4,231 +4,227 @@ import ( "errors" "testing" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/ldap.v3" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" ) -func TestLDAPLogin(t *testing.T) { - defaultLogin := &models.LoginUserQuery{ - Username: "user", - Password: "pwd", - IpAddress: "192.168.1.1:56433", - } - - Convey("Login()", t, func() { - Convey("Should get invalid credentials when userBind fails", func() { - connection := &MockConnection{} - entry := ldap.Entry{} - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - connection.BindProvider = func(username, password string) error { - return &ldap.Error{ - ResultCode: 49, - } - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, ErrInvalidCredentials) - }) - - Convey("Returns an error when search didn't find anything", func() { - connection := &MockConnection{} - result := ldap.SearchResult{Entries: []*ldap.Entry{}} - connection.setSearchResult(&result) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, ErrCouldNotFindUser) - }) - - Convey("When search returns an error", func() { - connection := &MockConnection{} - expected := errors.New("Killa-gorilla") - connection.setSearchError(expected) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, expected) - }) - - Convey("When login with valid credentials", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "dn", Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"markelog"}}, - {Name: "surname", Values: []string{"Gaidarenko"}}, - {Name: "email", Values: []string{"markelog@gmail.com"}}, - {Name: "name", Values: []string{"Oleg"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - resp, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - So(resp.Login, ShouldEqual, "markelog") - }) - - Convey("Should perform unauthenticated bind without admin", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - connection.UnauthenticatedBindProvider = func() error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - user, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - So(user.AuthId, ShouldEqual, "test") - So(connection.UnauthenticatedBindCalled, ShouldBeTrue) - }) - - Convey("Should perform authenticated binds", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - adminUsername := "" - adminPassword := "" - username := "" - password := "" - - i := 0 - connection.BindProvider = func(name, pass string) error { - i++ - if i == 1 { - adminUsername = name - adminPassword = pass - } - - if i == 2 { - username = name - password = pass - } - - return nil - } - server := &Server{ - Config: &ServerConfig{ - BindDN: "killa", - BindPassword: "gorilla", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - user, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - - So(user.AuthId, ShouldEqual, "test") - So(connection.BindCalled, ShouldBeTrue) - - So(adminUsername, ShouldEqual, "killa") - So(adminPassword, ShouldEqual, "gorilla") - - So(username, ShouldEqual, "test") - So(password, ShouldEqual, "pwd") - }) - Convey("Should bind with user if %s exists in the bind_dn", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - connection.setSearchResult(&ldap.SearchResult{Entries: []*ldap.Entry{&entry}}) - - authBindUser := "" - authBindPassword := "" - - connection.BindProvider = func(name, pass string) error { - authBindUser = name - authBindPassword = pass - return nil - } - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,ou=users,dc=grafana,dc=org", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - - So(authBindUser, ShouldEqual, "cn=user,ou=users,dc=grafana,dc=org") - So(authBindPassword, ShouldEqual, "pwd") - So(connection.BindCalled, ShouldBeTrue) - }) - }) +var defaultLogin = &models.LoginUserQuery{ + Username: "user", + Password: "pwd", + IpAddress: "192.168.1.1:56433", +} + +func TestServer_Login_UserBind_Fail(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{} + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return &ldap.Error{ + ResultCode: 49, + } + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + + assert.ErrorIs(t, err, ErrInvalidCredentials) +} + +func TestServer_Login_Search_NoResult(t *testing.T) { + connection := &MockConnection{} + result := ldap.SearchResult{Entries: []*ldap.Entry{}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + assert.ErrorIs(t, err, ErrCouldNotFindUser) +} + +func TestServer_Login_Search_Error(t *testing.T) { + connection := &MockConnection{} + expected := errors.New("Killa-gorilla") + connection.setSearchError(expected) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + assert.ErrorIs(t, err, expected) +} + +func TestServer_Login_ValidCredentials(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"markelog"}}, + {Name: "surname", Values: []string{"Gaidarenko"}}, + {Name: "email", Values: []string{"markelog@gmail.com"}}, + {Name: "name", Values: []string{"Oleg"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + }, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + resp, err := server.Login(defaultLogin) + require.NoError(t, err) + assert.Equal(t, "markelog", resp.Login) +} + +// TestServer_Login_UnauthenticatedBind tests that unauthenticated bind +// is called when there is no admin password or user wildcard in the +// bind_dn. +func TestServer_Login_UnauthenticatedBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.UnauthenticatedBindProvider = func() error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + user, err := server.Login(defaultLogin) + require.NoError(t, err) + assert.Equal(t, "test", user.AuthId) + assert.True(t, connection.UnauthenticatedBindCalled) +} + +func TestServer_Login_AuthenticatedBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + adminUsername := "" + adminPassword := "" + username := "" + password := "" + + i := 0 + connection.BindProvider = func(name, pass string) error { + i++ + if i == 1 { + adminUsername = name + adminPassword = pass + } + + if i == 2 { + username = name + password = pass + } + + return nil + } + server := &Server{ + Config: &ServerConfig{ + BindDN: "killa", + BindPassword: "gorilla", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + user, err := server.Login(defaultLogin) + require.NoError(t, err) + + assert.Equal(t, "test", user.AuthId) + assert.True(t, connection.BindCalled) + + assert.Equal(t, "killa", adminUsername) + assert.Equal(t, "gorilla", adminPassword) + + assert.Equal(t, "test", username) + assert.Equal(t, "pwd", password) +} + +func TestServer_Login_UserWildcardBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + connection.setSearchResult(&ldap.SearchResult{Entries: []*ldap.Entry{&entry}}) + + authBindUser := "" + authBindPassword := "" + + connection.BindProvider = func(name, pass string) error { + authBindUser = name + authBindPassword = pass + return nil + } + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=%s,ou=users,dc=grafana,dc=org", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + require.NoError(t, err) + + assert.Equal(t, "cn=user,ou=users,dc=grafana,dc=org", authBindUser) + assert.Equal(t, "pwd", authBindPassword) + assert.True(t, connection.BindCalled) } diff --git a/pkg/services/ldap/ldap_private_test.go b/pkg/services/ldap/ldap_private_test.go index 431f94f0d94a..d4d0f1c238d6 100644 --- a/pkg/services/ldap/ldap_private_test.go +++ b/pkg/services/ldap/ldap_private_test.go @@ -3,271 +3,252 @@ package ldap import ( "testing" + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" + + "gopkg.in/ldap.v3" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" - . "github.com/smartystreets/goconvey/convey" - "gopkg.in/ldap.v3" ) -func TestLDAPPrivateMethods(t *testing.T) { - Convey("getSearchRequest()", t, func() { - Convey("with enabled GroupSearchFilterUserAttribute setting", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - GroupSearchFilterUserAttribute: "gansta", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - log: log.New("test-logger"), - } - - result := server.getSearchRequest("killa", []string{"gorilla"}) - - So(result, ShouldResemble, &ldap.SearchRequest{ - BaseDN: "killa", - Scope: 2, - DerefAliases: 0, - SizeLimit: 0, - TimeLimit: 0, - TypesOnly: false, - Filter: "(|)", - Attributes: []string{ - "username", - "email", - "name", - "memberof", - "gansta", - }, - Controls: nil, - }) - }) +func TestServer_getSearchRequest(t *testing.T) { + expected := &ldap.SearchRequest{ + BaseDN: "killa", + Scope: 2, + DerefAliases: 0, + SizeLimit: 0, + TimeLimit: 0, + TypesOnly: false, + Filter: "(|)", + Attributes: []string{ + "username", + "email", + "name", + "memberof", + "gansta", + }, + Controls: nil, + } + + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", + }, + GroupSearchFilterUserAttribute: "gansta", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + log: log.New("test-logger"), + } + + result := server.getSearchRequest("killa", []string{"gorilla"}) + + assert.EqualValues(t, expected, result) +} + +func TestSerializeUsers(t *testing.T) { + t.Run("simple case", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", + }, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } + + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "surname", Values: []string{"Gerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} + + result, err := server.serializeUsers(users) + require.NoError(t, err) + + assert.Equal(t, "roelgerrits", result[0].Login) + assert.Equal(t, "roel@test.com", result[0].Email) + assert.Contains(t, result[0].Groups, "admins") }) - Convey("serializeUsers()", t, func() { - Convey("simple case", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } - - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "surname", Values: []string{"Gerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} - - result, err := server.serializeUsers(users) - - So(err, ShouldBeNil) - So(result[0].Login, ShouldEqual, "roelgerrits") - So(result[0].Email, ShouldEqual, "roel@test.com") - So(result[0].Groups, ShouldContain, "admins") - }) - - Convey("without lastname", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } - - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} - - result, err := server.serializeUsers(users) - - So(err, ShouldBeNil) - So(result[0].IsDisabled, ShouldBeFalse) - So(result[0].Name, ShouldEqual, "Roel") - }) - - Convey("a user without matching groups should be marked as disabled", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{{ - GroupDN: "foo", - OrgId: 1, - OrgRole: models.ROLE_EDITOR, - }}, - }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } - - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} + t.Run("without lastname", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", + }, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } + + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} + + result, err := server.serializeUsers(users) + require.NoError(t, err) + + assert.False(t, result[0].IsDisabled) + assert.Equal(t, "Roel", result[0].Name) + }) - result, err := server.serializeUsers(users) + t.Run("mark user without matching group as disabled", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{{ + GroupDN: "foo", + OrgId: 1, + OrgRole: models.ROLE_EDITOR, + }}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } + + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} + + result, err := server.serializeUsers(users) + require.NoError(t, err) + + assert.Len(t, result, 1) + assert.True(t, result[0].IsDisabled) + }) +} - So(err, ShouldBeNil) - So(len(result), ShouldEqual, 1) - So(result[0].IsDisabled, ShouldBeTrue) - }) +func TestServer_validateGrafanaUser(t *testing.T) { + t.Run("no group config", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{}, + }, + log: logger.New("test"), + } + + user := &models.ExternalUserInfo{ + Login: "markelog", + } + + err := server.validateGrafanaUser(user) + require.NoError(t, err) }) - Convey("validateGrafanaUser()", t, func() { - Convey("Returns error when user does not belong in any of the specified LDAP groups", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{ - { - OrgId: 1, - }, + t.Run("user in group", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{ + { + OrgId: 1, }, }, - log: logger.New("test"), - } - - user := &models.ExternalUserInfo{ - Login: "markelog", - } - - result := server.validateGrafanaUser(user) - - So(result, ShouldEqual, ErrInvalidCredentials) - }) - - Convey("Does not return error when group config is empty", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{}, - }, - log: logger.New("test"), - } - - user := &models.ExternalUserInfo{ - Login: "markelog", - } + }, + log: logger.New("test"), + } - result := server.validateGrafanaUser(user) + user := &models.ExternalUserInfo{ + Login: "markelog", + OrgRoles: map[int64]models.RoleType{ + 1: "test", + }, + } - So(result, ShouldBeNil) - }) + err := server.validateGrafanaUser(user) + require.NoError(t, err) + }) - Convey("Does not return error when groups are there", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{ - { - OrgId: 1, - }, + t.Run("user not in group", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{ + { + OrgId: 1, }, }, - log: logger.New("test"), - } - - user := &models.ExternalUserInfo{ - Login: "markelog", - OrgRoles: map[int64]models.RoleType{ - 1: "test", - }, - } + }, + log: logger.New("test"), + } - result := server.validateGrafanaUser(user) + user := &models.ExternalUserInfo{ + Login: "markelog", + } - So(result, ShouldBeNil) - }) + err := server.validateGrafanaUser(user) + require.ErrorIs(t, err, ErrInvalidCredentials) }) +} - Convey("shouldAdminBind()", t, func() { - Convey("it should require admin userBind", func() { - server := &Server{ - Config: &ServerConfig{ - BindPassword: "test", - }, - } - - result := server.shouldAdminBind() - So(result, ShouldBeTrue) - }) - - Convey("it should not require admin userBind", func() { - server := &Server{ - Config: &ServerConfig{ - BindPassword: "", - }, - } +func TestServer_binds(t *testing.T) { + t.Run("single bind with cn wildcard", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=%s,dc=grafana,dc=org", + }, + } - result := server.shouldAdminBind() - So(result, ShouldBeFalse) - }) + assert.True(t, server.shouldSingleBind()) + assert.Equal(t, "cn=test,dc=grafana,dc=org", server.singleBindDN("test")) }) - Convey("shouldSingleBind()", t, func() { - Convey("it should allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,dc=grafana,dc=org", - }, - } + t.Run("don't single bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=admin,dc=grafana,dc=org", + }, + } - result := server.shouldSingleBind() - So(result, ShouldBeTrue) - }) + assert.False(t, server.shouldSingleBind()) + }) - Convey("it should not allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=admin,dc=grafana,dc=org", - }, - } + t.Run("admin user bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindPassword: "test", + }, + } - result := server.shouldSingleBind() - So(result, ShouldBeFalse) - }) + assert.True(t, server.shouldAdminBind()) }) - Convey("singleBindDN()", t, func() { - Convey("it should allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,dc=grafana,dc=org", - }, - } + t.Run("don't admin user bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindPassword: "", + }, + } - result := server.singleBindDN("test") - So(result, ShouldEqual, "cn=test,dc=grafana,dc=org") - }) + assert.False(t, server.shouldAdminBind()) }) } diff --git a/pkg/services/ldap/ldap_test.go b/pkg/services/ldap/ldap_test.go index ea1fd049bf37..042ac045506c 100644 --- a/pkg/services/ldap/ldap_test.go +++ b/pkg/services/ldap/ldap_test.go @@ -2,226 +2,319 @@ package ldap import ( "errors" + "fmt" "testing" - "github.com/grafana/grafana/pkg/infra/log" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/ldap.v3" + + "github.com/grafana/grafana/pkg/infra/log" ) -func TestPublicAPI(t *testing.T) { - Convey("New()", t, func() { - Convey("Should return ", func() { - result := New(&ServerConfig{ +func TestNew(t *testing.T) { + result := New(&ServerConfig{ + Attr: AttributeMap{}, + SearchBaseDNs: []string{"BaseDNHere"}, + }) + + assert.Implements(t, (*IServer)(nil), result) +} + +func TestServer_Close(t *testing.T) { + t.Run("close the connection", func(t *testing.T) { + connection := &MockConnection{} + + server := &Server{ + Config: &ServerConfig{ Attr: AttributeMap{}, SearchBaseDNs: []string{"BaseDNHere"}, - }) + }, + Connection: connection, + } - So(result, ShouldImplement, (*IServer)(nil)) - }) + assert.NotPanics(t, server.Close) + assert.True(t, connection.CloseCalled) }) - Convey("Close()", t, func() { - Convey("Should close the connection", func() { - connection := &MockConnection{} + t.Run("panic if no connection", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{}, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: nil, + } + + assert.Panics(t, server.Close) + }) +} - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{}, - SearchBaseDNs: []string{"BaseDNHere"}, +func TestServer_Users(t *testing.T) { + t.Run("one user", func(t *testing.T) { + conn := &MockConnection{} + entry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "surname", Values: []string{"Gerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }} + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + conn.setSearchResult(&result) + + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", }, - Connection: connection, - } + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } - So(server.Close, ShouldNotPanic) - So(connection.CloseCalled, ShouldBeTrue) - }) + searchResult, err := server.Users([]string{"roelgerrits"}) - Convey("Should panic if no connection is established", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{}, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: nil, - } + require.NoError(t, err) + assert.NotNil(t, searchResult) - So(server.Close, ShouldPanic) - }) + // User should be searched in ldap + assert.True(t, conn.SearchCalled) + // No empty attributes should be added to the search request + assert.Len(t, conn.SearchAttributes, 3) }) - Convey("Users()", t, func() { - Convey("Finds one user", func() { - MockConnection := &MockConnection{} - entry := ldap.Entry{ - DN: "dn", Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "surname", Values: []string{"Gerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }} - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - MockConnection.setSearchResult(&result) - - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), - } - searchResult, err := server.Users([]string{"roelgerrits"}) + t.Run("error", func(t *testing.T) { + expected := errors.New("Killa-gorilla") + conn := &MockConnection{} + conn.setSearchError(expected) - So(err, ShouldBeNil) - So(searchResult, ShouldNotBeNil) + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } - // User should be searched in ldap - So(MockConnection.SearchCalled, ShouldBeTrue) + _, err := server.Users([]string{"roelgerrits"}) - // No empty attributes should be added to the search request - So(len(MockConnection.SearchAttributes), ShouldEqual, 3) - }) - - Convey("Handles a error", func() { - expected := errors.New("Killa-gorilla") - MockConnection := &MockConnection{} - MockConnection.setSearchError(expected) + assert.ErrorIs(t, err, expected) + }) - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), - } + t.Run("no user", func(t *testing.T) { + conn := &MockConnection{} + result := ldap.SearchResult{Entries: []*ldap.Entry{}} + conn.setSearchResult(&result) - _, err := server.Users([]string{"roelgerrits"}) + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } - So(err, ShouldEqual, expected) - }) + searchResult, err := server.Users([]string{"roelgerrits"}) - Convey("Should return empty slice if none were found", func() { - MockConnection := &MockConnection{} - result := ldap.SearchResult{Entries: []*ldap.Entry{}} - MockConnection.setSearchResult(&result) + require.NoError(t, err) + assert.Empty(t, searchResult) + }) - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), + t.Run("multiple DNs", func(t *testing.T) { + conn := &MockConnection{} + serviceDN := "dc=svc,dc=example,dc=org" + serviceEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"imgrenderer"}}, + {Name: "name", Values: []string{"Image renderer"}}, + }} + services := ldap.SearchResult{Entries: []*ldap.Entry{&serviceEntry}} + + userDN := "dc=users,dc=example,dc=org" + userEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot"}}, + }} + users := ldap.SearchResult{Entries: []*ldap.Entry{&userEntry}} + + conn.setSearchFunc(func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + switch request.BaseDN { + case userDN: + return &users, nil + case serviceDN: + return &services, nil + default: + return nil, fmt.Errorf("test case not defined for baseDN: '%s'", request.BaseDN) } - - searchResult, err := server.Users([]string{"roelgerrits"}) - - So(err, ShouldBeNil) - So(searchResult, ShouldBeEmpty) }) - }) - Convey("UserBind()", t, func() { - Convey("Should use provided DN and password", func() { - connection := &MockConnection{} - var actualUsername, actualPassword string - connection.BindProvider = func(username, password string) error { - actualUsername = username - actualPassword = password - return nil - } - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindDN: "cn=admin,dc=grafana,dc=org", + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", }, - } + SearchBaseDNs: []string{serviceDN, userDN}, + }, + Connection: conn, + log: log.New("test-logger"), + } - dn := "cn=user,ou=users,dc=grafana,dc=org" - err := server.UserBind(dn, "pwd") - - So(err, ShouldBeNil) - So(actualUsername, ShouldEqual, dn) - So(actualPassword, ShouldEqual, "pwd") - }) + searchResult, err := server.Users([]string{"imgrenderer", "grot"}) + require.NoError(t, err) - Convey("Should handle an error", func() { - connection := &MockConnection{} - expected := &ldap.Error{ - ResultCode: uint16(25), - } - connection.BindProvider = func(username, password string) error { - return expected - } - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindDN: "cn=%s,ou=users,dc=grafana,dc=org", - }, - log: log.New("test-logger"), - } - err := server.UserBind("user", "pwd") - So(err, ShouldEqual, expected) - }) + assert.Len(t, searchResult, 2) }) - Convey("AdminBind()", t, func() { - Convey("Should use admin DN and password", func() { - connection := &MockConnection{} - var actualUsername, actualPassword string - connection.BindProvider = func(username, password string) error { - actualUsername = username - actualPassword = password - return nil + t.Run("same user in multiple DNs", func(t *testing.T) { + conn := &MockConnection{} + firstDN := "dc=users1,dc=example,dc=org" + firstEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot the First"}}, + }} + firsts := ldap.SearchResult{Entries: []*ldap.Entry{&firstEntry}} + + secondDN := "dc=users2,dc=example,dc=org" + secondEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot the Second"}}, + }} + seconds := ldap.SearchResult{Entries: []*ldap.Entry{&secondEntry}} + + conn.setSearchFunc(func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + switch request.BaseDN { + case secondDN: + return &seconds, nil + case firstDN: + return &firsts, nil + default: + return nil, fmt.Errorf("test case not defined for baseDN: '%s'", request.BaseDN) } + }) - dn := "cn=admin,dc=grafana,dc=org" - - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindPassword: "pwd", - BindDN: dn, + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", }, - } - - err := server.AdminBind() - - So(err, ShouldBeNil) - So(actualUsername, ShouldEqual, dn) - So(actualPassword, ShouldEqual, "pwd") - }) + SearchBaseDNs: []string{firstDN, secondDN}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + res, err := server.Users([]string{"grot"}) + require.NoError(t, err) + require.Len(t, res, 1) + assert.Equal(t, "Grot the First", res[0].Name) + }) +} - Convey("Should handle an error", func() { - connection := &MockConnection{} - expected := &ldap.Error{ - ResultCode: uint16(25), - } - connection.BindProvider = func(username, password string) error { - return expected - } +func TestServer_UserBind(t *testing.T) { + t.Run("use provided DN and password", func(t *testing.T) { + connection := &MockConnection{} + var actualUsername, actualPassword string + connection.BindProvider = func(username, password string) error { + actualUsername = username + actualPassword = password + return nil + } + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindDN: "cn=admin,dc=grafana,dc=org", + }, + } + + dn := "cn=user,ou=users,dc=grafana,dc=org" + err := server.UserBind(dn, "pwd") + + require.NoError(t, err) + assert.Equal(t, dn, actualUsername) + assert.Equal(t, "pwd", actualPassword) + }) - dn := "cn=admin,dc=grafana,dc=org" + t.Run("error", func(t *testing.T) { + connection := &MockConnection{} + expected := &ldap.Error{ + ResultCode: uint16(25), + } + connection.BindProvider = func(username, password string) error { + return expected + } + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindDN: "cn=%s,ou=users,dc=grafana,dc=org", + }, + log: log.New("test-logger"), + } + err := server.UserBind("user", "pwd") + assert.ErrorIs(t, err, expected) + }) +} - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindPassword: "pwd", - BindDN: dn, - }, - log: log.New("test-logger"), - } +func TestServer_AdminBind(t *testing.T) { + t.Run("use admin DN and password", func(t *testing.T) { + connection := &MockConnection{} + var actualUsername, actualPassword string + connection.BindProvider = func(username, password string) error { + actualUsername = username + actualPassword = password + return nil + } + + dn := "cn=admin,dc=grafana,dc=org" + + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindPassword: "pwd", + BindDN: dn, + }, + } + + err := server.AdminBind() + require.NoError(t, err) + + assert.Equal(t, dn, actualUsername) + assert.Equal(t, "pwd", actualPassword) + }) - err := server.AdminBind() - So(err, ShouldEqual, expected) - }) + t.Run("error", func(t *testing.T) { + connection := &MockConnection{} + expected := &ldap.Error{ + ResultCode: uint16(25), + } + connection.BindProvider = func(username, password string) error { + return expected + } + + dn := "cn=admin,dc=grafana,dc=org" + + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindPassword: "pwd", + BindDN: dn, + }, + log: log.New("test-logger"), + } + + err := server.AdminBind() + assert.ErrorIs(t, err, expected) }) } diff --git a/pkg/services/ldap/testing.go b/pkg/services/ldap/testing.go index 8bad83a2d926..cd9ff9184f47 100644 --- a/pkg/services/ldap/testing.go +++ b/pkg/services/ldap/testing.go @@ -6,10 +6,11 @@ import ( "gopkg.in/ldap.v3" ) +type searchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) + // MockConnection struct for testing type MockConnection struct { - SearchResult *ldap.SearchResult - SearchError error + SearchFunc searchFunc SearchCalled bool SearchAttributes []string @@ -56,11 +57,19 @@ func (c *MockConnection) Close() { } func (c *MockConnection) setSearchResult(result *ldap.SearchResult) { - c.SearchResult = result + c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + return result, nil + } } func (c *MockConnection) setSearchError(err error) { - c.SearchError = err + c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + return nil, err + } +} + +func (c *MockConnection) setSearchFunc(fn searchFunc) { + c.SearchFunc = fn } // Search mocks Search connection function @@ -68,11 +77,7 @@ func (c *MockConnection) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, err c.SearchCalled = true c.SearchAttributes = sr.Attributes - if c.SearchError != nil { - return nil, c.SearchError - } - - return c.SearchResult, nil + return c.SearchFunc(sr) } // Add mocks Add connection function