diff --git a/nuts/client.go b/nuts/client.go index d2ce113..6106f53 100644 --- a/nuts/client.go +++ b/nuts/client.go @@ -45,14 +45,10 @@ func (o OAuth2TokenSource) Token(httpRequest *http.Request, authzServerURL *url. if err != nil { return nil, err } - // TODO: Might want to support DPoP as well - var tokenType = iam.ServiceAccessTokenRequestTokenTypeBearer - // TODO: Is this the right context to use? response, err := client.RequestServiceAccessToken(httpRequest.Context(), o.NutsSubject, iam.RequestServiceAccessTokenJSONRequestBody{ AuthorizationServer: authzServerURL.String(), Credentials: &additionalCredentials, Scope: scope, - TokenType: &tokenType, }) if err != nil { return nil, err @@ -69,9 +65,34 @@ func (o OAuth2TokenSource) Token(httpRequest *http.Request, authzServerURL *url. expiry = new(time.Time) *expiry = time.Now().Add(time.Duration(*accessTokenResponse.JSON200.ExpiresIn) * time.Second) } + tokenType := iam.ServiceAccessTokenRequestTokenType(accessTokenResponse.JSON200.TokenType) + var dPoPToken *string + if tokenType == iam.ServiceAccessTokenRequestTokenTypeDPoP { + if accessTokenResponse.JSON200.DpopKid == nil { + return nil, fmt.Errorf("type is DPoP but no DpopKid has been provided") + } + kid := *accessTokenResponse.JSON200.DpopKid + proof, err := client.CreateDPoPProof(httpRequest.Context(), kid, iam.CreateDPoPProofJSONRequestBody{ + Token: accessTokenResponse.JSON200.AccessToken, + Htm: httpRequest.Method, + Htu: httpRequest.URL.String(), + }) + if err != nil { + return nil, err + } + proofResponse, err := iam.ParseCreateDPoPProofResponse(proof) + if err != nil { + return nil, err + } + if proofResponse.JSON200 == nil { + return nil, fmt.Errorf("failed service dpop response: %s", accessTokenResponse.HTTPResponse.Status) + } + dPoPToken = &proofResponse.JSON200.Dpop + } return &oauth2.Token{ AccessToken: accessTokenResponse.JSON200.AccessToken, - TokenType: accessTokenResponse.JSON200.TokenType, + DPoPToken: dPoPToken, + TokenType: string(tokenType), Expiry: expiry, }, nil } diff --git a/nuts/client_test.go b/nuts/client_test.go index d4837a4..9da976c 100644 --- a/nuts/client_test.go +++ b/nuts/client_test.go @@ -15,12 +15,12 @@ import ( ) func TestOAuth2TokenSource_Token(t *testing.T) { - t.Run("ok", func(t *testing.T) { + t.Run("ok nodpop", func(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/internal/auth/v2/123abc/request-service-access-token", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600}`)) + _, _ = w.Write([]byte(`{"access_token":"test","token_type":"Bearer","expires_in":3600}`)) }) httpServer := httptest.NewServer(mux) tokenSource := OAuth2TokenSource{ @@ -35,8 +35,40 @@ func TestOAuth2TokenSource_Token(t *testing.T) { require.NoError(t, err) require.NotNil(t, token) + require.Nil(t, token.DPoPToken) require.Equal(t, "test", token.AccessToken) - require.Equal(t, "bearer", token.TokenType) + require.Equal(t, "Bearer", token.TokenType) + require.Greater(t, token.Expiry.Unix(), time.Now().Unix()) + require.Less(t, token.Expiry.Unix(), time.Now().Add(2*time.Hour).Unix()) + }) + t.Run("ok dpop", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/internal/auth/v2/123abc/request-service-access-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"test","token_type":"DPoP","expires_in":3600, "dpop_kid" : "kid"}`)) + }) + mux.HandleFunc("/internal/auth/v2/dpop/kid", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"dpop":"dpop321"}`)) + }) + httpServer := httptest.NewServer(mux) + tokenSource := OAuth2TokenSource{ + NutsSubject: "123abc", + NutsAPIURL: httpServer.URL, + } + expectedAuthServerURL, _ := url.Parse("https://auth.example.com") + httpRequest, _ := http.NewRequestWithContext(context.Background(), "GET", "https://resource.example.com", nil) + + token, err := tokenSource.Token(httpRequest, expectedAuthServerURL, "test") + + require.NoError(t, err) + require.NotNil(t, token) + + require.NotNil(t, token.DPoPToken) + require.Equal(t, "test", token.AccessToken) + require.Equal(t, "DPoP", token.TokenType) require.Greater(t, token.Expiry.Unix(), time.Now().Unix()) require.Less(t, token.Expiry.Unix(), time.Now().Add(2*time.Hour).Unix()) }) @@ -47,7 +79,7 @@ func TestOAuth2TokenSource_Token(t *testing.T) { require.NoError(t, json.NewDecoder(r.Body).Decode(&capturedRequest)) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600}`)) + _, _ = w.Write([]byte(`{"access_token":"test","token_type":"Bearer","expires_in":3600}`)) }) httpServer := httptest.NewServer(mux) tokenSource := OAuth2TokenSource{ @@ -67,8 +99,78 @@ func TestOAuth2TokenSource_Token(t *testing.T) { require.NoError(t, err) require.NotNil(t, token) + require.Nil(t, token.DPoPToken) require.Equal(t, "test", token.AccessToken) - require.Equal(t, "bearer", token.TokenType) + require.Equal(t, "Bearer", token.TokenType) require.NotEmpty(t, capturedRequest.Credentials) }) + t.Run("error dpop with no kid", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/internal/auth/v2/123abc/request-service-access-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"test","token_type":"DPoP","expires_in":3600}`)) + }) + mux.HandleFunc("/internal/auth/v2/dpop/kid", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"dpop":"dpop321"}`)) + }) + httpServer := httptest.NewServer(mux) + tokenSource := OAuth2TokenSource{ + NutsSubject: "123abc", + NutsAPIURL: httpServer.URL, + } + expectedAuthServerURL, _ := url.Parse("https://auth.example.com") + httpRequest, _ := http.NewRequestWithContext(context.Background(), "GET", "https://resource.example.com", nil) + + _, err := tokenSource.Token(httpRequest, expectedAuthServerURL, "test") + + require.Error(t, err) + }) + + t.Run("error broken request-service-access-token", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/internal/auth/v2/123abc/request-service-access-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_request"}`)) + }) + httpServer := httptest.NewServer(mux) + tokenSource := OAuth2TokenSource{ + NutsSubject: "123abc", + NutsAPIURL: httpServer.URL, + } + expectedAuthServerURL, _ := url.Parse("https://auth.example.com") + httpRequest, _ := http.NewRequestWithContext(context.Background(), "GET", "https://resource.example.com", nil) + + _, err := tokenSource.Token(httpRequest, expectedAuthServerURL, "test") + + require.Error(t, err) + }) + + t.Run("error broken dpop call", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/internal/auth/v2/123abc/request-service-access-token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token":"test","token_type":"DPoP","expires_in":3600, "dpop_kid" : "kid"}`)) + }) + mux.HandleFunc("/internal/auth/v2/dpop/kid", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`error`)) + }) + httpServer := httptest.NewServer(mux) + tokenSource := OAuth2TokenSource{ + NutsSubject: "123abc", + NutsAPIURL: httpServer.URL, + } + expectedAuthServerURL, _ := url.Parse("https://auth.example.com") + httpRequest, _ := http.NewRequestWithContext(context.Background(), "GET", "https://resource.example.com", nil) + + _, err := tokenSource.Token(httpRequest, expectedAuthServerURL, "test") + + require.Error(t, err) + }) } diff --git a/oauth2/client.go b/oauth2/client.go index 553abb5..65d75dc 100644 --- a/oauth2/client.go +++ b/oauth2/client.go @@ -77,6 +77,9 @@ func (o *Transport) RoundTrip(httpRequest *http.Request) (*http.Response, error) } httpRequest = copyRequest(httpRequest, requestBody) httpRequest.Header.Set("Authorization", fmt.Sprintf("%s %s", token.TokenType, token.AccessToken)) + if token.DPoPToken != nil { + httpRequest.Header.Set("DPoP", *token.DPoPToken) + } httpResponse, err = client.RoundTrip(httpRequest) } return httpResponse, err diff --git a/oauth2/tokensource.go b/oauth2/tokensource.go index e8e5d82..b0b6870 100644 --- a/oauth2/tokensource.go +++ b/oauth2/tokensource.go @@ -9,6 +9,7 @@ import ( type Token struct { AccessToken string TokenType string + DPoPToken *string Expiry *time.Time }