diff --git a/auth/providers/azure/graph/graph.go b/auth/providers/azure/graph/graph.go index b2ab25393..6c7829745 100644 --- a/auth/providers/azure/graph/graph.go +++ b/auth/providers/azure/graph/graph.go @@ -150,10 +150,8 @@ func (u *UserInfo) getExpandedGroups(ids []string) (*GroupList, error) { // Generally in federated directories the email address is the userPrincipalName func (u *UserInfo) GetGroups(userPrincipal string) ([]string, error) { // Make sure things are logged in before continuing - if u.isExpired() { - if err := u.login(); err != nil { - return nil, err - } + if err := u.login(); err != nil { + return nil, err } // Get the group IDs for the user @@ -181,8 +179,7 @@ func (u *UserInfo) Name() string { return getterName } -// New returns a new UserInfo object that is authenticated to the MS Graph API. -// If authentication fails, an error will be returned +// New returns a new UserInfo object func New(clientID, clientSecret, tenantName string) (*UserInfo, error) { parsedLogin, err := url.Parse(fmt.Sprintf(loginURL, tenantName)) if err != nil { @@ -198,10 +195,6 @@ func New(clientID, clientSecret, tenantName string) (*UserInfo, error) { clientID: clientID, clientSecret: clientSecret, } - err = u.login() - if err != nil { - return nil, err - } return u, nil } diff --git a/auth/providers/azure/graph/graph_test.go b/auth/providers/azure/graph/graph_test.go index abccae860..2268b9dfc 100644 --- a/auth/providers/azure/graph/graph_test.go +++ b/auth/providers/azure/graph/graph_test.go @@ -248,6 +248,10 @@ func TestGetGroups(t *testing.T) { ] }` mux := http.NewServeMux() + mux.Handle("/login", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{ "token_type": "Bearer", "expires_in": 8459, "access_token": "secret"}`)) + })) mux.Handle("/users/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(validBody1)) @@ -258,9 +262,12 @@ func TestGetGroups(t *testing.T) { })) ts := httptest.NewServer(mux) apiURL, _ := url.Parse(ts.URL) + loginURL, _ := url.Parse(ts.URL + "/login") + u := &UserInfo{ client: http.DefaultClient, apiURL: apiURL, + loginURL: loginURL, headers: http.Header{}, clientID: "jason", clientSecret: "bourne", diff --git a/server/handler_test.go b/server/handler_test.go index bba098fff..2d8e88bac 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -7,8 +7,15 @@ import ( "net/http/httptest" "testing" + "github.com/appscode/guard/auth/providers/appscode" + "github.com/appscode/guard/auth/providers/azure" + "github.com/appscode/guard/auth/providers/github" + "github.com/appscode/guard/auth/providers/gitlab" + "github.com/appscode/guard/auth/providers/google" + "github.com/appscode/guard/auth/providers/ldap" "github.com/appscode/kutil/tools/certstore" "github.com/google/gofuzz" + "github.com/pkg/errors" "github.com/spf13/afero" "github.com/stretchr/testify/assert" auth "k8s.io/api/authentication/v1" @@ -85,3 +92,77 @@ func TestServeHTTP(t *testing.T) { assert.Nil(t, err, "response body must be of kind TokenReview") } } + +func TestGetAuthProviderClient(t *testing.T) { + const invalideAuthProvider = "invalid_auth_provider" + + testData := []struct { + testName string + authProvider string + expectedErr error + }{ + { + "get github client", + github.OrgType, + nil, + }, + { + "get google client", + google.OrgType, + nil, + }, + { + "get appscode client", + appscode.OrgType, + nil, + }, + { + "get gitlab client", + gitlab.OrgType, + nil, + }, + { + "get azure client", + azure.OrgType, + nil, + }, + { + "get LDAP client", + ldap.OrgType, + nil, + }, + { + "unknown auth providername", + invalideAuthProvider, + errors.Errorf("Client is using unknown organization %s", invalideAuthProvider), + }, + } + s := Server{ + RecommendedOptions: NewRecommendedOptions(), + } + + // https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols-oidc + // https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-protocols-oauth-code#jwt-token-claims + s.RecommendedOptions.Azure.TenantID = "7fe81447-da57-4385-becb-6de57f21477e" + + for _, test := range testData { + t.Run(test.testName, func(t *testing.T) { + + client, err := s.getAuthProviderClient(test.authProvider, "") + + if test.expectedErr == nil { + assert.Nil(t, err, "expected error nil") + + if err == nil { + assert.Equal(t, test.authProvider, client.UID()) + } + } else { + assert.NotNil(t, err) + + if err != nil { + assert.EqualError(t, err, test.expectedErr.Error()) + } + } + }) + } +}