diff --git a/config.go b/config.go index 6418b5ad..48553098 100644 --- a/config.go +++ b/config.go @@ -34,34 +34,34 @@ func newDefaultConfig() *Config { hostnames = append(hostnames, []string{"localhost", "127.0.0.1"}...) return &Config{ - AccessTokenDuration: time.Duration(720) * time.Hour, - CookieAccessName: "kc-access", - CookieRefreshName: "kc-state", - EnableAuthorizationCookies: true, - EnableAuthorizationHeader: true, - EnableDefaultDeny: true, - EnableSessionCookies: true, - EnableTokenHeader: true, - HTTPOnlyCookie: true, - Headers: make(map[string]string), - LetsEncryptCacheDir: "./cache/", - MatchClaims: make(map[string]string), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 50, - OAuthURI: "/oauth", - OpenIDProviderTimeout: 30 * time.Second, - PreserveHost: false, - SelfSignedTLSExpiration: 3 * time.Hour, - SelfSignedTLSHostnames: hostnames, - RequestIDHeader: "X-Request-ID", - ResponseHeaders: make(map[string]string), - SecureCookie: true, - ServerIdleTimeout: 120 * time.Second, - ServerReadTimeout: 10 * time.Second, - ServerWriteTimeout: 10 * time.Second, - SkipOpenIDProviderTLSVerify: false, - SkipUpstreamTLSVerify: true, - Tags: make(map[string]string, 0), + AccessTokenDuration: time.Duration(720) * time.Hour, + CookieAccessName: "kc-access", + CookieRefreshName: "kc-state", + EnableAuthorizationCookies: true, + EnableAuthorizationHeader: true, + EnableDefaultDeny: true, + EnableSessionCookies: true, + EnableTokenHeader: true, + HTTPOnlyCookie: true, + Headers: make(map[string]string), + LetsEncryptCacheDir: "./cache/", + MatchClaims: make(map[string]string), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 50, + OAuthURI: "/oauth", + OpenIDProviderTimeout: 30 * time.Second, + PreserveHost: false, + SelfSignedTLSExpiration: 3 * time.Hour, + SelfSignedTLSHostnames: hostnames, + RequestIDHeader: "X-Request-ID", + ResponseHeaders: make(map[string]string), + SecureCookie: true, + ServerIdleTimeout: 120 * time.Second, + ServerReadTimeout: 10 * time.Second, + ServerWriteTimeout: 10 * time.Second, + SkipOpenIDProviderTLSVerify: false, + SkipUpstreamTLSVerify: true, + Tags: make(map[string]string, 0), UpstreamExpectContinueTimeout: 10 * time.Second, UpstreamKeepaliveTimeout: 10 * time.Second, UpstreamKeepalives: true, diff --git a/cookies.go b/cookies.go index eb4a996d..74fdf721 100644 --- a/cookies.go +++ b/cookies.go @@ -20,6 +20,8 @@ import ( "strconv" "strings" "time" + + "github.com/satori/go.uuid" ) // dropCookie drops a cookie into the response @@ -84,6 +86,13 @@ func (r *oauthProxy) dropRefreshTokenCookie(req *http.Request, w http.ResponseWr } } +// dropStateParameterCookie drops a state parameter cookie into the response +func (r *oauthProxy) writeStateParameterCookie(req *http.Request, w http.ResponseWriter) string { + uuid := uuid.NewV4().String() + r.dropCookie(w, req.Host, "OAuth_Token_Request_State", uuid, 0) + return uuid +} + // clearAllCookies is just a helper function for the below func (r *oauthProxy) clearAllCookies(req *http.Request, w http.ResponseWriter) { r.clearAccessTokenCookie(req, w) diff --git a/handlers.go b/handlers.go index 7a1f8a3a..99541d9c 100644 --- a/handlers.go +++ b/handlers.go @@ -55,6 +55,12 @@ func (r *oauthProxy) getRedirectionURL(w http.ResponseWriter, req *http.Request) redirect = r.config.RedirectionURL } + state, _ := req.Cookie("OAuth_Token_Request_State") + if state != nil && req.URL.Query().Get("state") != state.Value { + r.log.Error("State parameter mismatch") + w.WriteHeader(http.StatusForbidden) + return "" + } return fmt.Sprintf("%s%s", redirect, r.config.WithOAuthURI("callback")) } diff --git a/handlers_test.go b/handlers_test.go index 8042a7ef..a796a061 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -216,7 +216,7 @@ func TestServiceRedirect(t *testing.T) { URI: "/admin", Redirects: true, ExpectedCode: http.StatusTemporaryRedirect, - ExpectedLocation: "/oauth/authorize?state=L2FkbWlu", + ExpectedLocation: "/oauth/authorize?state", }, { URI: "/admin", @@ -242,25 +242,25 @@ func TestAuthorizationURL(t *testing.T) { { URI: "/admin", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWlu", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusTemporaryRedirect, }, { URI: "/admin/test", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWluL3Rlc3Q=", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusTemporaryRedirect, }, { URI: "/help/../admin", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWlu", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusTemporaryRedirect, }, { URI: "/admin?test=yes&test1=test", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWluP3Rlc3Q9eWVzJnRlc3QxPXRlc3Q=", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusTemporaryRedirect, }, { diff --git a/middleware_test.go b/middleware_test.go index 9c144e3b..fe78786d 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -21,6 +21,7 @@ import ( "log" "net" "net/http" + "net/url" "strings" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/coreos/go-oidc/jose" "github.com/go-resty/resty" "github.com/rs/cors" + uuid "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -214,8 +216,14 @@ func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) { assert.Equal(t, c.ExpectedCode, status, "case %d, expected status code: %d, got: %d", i, c.ExpectedCode, status) } if c.ExpectedLocation != "" { - l := resp.Header().Get("Location") - assert.Equal(t, c.ExpectedLocation, l, "case %d, expected location: %s, got: %s", i, c.ExpectedLocation, l) + l, _ := url.Parse(resp.Header().Get("Location")) + assert.True(t, strings.Contains(l.String(), c.ExpectedLocation), "Expected location to contain %s", l.String()) + if l.Query().Get("state") != "" { + state, err := uuid.FromString(l.Query().Get("state")) + if err != nil { + assert.Fail(t, "Expected state parameter with valid UUID, got: %s with error %s", state.String(), err) + } + } } if len(c.ExpectedHeaders) > 0 { for k, v := range c.ExpectedHeaders { diff --git a/misc.go b/misc.go index be7995ce..5e24d906 100644 --- a/misc.go +++ b/misc.go @@ -17,7 +17,6 @@ package main import ( "context" - "encoding/base64" "fmt" "net/http" "path" @@ -95,8 +94,10 @@ func (r *oauthProxy) redirectToAuthorization(w http.ResponseWriter, req *http.Re w.WriteHeader(http.StatusUnauthorized) return r.revokeProxy(w, req) } + // step: add a state referrer to the authorization page - authQuery := fmt.Sprintf("?state=%s", base64.StdEncoding.EncodeToString([]byte(req.URL.RequestURI()))) + uuid := r.writeStateParameterCookie(req, w) + authQuery := fmt.Sprintf("?state=%s", uuid) // step: if verification is switched off, we can't authorization if r.config.SkipTokenVerification { diff --git a/misc_test.go b/misc_test.go index 54e57f02..d1f9661a 100644 --- a/misc_test.go +++ b/misc_test.go @@ -35,7 +35,7 @@ func TestRedirectToAuthorization(t *testing.T) { { URI: "/admin", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWlu", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusTemporaryRedirect, }, } @@ -50,7 +50,7 @@ func TestRedirectToAuthorizationWith303Enabled(t *testing.T) { { URI: "/admin", Redirects: true, - ExpectedLocation: "/oauth/authorize?state=L2FkbWlu", + ExpectedLocation: "/oauth/authorize?state", ExpectedCode: http.StatusSeeOther, }, }