diff --git a/oidc/oidc.go b/oidc/oidc.go index a9098f3c..89209018 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -36,8 +36,9 @@ const ( ) var ( - errNoAtHash = errors.New("id token did not have an access token hash") - errInvalidAtHash = errors.New("access token hash does not match value in ID token") + errNoAtHash = errors.New("id token did not have an access token hash") + errInvalidAtHash = errors.New("access token hash does not match value in ID token") + ErrUserInfoNotSupported = errors.New("oidc: user info endpoint is not supported by this provider") ) type contextKey int @@ -306,7 +307,7 @@ func (u *UserInfo) Claims(v interface{}) error { // UserInfo uses the token source to query the provider's user info endpoint. func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) { if p.userInfoURL == "" { - return nil, errors.New("oidc: user info endpoint is not supported by this provider") + return nil, ErrUserInfoNotSupported } req, err := http.NewRequest("GET", p.userInfoURL, nil) diff --git a/oidc/oidc_test.go b/oidc/oidc_test.go index bbbbda42..d2c13757 100644 --- a/oidc/oidc_test.go +++ b/oidc/oidc_test.go @@ -362,14 +362,19 @@ func (ts *testServer) run(t *testing.T) string { ] }` + var userInfoJSON string + if ts.userInfo != "" { + userInfoJSON = fmt.Sprintf(`"userinfo_endpoint": "%s/userinfo",`, server.URL) + } + wellKnown := fmt.Sprintf(`{ "issuer": "%[1]s", "authorization_endpoint": "%[1]s/auth", "token_endpoint": "%[1]s/token", "jwks_uri": "%[1]s/keys", - "userinfo_endpoint": "%[1]s/userinfo", + %[2]s "id_token_signing_alg_values_supported": ["RS256"] - }`, server.URL) + }`, server.URL, userInfoJSON) newMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, req *http.Request) { _, err := io.WriteString(w, wellKnown) @@ -383,13 +388,15 @@ func (ts *testServer) run(t *testing.T) string { w.WriteHeader(500) } }) - newMux.HandleFunc("/userinfo", func(w http.ResponseWriter, req *http.Request) { - w.Header().Add("Content-Type", ts.contentType) - _, err := io.WriteString(w, ts.userInfo) - if err != nil { - w.WriteHeader(500) - } - }) + if ts.userInfo != "" { + newMux.HandleFunc("/userinfo", func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("Content-Type", ts.contentType) + _, err := io.WriteString(w, ts.userInfo) + if err != nil { + w.WriteHeader(500) + } + }) + } t.Cleanup(server.Close) return server.URL } @@ -415,6 +422,7 @@ func TestUserInfoEndpoint(t *testing.T) { name string server testServer wantUserInfo UserInfo + wantError error }{ { name: "basic json userinfo", @@ -489,6 +497,14 @@ func TestUserInfoEndpoint(t *testing.T) { claims: []byte(userInfoJSONCognitoVariant), }, }, + { + name: "no userinfo endpoint", + server: testServer{ + contentType: "application/json", + userInfo: "", + }, + wantError: ErrUserInfoNotSupported, + }, } for _, test := range tests { @@ -504,8 +520,12 @@ func TestUserInfoEndpoint(t *testing.T) { fakeOauthToken := oauth2.Token{} info, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(&fakeOauthToken)) - if err != nil { - t.Fatalf("failed to get userinfo %v", err) + if err != test.wantError { + t.Fatalf("expected UserInfo err %v got %v", test.wantError, err) + } + + if test.wantError != nil { + return } if info.Email != test.wantUserInfo.Email {