diff --git a/middleware/realip.go b/middleware/realip.go index 55c95a89..d91f2d1c 100644 --- a/middleware/realip.go +++ b/middleware/realip.go @@ -9,9 +9,11 @@ import ( "strings" ) -var trueClientIP = http.CanonicalHeaderKey("True-Client-IP") -var xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") -var xRealIP = http.CanonicalHeaderKey("X-Real-IP") +var defaultHeaders = []string{ + "True-Client-IP", // Cloudflare Enterprise plan + "X-Real-IP", + "X-Forwarded-For", +} // RealIP is a middleware that sets a http.Request's RemoteAddr to the results // of parsing either the True-Client-IP, X-Real-IP or the X-Forwarded-For headers @@ -30,7 +32,7 @@ var xRealIP = http.CanonicalHeaderKey("X-Real-IP") // how you're using RemoteAddr, vulnerable to an attack of some sort). func RealIP(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { - if rip := realIP(r); rip != "" { + if rip := getRealIP(r, defaultHeaders); rip != "" { r.RemoteAddr = rip } h.ServeHTTP(w, r) @@ -39,22 +41,33 @@ func RealIP(h http.Handler) http.Handler { return http.HandlerFunc(fn) } -func realIP(r *http.Request) string { - var ip string - - if tcip := r.Header.Get(trueClientIP); tcip != "" { - ip = tcip - } else if xrip := r.Header.Get(xRealIP); xrip != "" { - ip = xrip - } else if xff := r.Header.Get(xForwardedFor); xff != "" { - i := strings.Index(xff, ",") - if i == -1 { - i = len(xff) +// RealIPFromHeaders is a middleware that sets a http.Request's RemoteAddr to the results +// of parsing the custom headers. +// +// usage: +// r.Use(RealIPFromHeaders("CF-Connecting-IP")) +func RealIPFromHeaders(headers ...string) func(http.Handler) http.Handler { + f := func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if rip := getRealIP(r, headers); rip != "" { + r.RemoteAddr = rip + } + h.ServeHTTP(w, r) } - ip = xff[:i] + return http.HandlerFunc(fn) } - if ip == "" || net.ParseIP(ip) == nil { - return "" + return f +} + +func getRealIP(r *http.Request, headers []string) string { + for _, header := range headers { + if ip := r.Header.Get(header); ip != "" { + ips := strings.Split(ip, ",") + if ips[0] == "" || net.ParseIP(ips[0]) == nil { + continue + } + return ips[0] + } } - return ip + return "" } diff --git a/middleware/realip_test.go b/middleware/realip_test.go index 1ab5e95e..97370323 100644 --- a/middleware/realip_test.go +++ b/middleware/realip_test.go @@ -113,3 +113,52 @@ func TestInvalidIP(t *testing.T) { t.Fatal("Invalid IP used.") } } + +func TestCustomIPHeader(t *testing.T) { + var customHeaderKey = "X-CUSTOM-IP" + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add(customHeaderKey, "100.100.100.100") + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(RealIPFromHeaders(customHeaderKey)) + + realIP := "" + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + realIP = r.RemoteAddr + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatal("Response Code should be 200") + } + + if realIP != "100.100.100.100" { + t.Fatal("Test get real IP precedence error.") + } +} + +func TestCustomIPHeaderWithoutDefault(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Add("X-REAL-IP", "100.100.100.100") + w := httptest.NewRecorder() + + r := chi.NewRouter() + r.Use(RealIPFromHeaders("CF-Connecting-IP")) + + realIP := "" + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + realIP = r.RemoteAddr + w.Write([]byte("Hello World")) + }) + r.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatal("Response Code should be 200") + } + + if realIP != "" { + t.Fatal("Invalid IP used.") + } +}