From c6fb0f145b0532c403fe569a5a73c78c43f86cd5 Mon Sep 17 00:00:00 2001 From: Ivan Porto Carrero Date: Sat, 17 Aug 2019 14:49:32 -0700 Subject: [PATCH] forward name to oauth2 context and provide an accessor Signed-off-by: Ivan Porto Carrero --- security/authenticator.go | 18 +++++++++++++++++- security/bearer_auth_test.go | 8 ++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/security/authenticator.go b/security/authenticator.go index 21be9a1b..5d058b8d 100644 --- a/security/authenticator.go +++ b/security/authenticator.go @@ -75,6 +75,7 @@ type secCtxKey uint8 const ( failedBasicAuth secCtxKey = iota + oauth2SchemeName ) func FailedBasicAuth(r *http.Request) string { @@ -89,6 +90,18 @@ func FailedBasicAuthCtx(ctx context.Context) string { return v } +func OAuth2SchemeName(r *http.Request) string { + return OAuth2SchemeNameCtx(r.Context()) +} + +func OAuth2SchemeNameCtx(ctx context.Context) string { + v, ok := ctx.Value(oauth2SchemeName).(string) + if !ok { + return "" + } + return v +} + // BasicAuth creates a basic auth authenticator with the provided authentication function func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator { return BasicAuthRealm(DefaultRealmName, authenticate) @@ -224,6 +237,8 @@ func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Aut return false, nil, nil } + rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name) + *r.Request = *r.Request.WithContext(rctx) p, err := authenticate(token, r.RequiredScopes) return true, p, err }) @@ -252,7 +267,8 @@ func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runti return false, nil, nil } - ctx, p, err := authenticate(r.Request.Context(), token, r.RequiredScopes) + rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name) + ctx, p, err := authenticate(rctx, token, r.RequiredScopes) *r.Request = *r.Request.WithContext(ctx) return true, p, err }) diff --git a/security/bearer_auth_test.go b/security/bearer_auth_test.go index 8289061d..57362b7a 100644 --- a/security/bearer_auth_test.go +++ b/security/bearer_auth_test.go @@ -29,6 +29,7 @@ func TestValidBearerAuth(t *testing.T) { assert.True(t, ok) assert.Equal(t, "admin", usr) assert.NoError(t, err) + assert.Equal(t, OAuth2SchemeName(req1), "owners_auth") req2, _ := http.NewRequest("GET", "/blah", nil) req2.Header.Set("Authorization", "Bearer token123") @@ -37,6 +38,7 @@ func TestValidBearerAuth(t *testing.T) { assert.True(t, ok) assert.Equal(t, "admin", usr) assert.NoError(t, err) + assert.Equal(t, OAuth2SchemeName(req2), "owners_auth") body := url.Values(map[string][]string{}) body.Set("access_token", "token123") @@ -47,6 +49,7 @@ func TestValidBearerAuth(t *testing.T) { assert.True(t, ok) assert.Equal(t, "admin", usr) assert.NoError(t, err) + assert.Equal(t, OAuth2SchemeName(req3), "owners_auth") mpbody := bytes.NewBuffer(nil) writer := multipart.NewWriter(mpbody) @@ -59,6 +62,7 @@ func TestValidBearerAuth(t *testing.T) { assert.True(t, ok) assert.Equal(t, "admin", usr) assert.NoError(t, err) + assert.Equal(t, OAuth2SchemeName(req4), "owners_auth") } func TestInvalidBearerAuth(t *testing.T) { @@ -162,6 +166,7 @@ func TestValidBearerAuthCtx(t *testing.T) { assert.Equal(t, wisdom, req1.Context().Value(original)) assert.Equal(t, extraWisdom, req1.Context().Value(extra)) assert.Nil(t, req1.Context().Value(reason)) + assert.Equal(t, OAuth2SchemeName(req1), "owners_auth") req2, _ := http.NewRequest("GET", "/blah", nil) req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) @@ -174,6 +179,7 @@ func TestValidBearerAuthCtx(t *testing.T) { assert.Equal(t, wisdom, req2.Context().Value(original)) assert.Equal(t, extraWisdom, req2.Context().Value(extra)) assert.Nil(t, req2.Context().Value(reason)) + assert.Equal(t, OAuth2SchemeName(req2), "owners_auth") body := url.Values(map[string][]string{}) body.Set("access_token", "token123") @@ -188,6 +194,7 @@ func TestValidBearerAuthCtx(t *testing.T) { assert.Equal(t, wisdom, req3.Context().Value(original)) assert.Equal(t, extraWisdom, req3.Context().Value(extra)) assert.Nil(t, req3.Context().Value(reason)) + assert.Equal(t, OAuth2SchemeName(req3), "owners_auth") mpbody := bytes.NewBuffer(nil) writer := multipart.NewWriter(mpbody) @@ -204,6 +211,7 @@ func TestValidBearerAuthCtx(t *testing.T) { assert.Equal(t, wisdom, req4.Context().Value(original)) assert.Equal(t, extraWisdom, req4.Context().Value(extra)) assert.Nil(t, req4.Context().Value(reason)) + assert.Equal(t, OAuth2SchemeName(req4), "owners_auth") } func TestInvalidBearerAuthCtx(t *testing.T) {