Skip to content

Commit

Permalink
Move functions in handlers to common pkg (#466)
Browse files Browse the repository at this point in the history
Move common parts in handlers, refactor
  • Loading branch information
p53 committed May 21, 2024
1 parent 679ec15 commit 28b055b
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 249 deletions.
146 changes: 2 additions & 144 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"errors"
"fmt"
"io"
"net"

"net/http"
"net/url"
Expand All @@ -36,6 +35,7 @@ import (
"github.com/gogatekeeper/gatekeeper/pkg/constant"
"github.com/gogatekeeper/gatekeeper/pkg/encryption"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/handlers"
"github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics"
"github.com/gogatekeeper/gatekeeper/pkg/storage"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
Expand All @@ -44,67 +44,6 @@ import (
"golang.org/x/oauth2"
)

type DiscoveryResponse struct {
ExpiredURL string `json:"expired_endpoint"`
LogoutURL string `json:"logout_endpoint"`
TokenURL string `json:"token_endpoint"`
LoginURL string `json:"login_endpoint"`
}

// getRedirectionURL returns the redirectionURL for the oauth flow
func getRedirectionURL(
logger *zap.Logger,
redirectionURL string,
noProxy bool,
noRedirects bool,
secureCookie bool,
cookieOAuthStateName string,
withOAuthURI func(string) string,
) func(wrt http.ResponseWriter, req *http.Request) string {
return func(wrt http.ResponseWriter, req *http.Request) string {
var redirect string

switch redirectionURL {
case "":
var scheme string
var host string

if noProxy && !noRedirects {
scheme = req.Header.Get("X-Forwarded-Proto")
host = req.Header.Get("X-Forwarded-Host")
} else {
// need to determine the scheme, cx.Request.URL.Scheme doesn't have it, best way is to default
// and then check for TLS
scheme = constant.UnsecureScheme
host = req.Host
if req.TLS != nil {
scheme = constant.SecureScheme
}
}

if scheme == constant.UnsecureScheme && secureCookie {
hint := "you have secure cookie set to true but using http "
hint += "use https or secure cookie false"
logger.Warn(hint)
}

redirect = fmt.Sprintf("%s://%s", scheme, host)
default:
redirect = redirectionURL
}

state, _ := req.Cookie(cookieOAuthStateName)

if state != nil && req.URL.Query().Get("state") != state.Value {
logger.Error("state parameter mismatch")
wrt.WriteHeader(http.StatusForbidden)
return ""
}

return fmt.Sprintf("%s%s", redirect, withOAuthURI(constant.CallbackURL))
}
}

// oauthAuthorizationHandler is responsible for performing the redirection to oauth provider
//
//nolint:cyclop
Expand Down Expand Up @@ -756,7 +695,7 @@ func logoutHandler(
identityToken = refresh
}

idToken, _, err := retrieveIDToken(
idToken, _, err := handlers.RetrieveIDToken(
cookieIDTokenName,
enableEncryptedToken,
forceEncryptedCookie,
Expand Down Expand Up @@ -952,23 +891,6 @@ func tokenHandler(
}
}

// proxyMetricsHandler forwards the request into the prometheus handler
func proxyMetricsHandler(
localhostMetrics bool,
accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context,
metricsHandler http.Handler,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
if localhostMetrics {
if !net.ParseIP(utils.RealIP(req)).IsLoopback() {
accessForbidden(wrt, req)
return
}
}
metricsHandler.ServeHTTP(wrt, req)
}
}

// retrieveRefreshToken retrieves the refresh token from store or cookie
func retrieveRefreshToken(
store storage.Storage,
Expand All @@ -995,67 +917,3 @@ func retrieveRefreshToken(
token, err = encryption.DecodeText(token, encryptionKey)
return token, encrypted, err
}

// retrieveIDToken retrieves the id token from cookie
func retrieveIDToken(
cookieIDTokenName string,
enableEncryptedToken bool,
forceEncryptedCookie bool,
encryptionKey string,
req *http.Request,
) (string, string, error) {
var token string
var err error
var encrypted string

token, err = utils.GetTokenInCookie(req, cookieIDTokenName)

if err != nil {
return token, "", err
}

if enableEncryptedToken || forceEncryptedCookie {
encrypted = token
token, err = encryption.DecodeText(token, encryptionKey)
}

return token, encrypted, err
}

// discoveryHandler provides endpoint info
func discoveryHandler(
logger *zap.Logger,
withOAuthURI func(string) string,
) func(wrt http.ResponseWriter, _ *http.Request) {
return func(wrt http.ResponseWriter, _ *http.Request) {
resp := &DiscoveryResponse{
ExpiredURL: withOAuthURI(constant.ExpiredURL),
LogoutURL: withOAuthURI(constant.LogoutURL),
TokenURL: withOAuthURI(constant.TokenURL),
LoginURL: withOAuthURI(constant.LoginURL),
}

respBody, err := json.Marshal(resp)

if err != nil {
logger.Error(
apperrors.ErrMarshallDiscoveryResp.Error(),
zap.String("error", err.Error()),
)

wrt.WriteHeader(http.StatusInternalServerError)
return
}

wrt.Header().Set("Content-Type", "application/json")
wrt.WriteHeader(http.StatusOK)
_, err = wrt.Write(respBody)

if err != nil {
logger.Error(
apperrors.ErrDiscoveryResponseWrite.Error(),
zap.String("error", err.Error()),
)
}
}
}
6 changes: 3 additions & 3 deletions pkg/keycloak/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (r *OauthProxy) CreateReverseProxy() error {
r.Config.EncryptionKey,
)

r.getRedirectionURL = getRedirectionURL(
r.getRedirectionURL = handlers.GetRedirectionURL(
r.Log,
r.Config.RedirectionURL,
r.Config.NoProxy,
Expand Down Expand Up @@ -423,7 +423,7 @@ func (r *OauthProxy) CreateReverseProxy() error {
)
adminEngine.Get(
constant.MetricsURL,
proxyMetricsHandler(
handlers.ProxyMetricsHandler(
r.Config.LocalhostMetrics,
r.accessForbidden,
r.metricsHandler,
Expand Down Expand Up @@ -550,7 +550,7 @@ func (r *OauthProxy) CreateReverseProxy() error {
tokenHandler(r.GetIdentity, r.Config.CookieAccessName, r.accessError),
)
eng.Post(constant.LoginURL, loginHand)
eng.Get(constant.DiscoveryURL, discoveryHandler(r.Log, r.WithOAuthURI))
eng.Get(constant.DiscoveryURL, handlers.DiscoveryHandler(r.Log, r.WithOAuthURI))

if r.Config.ListenAdmin == "" {
eng.Mount("/", adminEngine)
Expand Down
150 changes: 150 additions & 0 deletions pkg/proxy/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@ limitations under the License.
package handlers

import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/pprof"

"github.com/go-chi/chi/v5"
"github.com/gogatekeeper/gatekeeper/pkg/apperrors"
"github.com/gogatekeeper/gatekeeper/pkg/constant"
"github.com/gogatekeeper/gatekeeper/pkg/encryption"
proxycore "github.com/gogatekeeper/gatekeeper/pkg/proxy/core"
"github.com/gogatekeeper/gatekeeper/pkg/utils"
"go.uber.org/zap"
)

type DiscoveryResponse struct {
ExpiredURL string `json:"expired_endpoint"`
LogoutURL string `json:"logout_endpoint"`
TokenURL string `json:"token_endpoint"`
LoginURL string `json:"login_endpoint"`
}

// EmptyHandler is responsible for doing nothing
func EmptyHandler(_ http.ResponseWriter, _ *http.Request) {}

Expand Down Expand Up @@ -78,3 +93,138 @@ func MethodNotAllowHandlder(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = w.Write(nil)
}

// ProxyMetricsHandler forwards the request into the prometheus handler
func ProxyMetricsHandler(
localhostMetrics bool,
accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context,
metricsHandler http.Handler,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
if localhostMetrics {
if !net.ParseIP(utils.RealIP(req)).IsLoopback() {
accessForbidden(wrt, req)
return
}
}
metricsHandler.ServeHTTP(wrt, req)
}
}

// RetrieveIDToken retrieves the id token from cookie
func RetrieveIDToken(
cookieIDTokenName string,
enableEncryptedToken bool,
forceEncryptedCookie bool,
encryptionKey string,
req *http.Request,
) (string, string, error) {
var token string
var err error
var encrypted string

token, err = utils.GetTokenInCookie(req, cookieIDTokenName)

if err != nil {
return token, "", err
}

if enableEncryptedToken || forceEncryptedCookie {
encrypted = token
token, err = encryption.DecodeText(token, encryptionKey)
}

return token, encrypted, err
}

// discoveryHandler provides endpoint info
func DiscoveryHandler(
logger *zap.Logger,
withOAuthURI func(string) string,
) func(wrt http.ResponseWriter, _ *http.Request) {
return func(wrt http.ResponseWriter, _ *http.Request) {
resp := &DiscoveryResponse{
ExpiredURL: withOAuthURI(constant.ExpiredURL),
LogoutURL: withOAuthURI(constant.LogoutURL),
TokenURL: withOAuthURI(constant.TokenURL),
LoginURL: withOAuthURI(constant.LoginURL),
}

respBody, err := json.Marshal(resp)

if err != nil {
logger.Error(
apperrors.ErrMarshallDiscoveryResp.Error(),
zap.String("error", err.Error()),
)

wrt.WriteHeader(http.StatusInternalServerError)
return
}

wrt.Header().Set("Content-Type", "application/json")
wrt.WriteHeader(http.StatusOK)
_, err = wrt.Write(respBody)

if err != nil {
logger.Error(
apperrors.ErrDiscoveryResponseWrite.Error(),
zap.String("error", err.Error()),
)
}
}
}

// getRedirectionURL returns the redirectionURL for the oauth flow
func GetRedirectionURL(
logger *zap.Logger,
redirectionURL string,
noProxy bool,
noRedirects bool,
secureCookie bool,
cookieOAuthStateName string,
withOAuthURI func(string) string,
) func(wrt http.ResponseWriter, req *http.Request) string {
return func(wrt http.ResponseWriter, req *http.Request) string {
var redirect string

switch redirectionURL {
case "":
var scheme string
var host string

if noProxy && !noRedirects {
scheme = req.Header.Get("X-Forwarded-Proto")
host = req.Header.Get("X-Forwarded-Host")
} else {
// need to determine the scheme, cx.Request.URL.Scheme doesn't have it, best way is to default
// and then check for TLS
scheme = constant.UnsecureScheme
host = req.Host
if req.TLS != nil {
scheme = constant.SecureScheme
}
}

if scheme == constant.UnsecureScheme && secureCookie {
hint := "you have secure cookie set to true but using http "
hint += "use https or secure cookie false"
logger.Warn(hint)
}

redirect = fmt.Sprintf("%s://%s", scheme, host)
default:
redirect = redirectionURL
}

state, _ := req.Cookie(cookieOAuthStateName)

if state != nil && req.URL.Query().Get("state") != state.Value {
logger.Error("state parameter mismatch")
wrt.WriteHeader(http.StatusForbidden)
return ""
}

return fmt.Sprintf("%s%s", redirect, withOAuthURI(constant.CallbackURL))
}
}
Loading

0 comments on commit 28b055b

Please sign in to comment.