diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 3dfda7658..08b0ef836 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -35,24 +35,34 @@ func (c *RefreshTokenGrantHandler) HandleTokenEndpointRequest(ctx context.Contex refresh := req.PostForm.Get("refresh_token") signature := c.RefreshTokenStrategy.RefreshTokenSignature(refresh) - accessRequest, err := c.RefreshTokenGrantStorage.GetRefreshTokenSession(ctx, signature, nil) + originalRequest, err := c.RefreshTokenGrantStorage.GetRefreshTokenSession(ctx, signature, request.GetSession()) if errors.Cause(err) == fosite.ErrNotFound { return errors.Wrap(fosite.ErrInvalidRequest, err.Error()) } else if err != nil { return errors.Wrap(fosite.ErrServerError, err.Error()) } + if !originalRequest.GetGrantedScopes().Has("offline") { + return errors.Wrap(fosite.ErrScopeNotGranted, "The client is not allowed to use grant type refresh_token") + + } + // The authorization server MUST ... validate the refresh token. if err := c.RefreshTokenStrategy.ValidateRefreshToken(ctx, request, refresh); err != nil { return errors.Wrap(fosite.ErrInvalidRequest, err.Error()) } // The authorization server MUST ... and ensure that the refresh token was issued to the authenticated client - if accessRequest.GetClient().GetID() != request.GetClient().GetID() { + if originalRequest.GetClient().GetID() != request.GetClient().GetID() { return errors.Wrap(fosite.ErrInvalidRequest, "Client ID mismatch") } - request.Merge(accessRequest) + request.SetSession(originalRequest.GetSession()) + request.SetRequestedScopes(originalRequest.GetRequestedScopes()) + for _, scope := range originalRequest.GetGrantedScopes() { + request.GrantScope(scope) + } + return nil } diff --git a/handler/oauth2/flow_refresh_test.go b/handler/oauth2/flow_refresh_test.go index c01ca1657..a7629200e 100644 --- a/handler/oauth2/flow_refresh_test.go +++ b/handler/oauth2/flow_refresh_test.go @@ -55,7 +55,9 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { { description: "should fail because validation failed", setup: func() { - store.EXPECT().GetRefreshTokenSession(nil, "refreshtokensig", nil).Return(&fosite.Request{}, nil) + store.EXPECT().GetRefreshTokenSession(nil, "refreshtokensig", nil).Return(&fosite.Request{ + GrantedScopes:[]string{"offline"}, + }, nil) chgen.EXPECT().ValidateRefreshToken(nil, areq, "some.refreshtokensig").Return(errors.New("")) }, expectErr: fosite.ErrInvalidRequest, @@ -67,7 +69,10 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: fosite.Arguments{"refresh_token"}, } - store.EXPECT().GetRefreshTokenSession(nil, "refreshtokensig", nil).Return(&fosite.Request{Client: &fosite.DefaultClient{ID: ""}}, nil) + store.EXPECT().GetRefreshTokenSession(nil, "refreshtokensig", nil).Return(&fosite.Request{ + Client: &fosite.DefaultClient{ID: ""}, + GrantedScopes:[]string{"offline"}, + }, nil) chgen.EXPECT().ValidateRefreshToken(nil, areq, "some.refreshtokensig").AnyTimes().Return(nil) }, expectErr: fosite.ErrInvalidRequest, @@ -77,19 +82,19 @@ func TestRefreshFlow_HandleTokenEndpointRequest(t *testing.T) { setup: func() { store.EXPECT().GetRefreshTokenSession(nil, "refreshtokensig", nil).Return(&fosite.Request{ Client: &fosite.DefaultClient{ID: "foo"}, - GrantedScopes: fosite.Arguments{"foo"}, + GrantedScopes: fosite.Arguments{"foo", "offline"}, Scopes: fosite.Arguments{"foo", "bar"}, Session: sess, Form: url.Values{"foo": []string{"bar"}}, - RequestedAt: time.Now().Round(time.Hour), + RequestedAt: time.Now().Add(-time.Hour).Round(time.Hour), }, nil) }, expect: func() { assert.Equal(t, sess, areq.Session) - assert.Equal(t, time.Now().Round(time.Hour), areq.RequestedAt) - assert.Equal(t, fosite.Arguments{"foo"}, areq.GrantedScopes) + assert.NotEqual(t, time.Now().Add(-time.Hour).Round(time.Hour), areq.RequestedAt) + assert.Equal(t, fosite.Arguments{"foo", "offline"}, areq.GrantedScopes) assert.Equal(t, fosite.Arguments{"foo", "bar"}, areq.Scopes) - assert.Equal(t, url.Values{"foo": []string{"bar"}}, areq.Form) + assert.NotEqual(t, url.Values{"foo": []string{"bar"}}, areq.Form) }, }, } {