From 55ad88a12b502cd91140b1d5c2fe745689c46b2f Mon Sep 17 00:00:00 2001 From: "Manu Mtz.-Almeida" Date: Mon, 8 Feb 2021 14:08:35 +0100 Subject: [PATCH] refactor move logic to remoteIP() --- context.go | 50 +++++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/context.go b/context.go index 9fa3cb8190..470a755641 100644 --- a/context.go +++ b/context.go @@ -729,42 +729,50 @@ func (c *Context) ShouldBindBodyWith(obj interface{}, bb binding.BindingBody) (e // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. func (c *Context) ClientIP() string { - ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) - if err != nil { - return "" - } - remoteIP := net.ParseIP(ip) - if remoteIP == nil { - return "" - } - if c.engine.AppEngine { if addr := c.requestHeader("X-Appengine-Remote-Addr"); addr != "" { return addr } } - if c.shouldCheckIPHeaders() { - for _, cidr := range c.engine.trustedCIDRs { - if cidr.Contains(remoteIP) { - for _, headerName := range c.engine.RemoteIPHeaders { - ip, valid := validateHeader(c.requestHeader(headerName)) - if valid { - return ip - } - } + remoteIP, trusted := c.RemoteIP() + if remoteIP == nil { + return "" + } + if trusted { + for _, headerName := range c.engine.RemoteIPHeaders { + ip, valid := validateHeader(c.requestHeader(headerName)) + if valid { + return ip } } } - return remoteIP.String() } -func (c *Context) shouldCheckIPHeaders() bool { - return c.engine.ForwardedByClientIP && +func (c *Context) RemoteIP() (net.IP, bool) { + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return nil, false + } + remoteIP := net.ParseIP(ip) + if remoteIP == nil { + return nil, false + } + + shouldCheckTrustedIP := c.engine.ForwardedByClientIP && c.engine.RemoteIPHeaders != nil && len(c.engine.RemoteIPHeaders) > 0 && c.engine.trustedCIDRs != nil + + if shouldCheckTrustedIP { + for _, cidr := range c.engine.trustedCIDRs { + if cidr.Contains(remoteIP) { + return remoteIP, true + } + } + } + return remoteIP, false } func validateHeader(header string) (clientIP string, valid bool) {