From 148c725e2928459dc220f1fe1bdbb300383b494e Mon Sep 17 00:00:00 2001 From: Laurence Date: Mon, 29 Jan 2024 12:19:05 +0000 Subject: [PATCH] Implement offered comment --- config/config.go | 2 +- router/router.go | 37 +++++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/config/config.go b/config/config.go index a8fbc2fe..0d8e26d4 100644 --- a/config/config.go +++ b/config/config.go @@ -40,7 +40,7 @@ type Configuration struct { AllowHeaders []string } - TrustedProxies []string + TrustXRealIP bool `default:"false"` } Database struct { Dialect string `default:"sqlite3"` diff --git a/router/router.go b/router/router.go index c59513db..c6d00347 100644 --- a/router/router.go +++ b/router/router.go @@ -28,18 +28,7 @@ import ( func Create(db *database.GormDatabase, vInfo *model.VersionInfo, conf *config.Configuration) (*gin.Engine, func()) { g := gin.New() - if conf.Server.TrustedProxies != nil { - g.SetTrustedProxies(conf.Server.TrustedProxies) - g.ForwardedByClientIP = true - } - - g.Use(func(ctx *gin.Context) { - if localAddr, ok := ctx.Request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr.Network() == "unix" { - ctx.Request.RemoteAddr = "127.0.0.1:65535" // set remote address to localhost for unix socket requests - } - }) - - g.Use(gin.LoggerWithFormatter(logFormatter), gin.Recovery(), gerror.Handler(), location.Default()) + g.Use(gin.LoggerWithFormatter(useXRealIP(conf.Server.TrustXRealIP, logFormatter)), gin.Recovery(), gerror.Handler(), location.Default()) g.NoRoute(gerror.NotFound()) if conf.Server.SSL.Enabled != nil && conf.Server.SSL.RedirectToHTTPS != nil && *conf.Server.SSL.Enabled && *conf.Server.SSL.RedirectToHTTPS { @@ -259,3 +248,27 @@ func (fs *onlyImageFS) Open(name string) (http.File, error) { } return fs.inner.Open(name) } + +func useXRealIP(trustedProxy bool, inner gin.LogFormatter) gin.LogFormatter { + return func(params gin.LogFormatterParams) string { + params.ClientIP = getClientIp(trustedProxy, params.Request) + return inner(params) + } +} + +func getClientIp(trustedProxy bool, req *http.Request) string { + if trustedProxy { + realIpParts := strings.SplitN(req.Header.Get("x-real-ip"), ",", 2) + if ip := strings.TrimSpace(realIpParts[0]); ip != "" { + return ip + } + } + + addr := req.RemoteAddr + if addr == "@" { + return "socket" + } else if host, _, err := net.SplitHostPort(addr); err == nil { + return host + } + return addr +}