-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from jtrw/develop
Add function to get real IP from request headers
- Loading branch information
Showing
2 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package rest | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
type ipRange struct { | ||
start net.IP | ||
end net.IP | ||
} | ||
|
||
var privateRanges = []ipRange{ | ||
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")}, | ||
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")}, | ||
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")}, | ||
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")}, | ||
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")}, | ||
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")}, | ||
{start: net.ParseIP("::1"), end: net.ParseIP("::1")}, | ||
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, | ||
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")}, | ||
} | ||
|
||
// Get returns real ip from the given request | ||
// Prioritize public IPs over private IPs | ||
func GetRealIP(r *http.Request) (string, error) { | ||
var firstIP string | ||
for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { | ||
addresses := strings.Split(r.Header.Get(h), ",") | ||
for i := len(addresses) - 1; i >= 0; i-- { | ||
ip := strings.TrimSpace(addresses[i]) | ||
realIP := net.ParseIP(ip) | ||
if firstIP == "" && realIP != nil { | ||
firstIP = ip | ||
} | ||
if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) { | ||
continue | ||
} | ||
return ip, nil | ||
} | ||
} | ||
|
||
if firstIP != "" { | ||
return firstIP, nil | ||
} | ||
|
||
ip, _, err := net.SplitHostPort(r.RemoteAddr) | ||
if err != nil { | ||
return "", fmt.Errorf("can't parse ip %q: %w", r.RemoteAddr, err) | ||
} | ||
if netIP := net.ParseIP(ip); netIP == nil { | ||
return "", fmt.Errorf("no valid ip found") | ||
} | ||
|
||
return ip, nil | ||
} | ||
|
||
// isPrivateSubnet - check to see if this ip is in a private subnet | ||
func isPrivateSubnet(ipAddress net.IP) bool { | ||
|
||
// inRange - check to see if a given ip address is within a range given | ||
inRange := func(r ipRange, ipAddress net.IP) bool { | ||
// ensure the IPs are in the same format for comparison | ||
ipAddress = ipAddress.To16() | ||
r.start = r.start.To16() | ||
r.end = r.end.To16() | ||
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0 | ||
} | ||
|
||
for _, r := range privateRanges { | ||
if inRange(r, ipAddress) { | ||
return true | ||
} | ||
} | ||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
package rest | ||
|
||
import ( | ||
"log" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestGetFromHeaders(t *testing.T) { | ||
t.Run("single X-Real-IP", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Real-IP", "8.8.8.8") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "8.8.8.8", adr) | ||
}) | ||
t.Run("X-Forwarded-For last public", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "30.30.30.1", adr) | ||
}) | ||
t.Run("X-Forwarded-For last private", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "1.1.1.2", adr) | ||
}) | ||
t.Run("X-Forwarded-For public im the middle", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "192.168.1.1, 8.8.8.8, 10.0.0.65") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "8.8.8.8", adr) | ||
}) | ||
t.Run("X-Forwarded-For all private", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "192.168.1.1,10.0.0.65") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "10.0.0.65", adr) | ||
}) | ||
t.Run("X-Forwarded-For public, X-Real-IP private", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "30.30.30.1") | ||
req.Header.Add("X-Real-Ip", "10.0.0.1") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "30.30.30.1", adr) | ||
}) | ||
t.Run("X-Forwarded-For and X-Real-IP public", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "30.30.30.1") | ||
req.Header.Add("X-Real-Ip", "8.8.8.8") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "30.30.30.1", adr) | ||
}) | ||
t.Run("X-Forwarded-For private and X-Real-IP public", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("Something", "1234567") | ||
req.Header.Add("X-Forwarded-For", "10.0.0.2,192.168.1.1") | ||
req.Header.Add("X-Real-Ip", "8.8.8.8") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "8.8.8.8", adr) | ||
}) | ||
t.Run("RemoteAddr fallback", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.RemoteAddr = "192.0.2.1:1234" | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "192.0.2.1", adr) | ||
}) | ||
t.Run("X-Forwarded-For and X-Real-IP missing, no RemoteAddr either", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
ip, err := GetRealIP(req) | ||
assert.Error(t, err) | ||
assert.Equal(t, "", ip) | ||
}) | ||
t.Run("X-Real-IP IPv6", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("X-Real-IP", "2001:0db8:85a3:0000:0000:8a2e:0370:7334") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "2001:0db8:85a3:0000:0000:8a2e:0370:7334", adr) | ||
}) | ||
t.Run("X-Forwarded-For last IPv6 public", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.Header.Add("X-Forwarded-For", "2001:db8::ff00:42:8329,::1,fc00::") | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "2001:db8::ff00:42:8329", adr) | ||
}) | ||
|
||
t.Run("RemoteAddr IPv6 fallback", func(t *testing.T) { | ||
req, err := http.NewRequest("GET", "/something", http.NoBody) | ||
assert.NoError(t, err) | ||
req.RemoteAddr = "[2001:db8::ff00:42:8329]:1234" | ||
adr, err := GetRealIP(req) | ||
require.NoError(t, err) | ||
assert.Equal(t, "2001:db8::ff00:42:8329", adr) | ||
}) | ||
} | ||
|
||
func TestGetFromRemoteAddr(t *testing.T) { | ||
ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { | ||
log.Printf("%v", r) | ||
adr, err := GetRealIP(r) | ||
require.NoError(t, err) | ||
assert.Equal(t, "127.0.0.1", adr) | ||
})) | ||
|
||
req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody) | ||
require.NoError(t, err) | ||
client := http.Client{Timeout: time.Second} | ||
_, err = client.Do(req) | ||
require.NoError(t, err) | ||
} |