Skip to content

Commit

Permalink
Add oidc tls skip verify to forwarding proxy/refreshing token
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 committed Jan 23, 2021
1 parent d72a4d3 commit 1c88a15
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 5 deletions.
20 changes: 18 additions & 2 deletions forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package main

import (
"context"
"crypto/tls"
"fmt"
"net/http"
"time"

"go.uber.org/zap"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)

Expand Down Expand Up @@ -87,6 +89,15 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler {
// forwardProxyHandler is responsible for signing outbound requests
func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) {
ctx := context.Background()

if r.config.SkipOpenIDProviderTLSVerify {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
sslcli := &http.Client{Transport: tr}
ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli)
}

conf := r.newOAuth2Config(r.config.RedirectionURL)

// the loop state
Expand Down Expand Up @@ -119,7 +130,12 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) {
zap.String("username", r.config.ForwardingUsername))

// step: login into the service
resp, err := conf.PasswordCredentialsToken(ctx, r.config.ForwardingUsername, r.config.ForwardingPassword)
resp, err := conf.PasswordCredentialsToken(
ctx,
r.config.ForwardingUsername,
r.config.ForwardingPassword,
)

if err != nil {
r.log.Error("failed to login to authentication service", zap.Error(err))
// step: back-off and reschedule
Expand Down Expand Up @@ -169,7 +185,7 @@ func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) {
zap.String("expires", state.expiration.Format(time.RFC3339)))

// step: attempt to refresh the access
token, rawToken, newRefreshToken, expiration, _, err := getRefreshedToken(conf, state.refresh)
token, rawToken, newRefreshToken, expiration, _, err := getRefreshedToken(conf, r, state.refresh)
state.rawToken = rawToken
if err != nil {
state.login = true
Expand Down
31 changes: 31 additions & 0 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ package main

import (
"net/http"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestDebugHandler(t *testing.T) {
Expand Down Expand Up @@ -178,6 +181,20 @@ func TestSkipOpenIDProviderTLSVerifyLoginHandler(t *testing.T) {
},
}
newFakeProxy(c, &fakeAuthConfig{EnableTLS: true}).RunTests(t, requests)

c.SkipOpenIDProviderTLSVerify = false

defer func() {
if r := recover(); r != nil {
check := strings.Contains(
r.(string),
"failed to retrieve the provider configuration from discovery url",
)
assert.True(t, check)
}
}()

newFakeProxy(c, &fakeAuthConfig{EnableTLS: true}).RunTests(t, requests)
}

func TestLogoutHandlerBadRequest(t *testing.T) {
Expand Down Expand Up @@ -247,6 +264,20 @@ func TestSkipOpenIDProviderTLSVerifyLogoutHandler(t *testing.T) {
},
}
newFakeProxy(c, &fakeAuthConfig{EnableTLS: true}).RunTests(t, requests)

c.SkipOpenIDProviderTLSVerify = false

defer func() {
if r := recover(); r != nil {
check := strings.Contains(
r.(string),
"failed to retrieve the provider configuration from discovery url",
)
assert.True(t, check)
}
}()

newFakeProxy(c, &fakeAuthConfig{EnableTLS: true}).RunTests(t, requests)
}

func TestTokenHandler(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler
// exp: expiration of the access token
// expiresIn: expiration of the ID token
conf := r.newOAuth2Config(r.config.RedirectionURL)
_, newRawAccToken, newRefreshToken, accessExpiresAt, refreshExpiresIn, err := getRefreshedToken(conf, refresh)
_, newRawAccToken, newRefreshToken, accessExpiresAt, refreshExpiresIn, err := getRefreshedToken(conf, r, refresh)
if err != nil {
switch err {
case ErrRefreshTokenExpired:
Expand Down
20 changes: 18 additions & 2 deletions oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,24 @@ func (r *oauthProxy) newOAuth2Config(redirectionURL string) *oauth2.Config {
// NOTE: we may be able to extract the specific (non-standard) claim refresh_expires_in and refresh_expires
// from response.RawBody.
// When not available, keycloak provides us with the same (for now) expiry value for ID token.
func getRefreshedToken(conf *oauth2.Config, t string) (jwt.JSONWebToken, string, string, time.Time, time.Duration, error) {
tkn, err := conf.TokenSource(context.Background(), &oauth2.Token{RefreshToken: t}).Token()
func getRefreshedToken(conf *oauth2.Config, r *oauthProxy, t string) (jwt.JSONWebToken, string, string, time.Time, time.Duration, error) {
ctx, cancel := context.WithTimeout(
context.Background(),
r.config.OpenIDProviderTimeout,
)

if r.config.SkipOpenIDProviderTLSVerify {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
sslcli := &http.Client{Transport: tr}
ctx = context.WithValue(ctx, oauth2.HTTPClient, sslcli)
}

defer cancel()

tkn, err := conf.TokenSource(ctx, &oauth2.Token{RefreshToken: t}).Token()

if err != nil {
if strings.Contains(err.Error(), "refresh token has expired") {
return jwt.JSONWebToken{}, "", "", time.Time{}, time.Duration(0), ErrRefreshTokenExpired
Expand Down
38 changes: 38 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,44 @@ func TestForwardingProxy(t *testing.T) {
p.RunTests(t, requests)
}

func TestSkipOpenIDProviderTLSVerifyForwardingProxy(t *testing.T) {
cfg := newFakeKeycloakConfig()
cfg.EnableForwarding = true
cfg.ForwardingDomains = []string{}
cfg.ForwardingUsername = validUsername
cfg.ForwardingPassword = validPassword
cfg.SkipOpenIDProviderTLSVerify = true
s := httptest.NewServer(&fakeUpstreamService{})
requests := []fakeRequest{
{
URL: s.URL + "/test",
ProxyRequest: true,
ExpectedProxy: true,
ExpectedCode: http.StatusOK,
ExpectedContentContains: "Bearer ey",
},
}
p := newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true})
<-time.After(time.Duration(100) * time.Millisecond)
p.RunTests(t, requests)

cfg.SkipOpenIDProviderTLSVerify = false

defer func() {
if r := recover(); r != nil {
check := strings.Contains(
r.(string),
"failed to retrieve the provider configuration from discovery url",
)
assert.True(t, check)
}
}()

p = newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true})
<-time.After(time.Duration(100) * time.Millisecond)
p.RunTests(t, requests)
}

func TestForbiddenTemplate(t *testing.T) {
cfg := newFakeKeycloakConfig()
cfg.ForbiddenPage = "templates/forbidden.html.tmpl"
Expand Down

0 comments on commit 1c88a15

Please sign in to comment.