diff --git a/accesslog/middleware.go b/accesslog/middleware.go index 25901a01..126ac7ee 100644 --- a/accesslog/middleware.go +++ b/accesslog/middleware.go @@ -17,7 +17,10 @@ import ( "bytes" "fmt" "net" + "net/http" "os" + "path" + "runtime" "strings" "text/template" "time" @@ -42,6 +45,92 @@ const ( type AccessLogMiddleware struct { Format AccessLogFormat textTemplate *template.Template + + DisableLog func(statusCode int, r *rest.Request) bool + + recorder *rest.RecorderMiddleware +} + +const MaxTraceback = 32 + +func collectTrace() string { + var ( + trace [MaxTraceback]uintptr + traceback strings.Builder + ) + // Skip 4 + // = accesslog.LogFunc + // + accesslog.collectTrace + // + runtime.Callers + // + runtime.gopanic + n := runtime.Callers(4, trace[:]) + frames := runtime.CallersFrames(trace[:n]) + for frame, more := frames.Next(); frame.PC != 0 && + n >= 0; frame, more = frames.Next() { + funcName := frame.Function + if funcName == "" { + fmt.Fprint(&traceback, "???\n") + } else { + fmt.Fprintf(&traceback, "%s@%s:%d", + frame.Function, + path.Base(frame.File), + frame.Line, + ) + } + if more { + fmt.Fprintln(&traceback) + } + n-- + } + return traceback.String() +} + +func (mw *AccessLogMiddleware) LogFunc(startTime time.Time, w rest.ResponseWriter, r *rest.Request) { + util := &accessLogUtil{w, r} + fields := logrus.Fields{ + "type": r.Proto, + "ts": startTime. + Truncate(time.Millisecond). + Format(time.RFC3339Nano), + "method": r.Method, + "path": r.URL.Path, + "qs": r.URL.RawQuery, + } + statusCode := util.StatusCode() + + if panic := recover(); panic != nil { + trace := collectTrace() + fields["panic"] = panic + fields["trace"] = trace + // Wrap in recorder middleware to make sure the response is recorded + mw.recorder.MiddlewareFunc(func(w rest.ResponseWriter, r *rest.Request) { + rest.Error(w, "Internal Server Error", http.StatusInternalServerError) + })(w, r) + statusCode = http.StatusInternalServerError + } else if mw.DisableLog != nil && mw.DisableLog(statusCode, r) { + return + } + rspTime := time.Since(startTime) + r.Env["ELAPSED_TIME"] = &rspTime + // We do not need more than 3 digit fraction + if rspTime > time.Second { + rspTime = rspTime.Round(time.Millisecond) + } else if rspTime > time.Millisecond { + rspTime = rspTime.Round(time.Microsecond) + } + fields["responsetime"] = rspTime.String() + fields["byteswritten"] = util.BytesWritten() + fields["status"] = statusCode + + logger := requestlog.GetRequestLogger(r) + var level logrus.Level = logrus.InfoLevel + if statusCode >= 500 { + level = logrus.ErrorLevel + } else if statusCode >= 300 { + level = logrus.WarnLevel + } + logger.WithFields(fields). + Log(level, mw.executeTextTemplate(util)) } // MiddlewareFunc makes AccessLogMiddleware implement the Middleware interface. @@ -52,35 +141,14 @@ func (mw *AccessLogMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFu mw.convertFormat() + // This middleware depends on RecorderMiddleware to work + mw.recorder = new(rest.RecorderMiddleware) return func(w rest.ResponseWriter, r *rest.Request) { - - // call the handler - h(w, r) - - util := &accessLogUtil{w, r} - logger := requestlog.GetRequestLogger(r) - logged := false - log := logger.WithFields(logrus.Fields{ - "type": TypeHTTP, - "ts": util.StartTime().Round(0), - "status": util.StatusCode(), - "responsetime": util.ResponseTime().Seconds(), - "byteswritten": util.BytesWritten(), - "method": r.Method, - "path": r.URL.Path, - "qs": r.URL.RawQuery, - }) - for pathSuffix, status := range DebugLogsByPathSuffix { - if util.StatusCode() == status && strings.HasSuffix(r.URL.Path, pathSuffix) { - log.Debug(mw.executeTextTemplate(util)) - logged = true - break - } - } - - if !logged { - log.Print(mw.executeTextTemplate(util)) - } + startTime := time.Now() + r.Env["START_TIME"] = &startTime + defer mw.LogFunc(startTime, w, r) + // call the handler inside recorder context + mw.recorder.MiddlewareFunc(h)(w, r) } } diff --git a/accesslog/middleware_gin.go b/accesslog/middleware_gin.go index 58798824..418c315e 100644 --- a/accesslog/middleware_gin.go +++ b/accesslog/middleware_gin.go @@ -17,151 +17,98 @@ package accesslog import ( "fmt" "net/http" - "runtime" - "strings" "time" "github.com/gin-gonic/gin" "github.com/mendersoftware/go-lib-micro/log" "github.com/mendersoftware/go-lib-micro/rest.utils" "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) -const MaxTraceback = 32 - -func funcname(fn string) string { - // strip package path - i := strings.LastIndex(fn, "/") - fn = fn[i+1:] - // strip package name. - i = strings.Index(fn, ".") - fn = fn[i+1:] - return fn +type AccessLogger struct { + DisableLog func(c *gin.Context) bool } -func panicHandler(c *gin.Context, startTime time.Time) { +func (a AccessLogger) LogFunc(c *gin.Context, startTime time.Time) { + logCtx := logrus.Fields{ + "clientip": c.ClientIP(), + "method": c.Request.Method, + "path": c.Request.URL.Path, + "qs": c.Request.URL.RawQuery, + "ts": startTime. + Truncate(time.Millisecond). + Format(time.RFC3339Nano), + "type": c.Request.Proto, + "useragent": c.Request.UserAgent(), + } if r := recover(); r != nil { - l := log.FromContext(c.Request.Context()) - latency := time.Since(startTime) - trace := [MaxTraceback]uintptr{} - // Skip 3 - // = runtime.Callers + runtime.extern.Callers + runtime.gopanic - num := runtime.Callers(3, trace[:]) - var traceback strings.Builder - for i := 0; i < num; i++ { - fn := runtime.FuncForPC(trace[i]) - if fn == nil { - fmt.Fprintf(&traceback, "\n???") - continue - } - file, line := fn.FileLine(trace[i]) - fmt.Fprintf(&traceback, "\n%s(%s):%d", - file, funcname(fn.Name()), line, - ) - } + trace := collectTrace() + logCtx["trace"] = trace + logCtx["panic"] = r - logCtx := log.Ctx{ - "clientip": c.ClientIP(), - "method": c.Request.Method, - "path": c.Request.URL.Path, - "qs": c.Request.URL.RawQuery, - "responsetime": fmt.Sprintf("%dus", - latency.Round(time.Microsecond).Microseconds()), - "status": 500, - "ts": startTime.Format(time.RFC3339), - "type": c.Request.Proto, - "useragent": c.Request.UserAgent(), - "trace": traceback.String()[1:], - } - l = l.F(logCtx) func() { - // Panic is going to panic, but we - // immediately want to recover. - defer func() { recover() }() //nolint:errcheck - l.Panicf("[request panic] %s", r) + // Try to respond with an internal server error. + // If the connection is broken it might panic again. + defer func() { recover() }() // nolint:errcheck + rest.RenderError(c, + http.StatusInternalServerError, + errors.New("internal error"), + ) }() + } else if a.DisableLog != nil && a.DisableLog(c) { + return + } + latency := time.Since(startTime) + // We do not need more than 3 digit fraction + if latency > time.Second { + latency = latency.Round(time.Millisecond) + } else if latency > time.Millisecond { + latency = latency.Round(time.Microsecond) + } + code := c.Writer.Status() + logCtx["responsetime"] = latency.String() + logCtx["status"] = c.Writer.Status() + logCtx["byteswritten"] = c.Writer.Size() - // Try to respond with an internal server error. - // If the connection is broken it might panic again. - defer func() { recover() }() // nolint:errcheck - rest.RenderError(c, - http.StatusInternalServerError, - errors.New("internal error"), - ) + var logLevel logrus.Level = logrus.InfoLevel + if code >= 500 { + logLevel = logrus.ErrorLevel + } else if code >= 400 { + logLevel = logrus.WarnLevel } + if len(c.Errors) > 0 { + errs := c.Errors.Errors() + var errMsg string + if len(errs) == 1 { + errMsg = errs[0] + } else { + for i, err := range errs { + errMsg = errMsg + fmt.Sprintf( + "#%02d: %s\n", i+1, err, + ) + } + } + logCtx["error"] = errMsg + } + log.FromContext(c.Request.Context()). + WithFields(logCtx). + Log(logLevel) +} + +func (a AccessLogger) Middleware(c *gin.Context) { + startTime := time.Now() + defer a.LogFunc(c, startTime) + c.Next() } // Middleware provides accesslog middleware for the gin-gonic framework. -// This middleware will recover any panic from the occurring in the API -// handler and log it to panic level. +// This middleware will recover any panic from occurring in the API +// handler and log it to error level with panic and trace showing the panic +// message and traceback respectively. // If an error status is returned in the response, the middleware tries // to pop the topmost error from the gin.Context (c.Error) and puts it in // the "error" context to the final log entry. func Middleware() gin.HandlerFunc { - return func(c *gin.Context) { - startTime := time.Now() - defer panicHandler(c, startTime) - - c.Next() - - l := log.FromContext(c.Request.Context()) - latency := time.Since(startTime) - code := c.Writer.Status() - // Add status and response time to log context - size := c.Writer.Size() - if size < 0 { - size = 0 - } - logCtx := log.Ctx{ - "byteswritten": size, - "clientip": c.ClientIP(), - "method": c.Request.Method, - "path": c.Request.URL.Path, - "qs": c.Request.URL.RawQuery, - "responsetime": fmt.Sprintf("%dus", - latency.Round(time.Microsecond).Microseconds()), - "status": code, - "ts": startTime.Format(time.RFC3339), - "type": c.Request.Proto, - "useragent": c.Request.UserAgent(), - } - l = l.F(logCtx) - - if code < 400 { - logged := false - for pathSuffix, status := range DebugLogsByPathSuffix { - if code == status && strings.HasSuffix(c.Request.URL.Path, pathSuffix) { - l.Debug() - logged = true - break - } - } - - if !logged { - l.Info() - } - } else { - if len(c.Errors) > 0 { - errs := c.Errors.Errors() - var errMsg string - if len(errs) == 1 { - errMsg = errs[0] - } else { - for i, err := range errs { - errMsg = errMsg + fmt.Sprintf( - "#%02d: %s\n", i+1, err, - ) - } - } - l = l.F(log.Ctx{ - "error": errMsg, - }) - } else { - l = l.F(log.Ctx{ - "error": http.StatusText(code), - }) - } - l.Error() - } - } + return AccessLogger{}.Middleware } diff --git a/accesslog/middleware_gin_test.go b/accesslog/middleware_gin_test.go index 691492af..ec28d28e 100644 --- a/accesslog/middleware_gin_test.go +++ b/accesslog/middleware_gin_test.go @@ -16,7 +16,6 @@ package accesslog import ( "bytes" - "fmt" "net/http" "net/http/httptest" "testing" @@ -92,27 +91,6 @@ func TestMiddleware(t *testing.T) { "ts=", `error="#01: internal error 1\\n#02: internal error 2\\n"`, }, - }, { - Name: "ok, unexplained error", - - HandlerFunc: func(c *gin.Context) { - c.Status(http.StatusBadRequest) - _, _ = c.Writer.Write([]byte("bytes")) - }, - Fields: []string{ - "status=400", - `path=/test`, - `qs="foo=bar"`, - "method=GET", - "responsetime=", - "useragent=tester", - "byteswritten=5", - "ts=", - fmt.Sprintf( - `error="%s"`, - http.StatusText(http.StatusBadRequest), - ), - }, }, { Name: "error, panic in handler", @@ -129,9 +107,7 @@ func TestMiddleware(t *testing.T) { "useragent=tester", "ts=", // First three entries in the trace should match this: - `trace=".+middleware_gin_test\.go\(TestMiddleware\.func[0-9]*\):[0-9]+\\n` + - `.+\(\(\*Context\).Next\):[0-9]+\\n` + - `.+\(\(\*Context\).Next\):[0-9]+\\n`, + `trace=".+TestMiddleware\.func[0-9]*@middleware_gin_test\.go:[0-9]+\\n`, }, ExpectedBody: `{"error": "internal error"}`, }} diff --git a/accesslog/middleware_test.go b/accesslog/middleware_test.go new file mode 100644 index 00000000..afc686eb --- /dev/null +++ b/accesslog/middleware_test.go @@ -0,0 +1,125 @@ +// Copyright 2023 Northern.tech AS +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package accesslog + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ant0ine/go-json-rest/rest" + "github.com/mendersoftware/go-lib-micro/log" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestMiddlewareLegacy(t *testing.T) { + testCases := []struct { + Name string + + HandlerFunc rest.HandlerFunc + + Fields []string + ExpectedBody string + }{{ + Name: "ok", + + HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) { + w.WriteHeader(http.StatusNoContent) + }, + Fields: []string{ + "status=204", + `path=/test`, + `qs="foo=bar"`, + "method=GET", + "responsetime=", + "ts=", + }, + }, { + Name: "error, panic in handler", + + HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) { + panic("!!!!!") + }, + + Fields: []string{ + "status=500", + `path=/test`, + `qs="foo=bar"`, + "method=GET", + "responsetime=", + "ts=", + // First three entries in the trace should match this: + `trace=".+TestMiddlewareLegacy\.func[0-9.]*@middleware_test\.go:[0-9.]+\\n`, + }, + ExpectedBody: `{"Error": "Internal Server Error"}`, + }} + + for i := range testCases { + tc := testCases[i] + t.Run(tc.Name, func(t *testing.T) { + app, err := rest.MakeRouter(rest.Get("/test", tc.HandlerFunc)) + if err != nil { + t.Error(err) + t.FailNow() + } + api := rest.NewApi() + var logBuf = bytes.NewBuffer(nil) + api.Use(rest.MiddlewareSimple( + func(h rest.HandlerFunc) rest.HandlerFunc { + logger := log.NewEmpty() + logger.Logger.SetLevel(logrus.InfoLevel) + logger.Logger.SetOutput(logBuf) + logger.Logger.SetFormatter(&logrus.TextFormatter{ + DisableColors: true, + FullTimestamp: true, + }) + return func(w rest.ResponseWriter, r *rest.Request) { + ctx := r.Request.Context() + ctx = log.WithContext(ctx, logger) + r.Request = r.Request.WithContext(ctx) + h(w, r) + t.Log(r.Env) + } + })) + api.Use(&AccessLogMiddleware{}) + api.SetApp(app) + handler := api.MakeHandler() + w := httptest.NewRecorder() + req, _ := http.NewRequest( + http.MethodGet, + "http://localhost/test?foo=bar", + nil, + ) + req.Header.Set("User-Agent", "tester") + + handler.ServeHTTP(w, req) + + logEntry := logBuf.String() + for _, field := range tc.Fields { + assert.Regexp(t, field, logEntry) + } + if tc.Fields == nil { + assert.Empty(t, logEntry) + } + if tc.ExpectedBody != "" { + if assert.NotNil(t, w.Body) { + assert.JSONEq(t, tc.ExpectedBody, w.Body.String()) + } + } + }) + } +} diff --git a/go.mod b/go.mod index d56b56bd..e74cbd1f 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.17.0 github.com/stretchr/testify v1.8.4 - go.mongodb.org/mongo-driver v1.12.1 + go.mongodb.org/mongo-driver v1.13.0 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 ) diff --git a/go.sum b/go.sum index 2d8b57cd..50cc0a1a 100644 --- a/go.sum +++ b/go.sum @@ -253,8 +253,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mongodb.org/mongo-driver v1.12.1 h1:nLkghSU8fQNaK7oUmDhQFsnrtcoNy7Z6LVFKsEecqgE= -go.mongodb.org/mongo-driver v1.12.1/go.mod h1:/rGBTebI3XYboVmgz+Wv3Bcbl3aD0QF9zl6kDDw18rQ= +go.mongodb.org/mongo-driver v1.13.0 h1:67DgFFjYOCMWdtTEmKFpV3ffWlFnh+CYZ8ZS/tXWUfY= +go.mongodb.org/mongo-driver v1.13.0/go.mod h1:/rGBTebI3XYboVmgz+Wv3Bcbl3aD0QF9zl6kDDw18rQ= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= diff --git a/redis/redis.go b/redis/redis.go index 1a781a7f..664dbba4 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -63,6 +63,14 @@ func ClientFromConnectionString( if err != nil { return nil, err } + // in case connection string was provided in form of host:port + // add scheme and parse again + if redisurl.Host == "" { + redisurl, err = url.Parse("redis://" + connectionString) + if err != nil { + return nil, err + } + } q := redisurl.Query() scheme := redisurl.Scheme cname := redisurl.Hostname() @@ -125,6 +133,9 @@ func ClientFromConnectionString( rdb = redis.NewClient(redisOpts) } } + if err != nil { + return nil, fmt.Errorf("redis: invalid connection string: %w", err) + } _, err = rdb. Ping(ctx). Result()