diff --git a/deviceauth.go b/deviceauth.go index e99c92f39..e783a9437 100644 --- a/deviceauth.go +++ b/deviceauth.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "net/url" "strings" @@ -116,10 +117,38 @@ func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAu return nil, fmt.Errorf("oauth2: cannot auth device: %v", err) } if code := r.StatusCode; code < 200 || code > 299 { - return nil, &RetrieveError{ + retrieveError := &RetrieveError{ Response: r, Body: body, } + + content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) + switch content { + case "application/x-www-form-urlencoded", "text/plain": + // some endpoints return a query string + vals, err := url.ParseQuery(string(body)) + if err != nil { + return nil, retrieveError + } + retrieveError.ErrorCode = vals.Get("error") + retrieveError.ErrorDescription = vals.Get("error_description") + retrieveError.ErrorURI = vals.Get("error_uri") + default: + var tj struct { + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + ErrorCode string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` + } + if json.Unmarshal(body, &tj) != nil { + return nil, retrieveError + } + retrieveError.ErrorCode = tj.ErrorCode + retrieveError.ErrorDescription = tj.ErrorDescription + retrieveError.ErrorURI = tj.ErrorURI + } + + return nil, retrieveError } da := &DeviceAuthResponse{} diff --git a/deviceauth_test.go b/deviceauth_test.go index 0e61a2559..72fc394be 100644 --- a/deviceauth_test.go +++ b/deviceauth_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" "time" @@ -101,3 +103,49 @@ func ExampleConfig_DeviceAuth() { } fmt.Println(token) } + +func TestDeviceAuthTokenRetrieveErrorJSON(t *testing.T) { + for _, responseFun := range []func(w http.ResponseWriter){ + func(w http.ResponseWriter) { + w.Header().Set("Content-type", "application/x-www-form-urlencoded") + // "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`error=invalid_grant&error_description=sometext`)) + }, + func(w http.ResponseWriter) { + w.Header().Set("Content-type", "application/json") + // "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "invalid_grant", "error_description": "sometext"}`)) + }, + } { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() != "/device" { + t.Errorf("Unexpected device auth request URL, %v is found.", r.URL) + } + responseFun(w) + })) + defer ts.Close() + conf := newConf(ts.URL) + _, err := conf.DeviceAuth(context.Background()) + if err == nil { + t.Fatalf("got no error, expected one") + } + re, ok := err.(*RetrieveError) + if !ok { + t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err) + } + expected := `oauth2: "invalid_grant" "sometext"` + if errStr := err.Error(); errStr != expected { + t.Fatalf("got %#v, expected %#v", errStr, expected) + } + expected = "invalid_grant" + if re.ErrorCode != expected { + t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected) + } + expected = "sometext" + if re.ErrorDescription != expected { + t.Fatalf("got %#v, expected %#v", re.ErrorDescription, expected) + } + } +} diff --git a/oauth2_test.go b/oauth2_test.go index 5db78f21e..e996b8013 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -31,8 +31,9 @@ func newConf(url string) *Config { RedirectURL: "REDIRECT_URL", Scopes: []string{"scope1", "scope2"}, Endpoint: Endpoint{ - AuthURL: url + "/auth", - TokenURL: url + "/token", + AuthURL: url + "/auth", + DeviceAuthURL: url + "/device", + TokenURL: url + "/token", }, } }