From 4b8a57e6b94b54922ca0b6c9a3f81b12a88745b7 Mon Sep 17 00:00:00 2001 From: p53 Date: Mon, 8 Aug 2022 17:28:55 +0200 Subject: [PATCH] Fix linting issues (#197) * Fix linting issues * Linting updates --- .github/workflows/build.yml | 2 +- .golangci.yml | 7 +- Makefile | 2 +- cli.go | 45 +++++----- common_test.go | 53 +++++++----- config.go | 10 ++- config_test.go | 2 + doc.go | 8 +- forwarding.go | 18 +++- handlers.go | 164 +++++++++++++++++++++--------------- handlers_test.go | 18 +++- middleware.go | 70 +++++++++++---- middleware_test.go | 14 ++- misc.go | 8 +- oauth.go | 1 - resource.go | 5 +- server.go | 44 +++++++--- server_test.go | 91 ++++++++++++-------- stores.go | 2 +- user_context.go | 14 ++- utils.go | 8 +- 21 files changed, 383 insertions(+), 203 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cdc077a5..d7eead21 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -66,6 +66,6 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: - version: v1.33 + version: v1.46 args: "--out-${NO_FUTURE}format colored-line-number" github-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.golangci.yml b/.golangci.yml index 63a82cd9..8c3c8270 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,7 @@ linters-settings: govet: check-shadowing: true - golint: + revive: min-confidence: 0 gocyclo: min-complexity: 64 @@ -20,6 +20,9 @@ linters-settings: linters: enable-all: true disable: + - thelper + - ireturn + - maintidx - wrapcheck - gosec - gocritic @@ -31,6 +34,8 @@ linters: - paralleltest - exhaustive - exhaustivestruct + - exhaustruct + - tagliatelle - maligned - unparam - lll diff --git a/Makefile b/Makefile index 7641e7b1..8731e28f 100644 --- a/Makefile +++ b/Makefile @@ -103,7 +103,7 @@ lint: @which golangci-lint 2>/dev/null ; if [ $$? -eq 1 ]; then \ go get -u github.com/golangci/golangci-lint/cmd/golangci-lint; \ fi - @golint . + @golangci-lint run . gofmt: @echo "--> Running gofmt check" diff --git a/cli.go b/cli.go index ca0e8c38..b777e539 100644 --- a/cli.go +++ b/cli.go @@ -88,8 +88,11 @@ func newOauthProxyApp() *cli.App { return app } -// getCommandLineOptions builds the command line options by reflecting the Config struct and extracting -// the tagged information +/* + getCommandLineOptions builds the command line options by reflecting + the Config struct and extracting the tagged information +*/ +//nolint:cyclop func getCommandLineOptions() []cli.Flag { defaults := newDefaultConfig() var flags []cli.Flag @@ -166,8 +169,12 @@ func getCommandLineOptions() []cli.Flag { return flags } -// parseCLIOptions parses the command line options and constructs a config object -func parseCLIOptions(cx *cli.Context, config *Config) (err error) { +/* + parseCLIOptions parses the command line options + and constructs a config object +*/ +//nolint:cyclop +func parseCLIOptions(cliCtx *cli.Context, config *Config) error { // step: we can ignore these options in the Config struct ignoredOptions := []string{"tag-data", "match-claims", "resources", "headers"} // step: iterate the Config and grab command line options via reflection @@ -181,53 +188,53 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) { continue } - if cx.IsSet(name) { + if cliCtx.IsSet(name) { switch field.Type.Kind() { case reflect.Bool: - reflect.ValueOf(config).Elem().FieldByName(field.Name).SetBool(cx.Bool(name)) + reflect.ValueOf(config).Elem().FieldByName(field.Name).SetBool(cliCtx.Bool(name)) case reflect.String: - reflect.ValueOf(config).Elem().FieldByName(field.Name).SetString(cx.String(name)) + reflect.ValueOf(config).Elem().FieldByName(field.Name).SetString(cliCtx.String(name)) case reflect.Slice: - reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.StringSlice(name))) + reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cliCtx.StringSlice(name))) case reflect.Int: - reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cx.Int(name))) + reflect.ValueOf(config).Elem().FieldByName(field.Name).Set(reflect.ValueOf(cliCtx.Int(name))) case reflect.Int64: switch field.Type.String() { case durationType: - reflect.ValueOf(config).Elem().FieldByName(field.Name).SetInt(int64(cx.Duration(name))) + reflect.ValueOf(config).Elem().FieldByName(field.Name).SetInt(int64(cliCtx.Duration(name))) default: - reflect.ValueOf(config).Elem().FieldByName(field.Name).SetInt(cx.Int64(name)) + reflect.ValueOf(config).Elem().FieldByName(field.Name).SetInt(cliCtx.Int64(name)) } } } } - if cx.IsSet("tag") { - tags, err := decodeKeyPairs(cx.StringSlice("tag")) + if cliCtx.IsSet("tag") { + tags, err := decodeKeyPairs(cliCtx.StringSlice("tag")) if err != nil { return err } mergeMaps(config.Tags, tags) } - if cx.IsSet("match-claims") { - claims, err := decodeKeyPairs(cx.StringSlice("match-claims")) + if cliCtx.IsSet("match-claims") { + claims, err := decodeKeyPairs(cliCtx.StringSlice("match-claims")) if err != nil { return err } mergeMaps(config.MatchClaims, claims) } - if cx.IsSet("headers") { - headers, err := decodeKeyPairs(cx.StringSlice("headers")) + if cliCtx.IsSet("headers") { + headers, err := decodeKeyPairs(cliCtx.StringSlice("headers")) if err != nil { return err } mergeMaps(config.Headers, headers) } - if cx.IsSet("resources") { - for _, x := range cx.StringSlice("resources") { + if cliCtx.IsSet("resources") { + for _, x := range cliCtx.StringSlice("resources") { resource, err := newResource().parse(x) if err != nil { return fmt.Errorf("invalid resource %s, %s", x, err) diff --git a/common_test.go b/common_test.go index 02b2988c..e8c4380a 100644 --- a/common_test.go +++ b/common_test.go @@ -218,7 +218,7 @@ func (f *fakeProxy) getServiceURL() string { } // RunTests performs a series of requests against a fake proxy service -// nolint:gocyclo,funlen +//nolint:gocyclo,funlen,cyclop func (f *fakeProxy) RunTests(t *testing.T, requests []fakeRequest) { defer func() { f.idp.Close() @@ -801,14 +801,14 @@ func makeTestCodeFlowLogin(location string) (*http.Response, []*http.Cookie, err } // step: make the request - tr := &http.Transport{ + transport := &http.Transport{ TLSClientConfig: &tls.Config{ //nolint:gas InsecureSkipVerify: true, }, } - resp, err = tr.RoundTrip(req) + resp, err = transport.RoundTrip(req) if err != nil { return nil, nil, err @@ -861,8 +861,7 @@ func (f *fakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque }).ServeHTTP(wrt, req) } else { wrt.Header().Set("Content-Type", "application/json") - wrt.WriteHeader(http.StatusOK) - content, _ := json.Marshal(&fakeUpstreamResponse{ + content, err := json.Marshal(&fakeUpstreamResponse{ // r.RequestURI is what was received by the proxy. // r.URL.String() is what is actually sent to the upstream service. // KEYCLOAK-10864, KEYCLOAK-11276, KEYCLOAK-13315 @@ -871,6 +870,13 @@ func (f *fakeUpstreamService) ServeHTTP(wrt http.ResponseWriter, req *http.Reque Address: req.RemoteAddr, Headers: req.Header, }) + + if err != nil { + wrt.WriteHeader(http.StatusInternalServerError) + } else { + wrt.WriteHeader(http.StatusOK) + } + _, _ = wrt.Write(content) } } @@ -1258,7 +1264,8 @@ func (r *fakeAuthServer) userInfoHandler(wrt http.ResponseWriter, req *http.Requ }) } -func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) { +//nolint:cyclop +func (r *fakeAuthServer) tokenHandler(writer http.ResponseWriter, req *http.Request) { expires := time.Now().Add(r.expiration) refreshExpires := time.Now().Add(2 * r.expiration) token := newTestToken(r.getLocation()) @@ -1281,14 +1288,14 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) // sign the token with the private key jwtAccess, err := token.getToken() if err != nil { - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } // sign the token with the private key jwtRefresh, err := refreshToken.getToken() if err != nil { - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -1298,12 +1305,12 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) password := req.FormValue("password") if username == "" || password == "" { - w.WriteHeader(http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) return } if username == validUsername && password == validPassword { - renderJSON(http.StatusOK, w, req, tokenResponse{ + renderJSON(http.StatusOK, writer, req, tokenResponse{ IDToken: jwtAccess, AccessToken: jwtAccess, RefreshToken: jwtRefresh, @@ -1312,7 +1319,7 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) return } - renderJSON(http.StatusUnauthorized, w, req, map[string]string{ + renderJSON(http.StatusUnauthorized, writer, req, map[string]string{ "error": "invalid_grant", "error_description": "invalid user credentials", }) @@ -1326,13 +1333,13 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) clientSecret = p if clientID == "" || clientSecret == "" || !ok { - w.WriteHeader(http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) return } } if clientID == validUsername && clientSecret == validPassword { - renderJSON(http.StatusOK, w, req, tokenResponse{ + renderJSON(http.StatusOK, writer, req, tokenResponse{ IDToken: jwtAccess, AccessToken: jwtAccess, RefreshToken: jwtRefresh, @@ -1341,7 +1348,7 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) return } - renderJSON(http.StatusUnauthorized, w, req, map[string]string{ + renderJSON(http.StatusUnauthorized, writer, req, map[string]string{ "error": "invalid_grant", "error_description": "invalid client credentials", }) @@ -1349,7 +1356,7 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) oldRefreshToken, err := jwt.ParseSigned(req.FormValue("refresh_token")) if err != nil { - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -1358,7 +1365,7 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) err = oldRefreshToken.UnsafeClaimsWithoutVerification(stdClaims) if err != nil { - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -1374,37 +1381,37 @@ func (r *fakeAuthServer) tokenHandler(w http.ResponseWriter, req *http.Request) respBody, err := json.Marshal(expRefresh) if err != nil { - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } - w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write(respBody) + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write(respBody) return } - renderJSON(http.StatusOK, w, req, tokenResponse{ + renderJSON(http.StatusOK, writer, req, tokenResponse{ IDToken: jwtAccess, AccessToken: jwtAccess, ExpiresIn: float64(expires.Second()), }) case GrantTypeAuthCode: - renderJSON(http.StatusOK, w, req, tokenResponse{ + renderJSON(http.StatusOK, writer, req, tokenResponse{ IDToken: jwtAccess, AccessToken: jwtAccess, RefreshToken: jwtRefresh, ExpiresIn: float64(expires.Second()), }) case GrantTypeUmaTicket: - renderJSON(http.StatusOK, w, req, tokenResponse{ + renderJSON(http.StatusOK, writer, req, tokenResponse{ IDToken: jwtAccess, AccessToken: jwtAccess, RefreshToken: jwtRefresh, ExpiresIn: float64(expires.Second()), }) default: - w.WriteHeader(http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) } } diff --git a/config.go b/config.go index 078dd24f..ce0a4234 100644 --- a/config.go +++ b/config.go @@ -203,6 +203,7 @@ func (r *Config) isSameSiteValid() error { return nil } +//nolint:cyclop func (r *Config) isTLSFilesValid() error { if r.TLSCertificate != "" && r.TLSPrivateKey == "" { return errors.New("you have not provided a private key") @@ -237,6 +238,7 @@ func (r *Config) isTLSFilesValid() error { return nil } +//nolint:cyclop func (r *Config) isAdminTLSFilesValid() error { if r.TLSAdminCertificate != "" && r.TLSAdminPrivateKey == "" { return errors.New("you have not provided a private key for admin endpoint") @@ -290,13 +292,13 @@ func (r *Config) isTLSMinValid() error { switch strings.ToLower(r.TLSMinVersion) { case "": return fmt.Errorf("minimal TLS version should not be empty") - // nolint: goconst + //nolint: goconst case "tlsv1.0": - // nolint: goconst + //nolint: goconst case "tlsv1.1": - // nolint: goconst + //nolint: goconst case "tlsv1.2": - // nolint: goconst + //nolint: goconst case "tlsv1.3": default: return fmt.Errorf("invalid minimal TLS version specified") diff --git a/config_test.go b/config_test.go index 01bae7e0..7051c4fa 100644 --- a/config_test.go +++ b/config_test.go @@ -474,6 +474,7 @@ func TestIsSameSiteValid(t *testing.T) { } } +//nolint:cyclop func TestIsTLSFilesValid(t *testing.T) { testCases := []struct { Name string @@ -679,6 +680,7 @@ func TestIsTLSFilesValid(t *testing.T) { } } +//nolint:cyclop func TestIsAdminTLSFilesValid(t *testing.T) { testCases := []struct { Name string diff --git a/doc.go b/doc.go index 5e19e6cd..2b38d8ca 100644 --- a/doc.go +++ b/doc.go @@ -94,15 +94,15 @@ var ( ) oauthLatencyMetric = prometheus.NewSummaryVec( prometheus.SummaryOpts{ - Name: "proxy_oauth_request_latency_sec", - Help: "A summary of the request latancy for requests against the openid provider", + Name: "proxy_oauth_request_latency", + Help: "A summary of the request latancy for requests against the openid provider, in seconds", }, []string{"action"}, ) latencyMetric = prometheus.NewSummary( prometheus.SummaryOpts{ - Name: "proxy_request_duration_sec", - Help: "A summary of the http request latency for proxy requests", + Name: "proxy_request_duration", + Help: "A summary of the http request latency for proxy requests, in seconds", }, ) statusMetric = prometheus.NewCounterVec( diff --git a/forwarding.go b/forwarding.go index 79acdccd..177b5fb9 100644 --- a/forwarding.go +++ b/forwarding.go @@ -24,7 +24,11 @@ import ( "go.uber.org/zap" ) -// proxyMiddleware is responsible for handles reverse proxy request to the upstream endpoint +/* + proxyMiddleware is responsible for handles reverse proxy + request to the upstream endpoint +*/ +//nolint:cyclop func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { next.ServeHTTP(wrt, req) @@ -33,7 +37,16 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { ctxVal := req.Context().Value(contextScopeName) var scope *RequestScope if ctxVal != nil { - scope = ctxVal.(*RequestScope) + var assertOk bool + scope, assertOk = ctxVal.(*RequestScope) + + if !assertOk { + r.log.Error( + "assertion failed", + ) + return + } + if scope.AccessDenied { return } @@ -89,7 +102,6 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { } // forwardProxyHandler is responsible for signing outbound requests -// nolint:funlen func (r *oauthProxy) forwardProxyHandler() func(*http.Request, *http.Response) { return func(req *http.Request, resp *http.Response) { var token string diff --git a/handlers.go b/handlers.go index 805ed841..559882af 100644 --- a/handlers.go +++ b/handlers.go @@ -90,7 +90,7 @@ func (r *oauthProxy) oauthAuthorizationHandler(wrt http.ResponseWriter, req *htt if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -129,11 +129,13 @@ func (r *oauthProxy) oauthAuthorizationHandler(wrt http.ResponseWriter, req *htt r.redirectToURL(authURL, wrt, req, http.StatusSeeOther) } -// oauthCallbackHandler is responsible for handling the response from oauth service -// nolint:funlen -func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Request) { +/* + oauthCallbackHandler is responsible for handling the response from oauth service +*/ +//nolint:funlen,cyclop +func (r *oauthProxy) oauthCallbackHandler(writer http.ResponseWriter, req *http.Request) { if r.config.SkipTokenVerification { - w.WriteHeader(http.StatusNotAcceptable) + writer.WriteHeader(http.StatusNotAcceptable) return } @@ -141,7 +143,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -150,11 +152,11 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque code := req.URL.Query().Get("code") if code == "" { - r.accessError(w, req) + r.accessError(writer, req) return } - conf := r.newOAuth2Config(r.getRedirectionURL(w, req)) + conf := r.newOAuth2Config(r.getRedirectionURL(writer, req)) resp, err := exchangeAuthenticationCode( conf, @@ -164,19 +166,19 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to exchange code for access token", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } - rawToken := "" + var rawToken string // Flow: once we exchange the authorization code we parse the ID Token; we then check for an access token, // if an access token is present and we can decode it, we use that as the session token, otherwise we default // to the ID Token. - rawIDToken, ok := resp.Extra("id_token").(string) + rawIDToken, assertOk := resp.Extra("id_token").(string) - if !ok { + if !assertOk { scope.Logger.Error("unable to obtain id token", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } @@ -197,7 +199,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to verify the id token", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } @@ -205,7 +207,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to parse id token", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } @@ -219,7 +221,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to parse id token for claims", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } @@ -231,7 +233,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to verify access token", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } } @@ -254,7 +256,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if err != nil { scope.Logger.Error("unable to parse access token for claims", zap.Error(err)) - r.accessForbidden(w, req) + r.accessForbidden(writer, req) return } @@ -264,7 +266,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque if r.config.EnableEncryptedToken || r.config.ForceEncryptedCookie { if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { scope.Logger.Error("unable to encode the access token", zap.Error(err)) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } } @@ -302,13 +304,13 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque zap.String("email", customClaims.Email), ) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } // drop in the access token - cookie expiration = access token r.dropAccessTokenCookie( req, - w, + writer, accessToken, r.getAccessCookieExpiration(resp.RefreshToken), ) @@ -327,7 +329,7 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque zap.String("email", customClaims.Email), ) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -352,12 +354,12 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque ) } default: - r.dropRefreshTokenCookie(req, w, encrypted, expiration) + r.dropRefreshTokenCookie(req, writer, encrypted, expiration) } } else { r.dropAccessTokenCookie( req, - w, + writer, accessToken, time.Until(stdClaims.Expiry.Time()), ) @@ -395,18 +397,21 @@ func (r *oauthProxy) oauthCallbackHandler(w http.ResponseWriter, req *http.Reque } scope.Logger.Debug("redirecting to", zap.String("location", redirectURI)) - r.redirectToURL(redirectURI, w, req, http.StatusSeeOther) + r.redirectToURL(redirectURI, writer, req, http.StatusSeeOther) } -// loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider -func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { +/* + loginHandler provide's a generic endpoint for clients to perform a user_credentials login to the provider +*/ +//nolint:funlen,cyclop // refactor +func (r *oauthProxy) loginHandler(writer http.ResponseWriter, req *http.Request) { scope, assertOk := req.Context().Value(contextScopeName).(*RequestScope) if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -441,7 +446,7 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { errors.New("no credentials") } - conf := r.newOAuth2Config(r.getRedirectionURL(w, req)) + conf := r.newOAuth2Config(r.getRedirectionURL(writer, req)) start := time.Now() token, err := conf.PasswordCredentialsToken(ctx, username, password) @@ -475,7 +480,7 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { err } - w.Header().Set("Content-Type", "application/json") + writer.Header().Set("Content-Type", "application/json") idToken, assertOk := token.Extra("id_token").(string) if !assertOk { @@ -493,6 +498,8 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { } // step: are we encrypting the access token? + var plainIDToken string + if r.config.EnableEncryptedToken || r.config.ForceEncryptedCookie { if accessToken, err = encodeText(accessToken, r.config.EncryptionKey); err != nil { scope.Logger.Error("unable to encode the access token", zap.Error(err)) @@ -508,6 +515,8 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { err } + plainIDToken = idToken + if idToken, err = encodeText(idToken, r.config.EncryptionKey); err != nil { scope.Logger.Error("unable to encode the idToken token", zap.Error(err)) return "unable to encode the idToken token", @@ -531,7 +540,7 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { // drop in the access token - cookie expiration = access token r.dropAccessTokenCookie( req, - w, + writer, accessToken, r.getAccessCookieExpiration(token.RefreshToken), ) @@ -567,12 +576,12 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { ) } default: - r.dropRefreshTokenCookie(req, w, encrypted, expiration) + r.dropRefreshTokenCookie(req, writer, encrypted, expiration) } } else { r.dropAccessTokenCookie( req, - w, + writer, accessToken, time.Until(identity.expiresAt), ) @@ -580,7 +589,19 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { // @metric a token has been issued oauthTokensMetric.WithLabelValues("login").Inc() - scope, _ := token.Extra("scope").(string) + tokenScope := token.Extra("scope") + var tScope string + + if tokenScope != nil { + tScope, assertOk = tokenScope.(string) + + if !assertOk { + return "", + http.StatusInternalServerError, + fmt.Errorf("assertion failed") + } + } + var resp tokenResponse if r.config.EnableEncryptedToken { @@ -589,19 +610,19 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: expiresIn, - Scope: scope, + Scope: tScope, } } else { resp = tokenResponse{ - IDToken: token.Extra("id_token").(string), + IDToken: plainIDToken, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, ExpiresIn: expiresIn, - Scope: scope, + Scope: tScope, } } - err = json.NewEncoder(w).Encode(resp) + err = json.NewEncoder(writer).Encode(resp) if err != nil { return "", http.StatusInternalServerError, err @@ -615,18 +636,21 @@ func (r *oauthProxy) loginHandler(w http.ResponseWriter, req *http.Request) { zap.String("client_ip", req.RemoteAddr), zap.Error(err)) - w.WriteHeader(code) + writer.WriteHeader(code) } } // emptyHandler is responsible for doing nothing func emptyHandler(w http.ResponseWriter, req *http.Request) {} -// logoutHandler performs a logout -// - if it's just a access token, the cookie is deleted -// - if the user has a refresh token, the token is invalidated by the provider -// - optionally, the user can be redirected by to a url -func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { +/* + logoutHandler performs a logout + - if it's just a access token, the cookie is deleted + - if the user has a refresh token, the token is invalidated by the provider + - optionally, the user can be redirected by to a url +*/ +//nolint:cyclop +func (r *oauthProxy) logoutHandler(writer http.ResponseWriter, req *http.Request) { // @check if the redirection is there var redirectURL string for k := range req.URL.Query() { @@ -647,9 +671,9 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -657,7 +681,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { user, err := r.getIdentity(req) if err != nil { - r.accessError(w, req) + r.accessError(writer, req) return } @@ -668,7 +692,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if refresh, _, err := r.retrieveRefreshToken(req, user); err == nil { identityToken = refresh } - r.clearAllCookies(req, w) + r.clearAllCookies(req, writer) // @metric increment the logout counter oauthTokensMetric.WithLabelValues("logout").Inc() @@ -711,7 +735,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { sendTo, url.QueryEscape(redirectURL), ), - w, + writer, req, http.StatusSeeOther, ) @@ -743,7 +767,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if err != nil { scope.Logger.Error("unable to retrieve the openid client", zap.Error(err)) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -762,7 +786,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if err != nil { scope.Logger.Error("unable to construct the revocation request", zap.Error(err)) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -775,7 +799,7 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { if err != nil { scope.Logger.Error("unable to post to revocation endpoint", zap.Error(err)) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } @@ -800,14 +824,14 @@ func (r *oauthProxy) logoutHandler(w http.ResponseWriter, req *http.Request) { zap.String("response", string(content)), ) - w.WriteHeader(http.StatusInternalServerError) + writer.WriteHeader(http.StatusInternalServerError) return } } // step: should we redirect the user if redirectURL != "" { - r.redirectToURL(redirectURL, w, req, http.StatusSeeOther) + r.redirectToURL(redirectURL, writer, req, http.StatusSeeOther) } } @@ -872,7 +896,8 @@ func (r *oauthProxy) healthHandler(w http.ResponseWriter, req *http.Request) { } // debugHandler is responsible for providing the pprof -func (r *oauthProxy) debugHandler(w http.ResponseWriter, req *http.Request) { +//nolint:cyclop +func (r *oauthProxy) debugHandler(writer http.ResponseWriter, req *http.Request) { const symbolProfile = "symbol" name := chi.URLParam(req, "name") @@ -886,24 +911,24 @@ func (r *oauthProxy) debugHandler(w http.ResponseWriter, req *http.Request) { case "block": fallthrough case "threadcreate": - pprof.Handler(name).ServeHTTP(w, req) + pprof.Handler(name).ServeHTTP(writer, req) case "cmdline": - pprof.Cmdline(w, req) + pprof.Cmdline(writer, req) case "profile": - pprof.Profile(w, req) + pprof.Profile(writer, req) case "trace": - pprof.Trace(w, req) + pprof.Trace(writer, req) case symbolProfile: - pprof.Symbol(w, req) + pprof.Symbol(writer, req) default: - w.WriteHeader(http.StatusNotFound) + writer.WriteHeader(http.StatusNotFound) } case http.MethodPost: switch name { case symbolProfile: - pprof.Symbol(w, req) + pprof.Symbol(writer, req) default: - w.WriteHeader(http.StatusNotFound) + writer.WriteHeader(http.StatusNotFound) } } } @@ -920,7 +945,10 @@ func (r *oauthProxy) proxyMetricsHandler(wrt http.ResponseWriter, req *http.Requ } // retrieveRefreshToken retrieves the refresh token from store or cookie -func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (token, encrypted string, err error) { +func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) (string, string, error) { + var token string + var err error + switch r.useStore() { case true: token, err = r.GetRefreshToken(user.rawToken) @@ -929,12 +957,12 @@ func (r *oauthProxy) retrieveRefreshToken(req *http.Request, user *userContext) } if err != nil { - return + return token, "", err } - encrypted = token // returns encrypted, avoids encoding twice + encrypted := token // returns encrypted, avoids encoding twice token, err = decodeText(token, r.config.EncryptionKey) - return + return token, encrypted, err } func methodNotAllowHandlder(w http.ResponseWriter, req *http.Request) { diff --git a/handlers_test.go b/handlers_test.go index 730ebd49..550e4e5d 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -256,8 +256,14 @@ func TestSkipOpenIDProviderTLSVerifyLoginHandler(t *testing.T) { defer func() { if r := recover(); r != nil { + failure, assertOk := r.(string) + + if !assertOk { + t.Fatalf("assertion failed") + } + check := strings.Contains( - r.(string), + failure, "failed to retrieve the provider configuration from discovery url", ) assert.True(t, check) @@ -267,7 +273,7 @@ func TestSkipOpenIDProviderTLSVerifyLoginHandler(t *testing.T) { newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true}).RunTests(t, requests) } -// nolint:funlen +//nolint:funlen func TestTokenEncryptionLoginHandler(t *testing.T) { cfg := newFakeKeycloakConfig() uri := cfg.WithOAuthURI(loginURL) @@ -668,8 +674,14 @@ func TestSkipOpenIDProviderTLSVerifyLogoutHandler(t *testing.T) { defer func() { if r := recover(); r != nil { + failure, assertOk := r.(string) + + if !assertOk { + t.Fatalf("assertion failed") + } + check := strings.Contains( - r.(string), + failure, "failed to retrieve the provider configuration from discovery url", ) assert.True(t, check) diff --git a/middleware.go b/middleware.go index 5050cba2..eb39f0fd 100644 --- a/middleware.go +++ b/middleware.go @@ -103,7 +103,7 @@ func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -112,7 +112,7 @@ func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -149,8 +149,10 @@ func (r *oauthProxy) loggingMiddleware(next http.Handler) http.Handler { }) } -// authenticationMiddleware is responsible for verifying the access token -// nolint:funlen +/* + authenticationMiddleware is responsible for verifying the access token +*/ +//nolint:funlen,cyclop func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { @@ -158,7 +160,7 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -385,12 +387,12 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler } if r.useStore() { - go func(old, new string, encrypted string) { + go func(old, newToken string, encrypted string) { if err := r.DeleteRefreshToken(old); err != nil { scope.Logger.Error("failed to remove old token", zap.Error(err)) } - if err := r.StoreRefreshToken(new, encrypted, refreshExpiresIn); err != nil { + if err := r.StoreRefreshToken(newToken, encrypted, refreshExpiresIn); err != nil { scope.Logger.Error("failed to store refresh token", zap.Error(err)) return } @@ -411,8 +413,10 @@ func (r *oauthProxy) authenticationMiddleware() func(http.Handler) http.Handler } } -// authorizationMiddleware is responsible for verifying permissions in access_token -// nolint:funlen +/* + authorizationMiddleware is responsible for verifying permissions in access_token +*/ +//nolint:cyclop func (r *oauthProxy) authorizationMiddleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { @@ -420,7 +424,7 @@ func (r *oauthProxy) authorizationMiddleware() func(http.Handler) http.Handler { if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -514,6 +518,7 @@ func (r *oauthProxy) authorizationMiddleware() func(http.Handler) http.Handler { } // checkClaim checks whether claim in userContext matches claimName, match. It can be String or Strings claim. +//nolint:cyclop func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *regexp.Regexp, resourceURL string) bool { errFields := []zapcore.Field{ zap.String("claim", claimName), @@ -529,7 +534,14 @@ func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *rege switch user.claims[claimName].(type) { case []interface{}: - for _, v := range user.claims[claimName].([]interface{}) { + claims, assertOk := user.claims[claimName].([]interface{}) + + if !assertOk { + r.log.Error("assertion failed") + return false + } + + for _, v := range claims { value, ok := v.(string) if !ok { @@ -552,6 +564,7 @@ func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *rege return true } } + r.log.Warn( "claim requirement does not match any element claim group in token", append( @@ -566,7 +579,14 @@ func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *rege return false case string: - if match.MatchString(user.claims[claimName].(string)) { + claims, assertOk := user.claims[claimName].(string) + + if !assertOk { + r.log.Error("assertion failed") + return false + } + + if match.MatchString(claims) { return true } @@ -574,7 +594,7 @@ func (r *oauthProxy) checkClaim(user *userContext, claimName string, match *rege "claim requirement does not match claim in token", append( errFields, - zap.String("issued", user.claims[claimName].(string)), + zap.String("issued", claims), zap.String("required", match.String()), )..., ) @@ -605,7 +625,7 @@ func (r *oauthProxy) admissionMiddleware(resource *Resource) func(http.Handler) if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -695,7 +715,14 @@ func (r *oauthProxy) identityHeadersMiddleware(custom []string) func(http.Handle return func(next http.Handler) http.Handler { return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - scope := req.Context().Value(contextScopeName).(*RequestScope) + scope, assertOk := req.Context().Value(contextScopeName).(*RequestScope) + + if !assertOk { + r.log.Error( + "assertion failed", + ) + return + } if scope.Identity != nil { user := scope.Identity @@ -752,7 +779,7 @@ func (r *oauthProxy) securityMiddleware(next http.Handler) http.Handler { if !assertOk { r.log.Error( - "Assertion failed", + "assertion failed", ) return } @@ -783,7 +810,7 @@ func (r *oauthProxy) methodCheckMiddleware(next http.Handler) http.Handler { } // proxyDenyMiddleware just block everything -func proxyDenyMiddleware(next http.Handler) http.Handler { +func (r *oauthProxy) proxyDenyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { ctxVal := req.Context().Value(contextScopeName) @@ -791,7 +818,14 @@ func proxyDenyMiddleware(next http.Handler) http.Handler { if ctxVal == nil { scope = &RequestScope{} } else { - scope = ctxVal.(*RequestScope) + var assertOk bool + scope, assertOk = ctxVal.(*RequestScope) + if !assertOk { + r.log.Error( + "assertion failed", + ) + return + } } scope.AccessDenied = true diff --git a/middleware_test.go b/middleware_test.go index ea5e1f08..ddb1bb93 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -118,6 +118,7 @@ func TestOauthRequests(t *testing.T) { newFakeProxy(cfg, &fakeAuthConfig{}).RunTests(t, requests) } +//nolint:cyclop func TestAdminListener(t *testing.T) { testCases := []struct { Name string @@ -1979,7 +1980,7 @@ func TestEnableUma(t *testing.T) { } } -// nolint:funlen +//nolint:funlen,cyclop func TestEnableUmaWithCache(t *testing.T) { cfg := newFakeKeycloakConfig() @@ -2333,7 +2334,14 @@ func TestEnableUmaWithCache(t *testing.T) { fProxy.RunTests(t, exSettings) - result := fProxy.proxy.store.(storage.RedisStore).Client.Keys("*") + redisStoreInstance, assertOk := fProxy.proxy.store.(storage.RedisStore) + + if !assertOk { + t.Fatalf("assertion failed") + } + + result := redisStoreInstance.Client.Keys("*") + if len(result.Val()) != testCase.ExpectedCacheEntries { t.Fatalf( "expected number of entries %d, got %d", @@ -2344,7 +2352,7 @@ func TestEnableUmaWithCache(t *testing.T) { if testCase.ExpectedCacheValues != authorization.UndefinedAuthz { for _, val := range result.Val() { - result := fProxy.proxy.store.(storage.RedisStore).Client.Get(val) + result := redisStoreInstance.Client.Get(val) if result.Val() != testCase.ExpectedCacheValues.String() { t.Fatalf( "expecting cached authz %s, got %s", diff --git a/misc.go b/misc.go index dde0e7c8..dbd5cc3c 100644 --- a/misc.go +++ b/misc.go @@ -64,7 +64,13 @@ func (r *oauthProxy) revokeProxy(w http.ResponseWriter, req *http.Request) conte case nil: scope = &RequestScope{AccessDenied: true} default: - scope = ctxVal.(*RequestScope) + var assertOk bool + scope, assertOk = ctxVal.(*RequestScope) + + if !assertOk { + r.log.Error("assertion failed") + scope = &RequestScope{AccessDenied: true} + } } scope.AccessDenied = true diff --git a/oauth.go b/oauth.go index 91ed9c98..a84096aa 100644 --- a/oauth.go +++ b/oauth.go @@ -28,7 +28,6 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) -//FIXME remove constants in the future which hopefully won't be necessary in the next releases const ( GrantTypeAuthCode = "authorization_code" GrantTypeUserCreds = "password" diff --git a/resource.go b/resource.go index ad2d7b9a..64d3a3d1 100644 --- a/resource.go +++ b/resource.go @@ -28,7 +28,10 @@ func newResource() *Resource { } } -// parse decodes a resource definition +/* + parse decodes a resource definition +*/ +//nolint:cyclop func (r *Resource) parse(resource string) (*Resource, error) { if resource == "" { return nil, errors.New("the resource has no options") diff --git a/server.go b/server.go index 158bef20..5a2d91e6 100644 --- a/server.go +++ b/server.go @@ -89,6 +89,7 @@ func init() { const allPath = "/*" // newProxy create's a new proxy from configuration +//nolint:cyclop func newProxy(config *Config) (*oauthProxy, error) { // create the service logger log, err := createLogger(config) @@ -244,6 +245,7 @@ func (r *oauthProxy) useDefaultStack(engine chi.Router) { } // createReverseProxy creates a reverse proxy +//nolint:cyclop func (r *oauthProxy) createReverseProxy() error { r.log.Info( "enabled reverse proxy mode, upstream url", @@ -259,7 +261,7 @@ func (r *oauthProxy) createReverseProxy() error { // @step: configure CORS middleware if len(r.config.CorsOrigins) > 0 { - c := cors.New(cors.Options{ + corsHandler := cors.New(cors.Options{ AllowedOrigins: r.config.CorsOrigins, AllowedMethods: r.config.CorsMethods, AllowedHeaders: r.config.CorsHeaders, @@ -269,7 +271,7 @@ func (r *oauthProxy) createReverseProxy() error { Debug: r.config.Verbose, }) - engine.Use(c.Handler) + engine.Use(corsHandler.Handler) } engine.Use(r.proxyMiddleware) @@ -298,7 +300,7 @@ func (r *oauthProxy) createReverseProxy() error { } // step: add the routing for oauth - engine.With(proxyDenyMiddleware).Route(r.config.BaseURI+r.config.OAuthURI, func(eng chi.Router) { + engine.With(r.proxyDenyMiddleware).Route(r.config.BaseURI+r.config.OAuthURI, func(eng chi.Router) { eng.MethodNotAllowed(methodNotAllowHandlder) eng.HandleFunc(authorizationURL, r.oauthAuthorizationHandler) eng.Get(callbackURL, r.oauthCallbackHandler) @@ -334,7 +336,7 @@ func (r *oauthProxy) createReverseProxy() error { } if r.config.ListenAdmin == "" { - engine.With(proxyDenyMiddleware).Mount(debugURL, debugEngine) + engine.With(r.proxyDenyMiddleware).Mount(debugURL, debugEngine) } } @@ -346,7 +348,7 @@ func (r *oauthProxy) createReverseProxy() error { admin.MethodNotAllowed(emptyHandler) admin.NotFound(emptyHandler) admin.Use(middleware.Recoverer) - admin.Use(proxyDenyMiddleware) + admin.Use(r.proxyDenyMiddleware) admin.Route("/", func(e chi.Router) { e.Mount(r.config.OAuthURI, adminEngine) if debugEngine != nil { @@ -421,7 +423,7 @@ func (r *oauthProxy) createReverseProxy() error { if res.URL == allPath && !res.WhiteListed && enableDefaultDenyStrict { middlewares = []func(http.Handler) http.Handler{ r.denyMiddleware, - proxyDenyMiddleware, + r.proxyDenyMiddleware, } } @@ -486,7 +488,12 @@ func (r *oauthProxy) createForwardingProxy() error { forwardingHandler := r.forwardProxyHandler() // set the http handler - proxy := r.upstream.(*goproxy.ProxyHttpServer) + proxy, assertOk := r.upstream.(*goproxy.ProxyHttpServer) + + if !assertOk { + return fmt.Errorf("assertion failed") + } + r.router = proxy // setup the tls configuration @@ -514,7 +521,13 @@ func (r *oauthProxy) createForwardingProxy() error { proxy.OnResponse().DoFunc(func(resp *http.Response, ctx *goproxy.ProxyCtx) *http.Response { // @NOTES, somewhat annoying but goproxy hands back a nil response on proxy client errors if resp != nil && r.config.EnableLogging { - start := ctx.UserData.(time.Time) + start, assertOk := ctx.UserData.(time.Time) + + if !assertOk { + r.log.Error("assertion failed") + return nil + } + latency := time.Since(start) latencyMetric.Observe(latency.Seconds()) @@ -540,6 +553,7 @@ func (r *oauthProxy) createForwardingProxy() error { } // Run starts the proxy service +//nolint:cyclop func (r *oauthProxy) Run() error { listener, err := r.createHTTPListener(makeListenerConfig(r.config)) @@ -727,6 +741,7 @@ func makeListenerConfig(config *Config) listenerConfig { var ErrHostNotConfigured = errors.New("acme/autocert: host not configured") // createHTTPListener is responsible for creating a listening socket +//nolint:cyclop func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, error) { var listener net.Listener var err error @@ -773,7 +788,7 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er if config.useLetsEncryptTLS { r.log.Info("enabling letsencrypt tls support") - m := autocert.Manager{ + manager := autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(config.letsEncryptCacheDir), HostPolicy: func(_ context.Context, host string) error { @@ -799,7 +814,7 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er }, } - getCertificate = m.GetCertificate + getCertificate = manager.GetCertificate } if config.useSelfSignedTLS { @@ -955,7 +970,13 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { r.upstream = proxy // update the tls configuration of the reverse proxy - r.upstream.(*goproxy.ProxyHttpServer).Tr = &http.Transport{ + upstreamProxy, assertOk := r.upstream.(*goproxy.ProxyHttpServer) + + if !assertOk { + return fmt.Errorf("assertion failed") + } + + upstreamProxy.Tr = &http.Transport{ Dial: dialer, DisableKeepAlives: !r.config.UpstreamKeepalives, ExpectContinueTimeout: r.config.UpstreamExpectContinueTimeout, @@ -1056,6 +1077,7 @@ func (r *oauthProxy) Render(w io.Writer, name string, data interface{}) error { return r.templates.ExecuteTemplate(w, name, data) } +//nolint:cyclop func (r *oauthProxy) getPAT(done chan bool) { retry := 0 r.pat = &PAT{} diff --git a/server_test.go b/server_test.go index 29e2a2ca..341c2501 100644 --- a/server_test.go +++ b/server_test.go @@ -57,8 +57,8 @@ func TestNewKeycloakProxy(t *testing.T) { } func TestReverseProxyHeaders(t *testing.T) { - p := newFakeProxy(nil, &fakeAuthConfig{}) - token := newTestToken(p.idp.getLocation()) + proxy := newFakeProxy(nil, &fakeAuthConfig{}) + token := newTestToken(proxy.idp.getLocation()) token.addRealmRoles([]string{fakeAdminRole}) jwt, _ := token.getToken() uri := "/auth_all/test" @@ -84,7 +84,7 @@ func TestReverseProxyHeaders(t *testing.T) { ExpectedContentContains: `"uri":"` + uri + `"`, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestAuthTokenHeader(t *testing.T) { @@ -483,25 +483,31 @@ func TestSkipOpenIDProviderTLSVerifyForwardingProxy(t *testing.T) { ExpectedContentContains: "Bearer ey", }, } - p := newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true}) + proxy := newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true}) <-time.After(time.Duration(100) * time.Millisecond) - p.RunTests(t, requests) + proxy.RunTests(t, requests) cfg.SkipOpenIDProviderTLSVerify = false defer func() { if r := recover(); r != nil { + failure, assertOk := r.(string) + + if !assertOk { + t.Fatalf("assertion failed") + } + check := strings.Contains( - r.(string), + failure, "failed to retrieve the provider configuration from discovery url", ) assert.True(t, check) } }() - p = newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true}) + proxy = newFakeProxy(cfg, &fakeAuthConfig{EnableTLS: true}) <-time.After(time.Duration(100) * time.Millisecond) - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestForbiddenTemplate(t *testing.T) { @@ -610,8 +616,14 @@ func TestSkipOpenIDProviderTLSVerify(t *testing.T) { defer func() { if r := recover(); r != nil { + failure, assertOk := r.(string) + + if !assertOk { + t.Fatalf("assertion failed") + } + check := strings.Contains( - r.(string), + failure, "failed to retrieve the provider configuration from discovery url", ) assert.True(t, check) @@ -650,8 +662,14 @@ func TestOpenIDProviderProxy(t *testing.T) { defer func() { if r := recover(); r != nil { + failure, assertOk := r.(string) + + if !assertOk { + t.Fatalf("assertion failed") + } + check := strings.Contains( - r.(string), + failure, "failed to retrieve the provider configuration from discovery url", ) assert.True(t, check) @@ -682,8 +700,8 @@ func TestRequestIDHeader(t *testing.T) { func TestAuthTokenHeaderDisabled(t *testing.T) { c := newFakeKeycloakConfig() c.EnableTokenHeader = false - p := newFakeProxy(c, &fakeAuthConfig{}) - token := newTestToken(p.idp.getLocation()) + proxy := newFakeProxy(c, &fakeAuthConfig{}) + token := newTestToken(proxy.idp.getLocation()) jwt, _ := token.getToken() requests := []fakeRequest{ @@ -695,7 +713,7 @@ func TestAuthTokenHeaderDisabled(t *testing.T) { ExpectedCode: http.StatusOK, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestAudienceHeader(t *testing.T) { @@ -1036,7 +1054,7 @@ func TestCustomResponseHeaders(t *testing.T) { c.ResponseHeaders = map[string]string{ "CustomReponseHeader": "True", } - p := newFakeProxy(c, &fakeAuthConfig{}) + proxy := newFakeProxy(c, &fakeAuthConfig{}) requests := []fakeRequest{ { @@ -1050,7 +1068,7 @@ func TestCustomResponseHeaders(t *testing.T) { ExpectedCode: http.StatusOK, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestSkipClientIDDisabled(t *testing.T) { @@ -1058,13 +1076,13 @@ func TestSkipClientIDDisabled(t *testing.T) { // client for which was access token released, but this is not according spec // as access_token could be also other type not just JWT c := newFakeKeycloakConfig() - p := newFakeProxy(c, &fakeAuthConfig{}) + proxy := newFakeProxy(c, &fakeAuthConfig{}) // create two token, one with a bad client id - bad := newTestToken(p.idp.getLocation()) + bad := newTestToken(proxy.idp.getLocation()) bad.claims.Aud = "bad_client_id" badSigned, _ := bad.getToken() // and the good - good := newTestToken(p.idp.getLocation()) + good := newTestToken(proxy.idp.getLocation()) goodSigned, _ := good.getToken() requests := []fakeRequest{ { @@ -1096,18 +1114,18 @@ func TestSkipClientIDDisabled(t *testing.T) { SkipClientIDCheck: true, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestSkipIssuer(t *testing.T) { c := newFakeKeycloakConfig() - p := newFakeProxy(c, &fakeAuthConfig{}) + proxy := newFakeProxy(c, &fakeAuthConfig{}) // create two token, one with a bad client id - bad := newTestToken(p.idp.getLocation()) + bad := newTestToken(proxy.idp.getLocation()) bad.claims.Iss = "bad_issuer" badSigned, _ := bad.getToken() // and the good - good := newTestToken(p.idp.getLocation()) + good := newTestToken(proxy.idp.getLocation()) goodSigned, _ := good.getToken() requests := []fakeRequest{ { @@ -1139,12 +1157,12 @@ func TestSkipIssuer(t *testing.T) { SkipIssuerCheck: true, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestAuthTokenHeaderEnabled(t *testing.T) { - p := newFakeProxy(nil, &fakeAuthConfig{}) - token := newTestToken(p.idp.getLocation()) + proxy := newFakeProxy(nil, &fakeAuthConfig{}) + token := newTestToken(proxy.idp.getLocation()) signed, _ := token.getToken() requests := []fakeRequest{ @@ -1158,14 +1176,14 @@ func TestAuthTokenHeaderEnabled(t *testing.T) { ExpectedCode: http.StatusOK, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } func TestDisableAuthorizationCookie(t *testing.T) { cfg := newFakeKeycloakConfig() cfg.EnableAuthorizationCookies = false - p := newFakeProxy(cfg, &fakeAuthConfig{}) - token := newTestToken(p.idp.getLocation()) + proxy := newFakeProxy(cfg, &fakeAuthConfig{}) + token := newTestToken(proxy.idp.getLocation()) signed, _ := token.getToken() requests := []fakeRequest{ @@ -1181,9 +1199,10 @@ func TestDisableAuthorizationCookie(t *testing.T) { ExpectedProxy: true, }, } - p.RunTests(t, requests) + proxy.RunTests(t, requests) } +//nolint:cyclop func TestTLS(t *testing.T) { testProxyAddr := "127.0.0.1:14302" testCases := []struct { @@ -1509,6 +1528,7 @@ func TestCustomHTTPMethod(t *testing.T) { } } +//nolint:cyclop func TestStoreAuthz(t *testing.T) { cfg := newFakeKeycloakConfig() token := newTestToken("http://test") @@ -1553,7 +1573,7 @@ func TestStoreAuthz(t *testing.T) { testCase.Name, func(t *testing.T) { testCase.ProxySettings(&c) - p := newFakeProxy(&c, &fakeAuthConfig{}) + fProxy := newFakeProxy(&c, &fakeAuthConfig{}) url, err := url.Parse("http://test.com/test") @@ -1561,7 +1581,7 @@ func TestStoreAuthz(t *testing.T) { t.Fatal("Problem parsing url") } - err = p.proxy.StoreAuthz(jwt, url, authorization.AllowedAuthz, 1*time.Second) + err = fProxy.proxy.StoreAuthz(jwt, url, authorization.AllowedAuthz, 1*time.Second) if err != nil && !testCase.ExpectedFailure { t.Fatalf("error storing authz %v", err) @@ -1569,7 +1589,7 @@ func TestStoreAuthz(t *testing.T) { if !testCase.ExpectedFailure { url.Path += "/append" - err = p.proxy.StoreAuthz(jwt, url, authorization.AllowedAuthz, 1*time.Second) + err = fProxy.proxy.StoreAuthz(jwt, url, authorization.AllowedAuthz, 1*time.Second) if err != nil { t.Fatalf("error storing authz %v", err) @@ -1596,6 +1616,7 @@ func TestStoreAuthz(t *testing.T) { } } +//nolint:cyclop func TestGetAuthz(t *testing.T) { cfg := newFakeKeycloakConfig() token := newTestToken("http://test") @@ -1675,7 +1696,7 @@ func TestGetAuthz(t *testing.T) { testCase.Name, func(t *testing.T) { testCase.ProxySettings(&cfg) - p := newFakeProxy(&cfg, &fakeAuthConfig{}) + fProxy := newFakeProxy(&cfg, &fakeAuthConfig{}) url, err := url.Parse("http://test.com/test") @@ -1684,14 +1705,14 @@ func TestGetAuthz(t *testing.T) { } if !testCase.ExpectedFailure { - err = p.proxy.StoreAuthz(testCase.JWT, url, authorization.AllowedAuthz, 1*time.Second) + err = fProxy.proxy.StoreAuthz(testCase.JWT, url, authorization.AllowedAuthz, 1*time.Second) if err != nil { t.Fatalf("error storing authz %s", err) } } - dec, err := p.proxy.GetAuthz(testCase.JWT, url) + dec, err := fProxy.proxy.GetAuthz(testCase.JWT, url) if err != nil { if !testCase.ExpectedFailure { diff --git a/stores.go b/stores.go index 83beb060..1267a7fe 100644 --- a/stores.go +++ b/stores.go @@ -64,7 +64,7 @@ func (r *oauthProxy) DeleteRefreshToken(token string) error { } // StoreAuthz -// nolint:interfacer +//nolint:interfacer func (r *oauthProxy) StoreAuthz(token string, url *url.URL, value authorization.AuthzDecision, expiration time.Duration) error { if len(token) == 0 { return fmt.Errorf("token of zero length") diff --git a/user_context.go b/user_context.go index b126dd70..0667ec70 100644 --- a/user_context.go +++ b/user_context.go @@ -74,10 +74,20 @@ func extractIdentity(token *jwt.JSONWebToken) (*userContext, error) { // @step: extract the client roles from the access token for name, list := range customClaims.ResourceAccess { - scopes := list.(map[string]interface{}) + scopes, assertOk := list.(map[string]interface{}) + + if !assertOk { + return nil, fmt.Errorf("assertion failed") + } if roles, found := scopes[claimResourceRoles]; found { - for _, r := range roles.([]interface{}) { + rolesVal, assertOk := roles.([]interface{}) + + if !assertOk { + return nil, fmt.Errorf("assertion failed") + } + + for _, r := range rolesVal { roleList = append(roleList, fmt.Sprintf("%s:%s", name, r)) } } diff --git a/utils.go b/utils.go index 1a1b60fe..f873208e 100644 --- a/utils.go +++ b/utils.go @@ -355,12 +355,14 @@ func tryUpdateConnection(req *http.Request, writer http.ResponseWriter, endpoint defer server.Close() // @check the the response writer implements the Hijack method - if _, ok := writer.(http.Hijacker); !ok { - return errors.New("writer does not implement http.Hijacker method") + hijacker, assertOk := writer.(http.Hijacker) + + if !assertOk { + return fmt.Errorf("writer does not implement http.Hijacker method") } // @step: get the client connection object - client, _, err := writer.(http.Hijacker).Hijack() + client, _, err := hijacker.Hijack() if err != nil { return err