Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion security/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type secCtxKey uint8

const (
failedBasicAuth secCtxKey = iota
oauth2SchemeName
)

func FailedBasicAuth(r *http.Request) string {
Expand All @@ -89,6 +90,18 @@ func FailedBasicAuthCtx(ctx context.Context) string {
return v
}

func OAuth2SchemeName(r *http.Request) string {
return OAuth2SchemeNameCtx(r.Context())
}

func OAuth2SchemeNameCtx(ctx context.Context) string {
v, ok := ctx.Value(oauth2SchemeName).(string)
if !ok {
return ""
}
return v
}

// BasicAuth creates a basic auth authenticator with the provided authentication function
func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator {
return BasicAuthRealm(DefaultRealmName, authenticate)
Expand Down Expand Up @@ -224,6 +237,8 @@ func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Aut
return false, nil, nil
}

rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
*r.Request = *r.Request.WithContext(rctx)
p, err := authenticate(token, r.RequiredScopes)
return true, p, err
})
Expand Down Expand Up @@ -252,7 +267,8 @@ func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runti
return false, nil, nil
}

ctx, p, err := authenticate(r.Request.Context(), token, r.RequiredScopes)
rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
ctx, p, err := authenticate(rctx, token, r.RequiredScopes)
*r.Request = *r.Request.WithContext(ctx)
return true, p, err
})
Expand Down
8 changes: 8 additions & 0 deletions security/bearer_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req1), "owners_auth")

req2, _ := http.NewRequest("GET", "/blah", nil)
req2.Header.Set("Authorization", "Bearer token123")
Expand All @@ -37,6 +38,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req2), "owners_auth")

body := url.Values(map[string][]string{})
body.Set("access_token", "token123")
Expand All @@ -47,6 +49,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req3), "owners_auth")

mpbody := bytes.NewBuffer(nil)
writer := multipart.NewWriter(mpbody)
Expand All @@ -59,6 +62,7 @@ func TestValidBearerAuth(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "admin", usr)
assert.NoError(t, err)
assert.Equal(t, OAuth2SchemeName(req4), "owners_auth")
}

func TestInvalidBearerAuth(t *testing.T) {
Expand Down Expand Up @@ -162,6 +166,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req1.Context().Value(original))
assert.Equal(t, extraWisdom, req1.Context().Value(extra))
assert.Nil(t, req1.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req1), "owners_auth")

req2, _ := http.NewRequest("GET", "/blah", nil)
req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom))
Expand All @@ -174,6 +179,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req2.Context().Value(original))
assert.Equal(t, extraWisdom, req2.Context().Value(extra))
assert.Nil(t, req2.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req2), "owners_auth")

body := url.Values(map[string][]string{})
body.Set("access_token", "token123")
Expand All @@ -188,6 +194,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req3.Context().Value(original))
assert.Equal(t, extraWisdom, req3.Context().Value(extra))
assert.Nil(t, req3.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req3), "owners_auth")

mpbody := bytes.NewBuffer(nil)
writer := multipart.NewWriter(mpbody)
Expand All @@ -204,6 +211,7 @@ func TestValidBearerAuthCtx(t *testing.T) {
assert.Equal(t, wisdom, req4.Context().Value(original))
assert.Equal(t, extraWisdom, req4.Context().Value(extra))
assert.Nil(t, req4.Context().Value(reason))
assert.Equal(t, OAuth2SchemeName(req4), "owners_auth")
}

func TestInvalidBearerAuthCtx(t *testing.T) {
Expand Down