Skip to content

Commit

Permalink
Merge pull request #17 from jtrw/develop
Browse files Browse the repository at this point in the history
Add function to get real IP from request headers
  • Loading branch information
nilBora committed Apr 9, 2024
2 parents 415f385 + 2c2096f commit 4ebe433
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 0 deletions.
80 changes: 80 additions & 0 deletions realip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package rest

import (
"bytes"
"fmt"
"net"
"net/http"
"strings"
)

type ipRange struct {
start net.IP
end net.IP
}

var privateRanges = []ipRange{
{start: net.ParseIP("10.0.0.0"), end: net.ParseIP("10.255.255.255")},
{start: net.ParseIP("100.64.0.0"), end: net.ParseIP("100.127.255.255")},
{start: net.ParseIP("172.16.0.0"), end: net.ParseIP("172.31.255.255")},
{start: net.ParseIP("192.0.0.0"), end: net.ParseIP("192.0.0.255")},
{start: net.ParseIP("192.168.0.0"), end: net.ParseIP("192.168.255.255")},
{start: net.ParseIP("198.18.0.0"), end: net.ParseIP("198.19.255.255")},
{start: net.ParseIP("::1"), end: net.ParseIP("::1")},
{start: net.ParseIP("fc00::"), end: net.ParseIP("fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
{start: net.ParseIP("fe80::"), end: net.ParseIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")},
}

// Get returns real ip from the given request
// Prioritize public IPs over private IPs
func GetRealIP(r *http.Request) (string, error) {
var firstIP string
for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} {
addresses := strings.Split(r.Header.Get(h), ",")
for i := len(addresses) - 1; i >= 0; i-- {
ip := strings.TrimSpace(addresses[i])
realIP := net.ParseIP(ip)
if firstIP == "" && realIP != nil {
firstIP = ip
}
if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) {
continue
}
return ip, nil
}
}

if firstIP != "" {
return firstIP, nil
}

ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return "", fmt.Errorf("can't parse ip %q: %w", r.RemoteAddr, err)
}
if netIP := net.ParseIP(ip); netIP == nil {
return "", fmt.Errorf("no valid ip found")
}

return ip, nil
}

// isPrivateSubnet - check to see if this ip is in a private subnet
func isPrivateSubnet(ipAddress net.IP) bool {

// inRange - check to see if a given ip address is within a range given
inRange := func(r ipRange, ipAddress net.IP) bool {
// ensure the IPs are in the same format for comparison
ipAddress = ipAddress.To16()
r.start = r.start.To16()
r.end = r.end.To16()
return bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) <= 0
}

for _, r := range privateRanges {
if inRange(r, ipAddress) {
return true
}
}
return false
}
145 changes: 145 additions & 0 deletions realip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package rest

import (
"log"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetFromHeaders(t *testing.T) {
t.Run("single X-Real-IP", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Real-IP", "8.8.8.8")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
})
t.Run("X-Forwarded-For last public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2, 30.30.30.1")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
})
t.Run("X-Forwarded-For last private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "8.8.8.8,1.1.1.2,192.168.1.1,10.0.0.65")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "1.1.1.2", adr)
})
t.Run("X-Forwarded-For public im the middle", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "192.168.1.1, 8.8.8.8, 10.0.0.65")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
})
t.Run("X-Forwarded-For all private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "192.168.1.1,10.0.0.65")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "10.0.0.65", adr)
})
t.Run("X-Forwarded-For public, X-Real-IP private", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "30.30.30.1")
req.Header.Add("X-Real-Ip", "10.0.0.1")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
})
t.Run("X-Forwarded-For and X-Real-IP public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "30.30.30.1")
req.Header.Add("X-Real-Ip", "8.8.8.8")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "30.30.30.1", adr)
})
t.Run("X-Forwarded-For private and X-Real-IP public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("Something", "1234567")
req.Header.Add("X-Forwarded-For", "10.0.0.2,192.168.1.1")
req.Header.Add("X-Real-Ip", "8.8.8.8")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "8.8.8.8", adr)
})
t.Run("RemoteAddr fallback", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.RemoteAddr = "192.0.2.1:1234"
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "192.0.2.1", adr)
})
t.Run("X-Forwarded-For and X-Real-IP missing, no RemoteAddr either", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
ip, err := GetRealIP(req)
assert.Error(t, err)
assert.Equal(t, "", ip)
})
t.Run("X-Real-IP IPv6", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("X-Real-IP", "2001:0db8:85a3:0000:0000:8a2e:0370:7334")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "2001:0db8:85a3:0000:0000:8a2e:0370:7334", adr)
})
t.Run("X-Forwarded-For last IPv6 public", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.Header.Add("X-Forwarded-For", "2001:db8::ff00:42:8329,::1,fc00::")
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "2001:db8::ff00:42:8329", adr)
})

t.Run("RemoteAddr IPv6 fallback", func(t *testing.T) {
req, err := http.NewRequest("GET", "/something", http.NoBody)
assert.NoError(t, err)
req.RemoteAddr = "[2001:db8::ff00:42:8329]:1234"
adr, err := GetRealIP(req)
require.NoError(t, err)
assert.Equal(t, "2001:db8::ff00:42:8329", adr)
})
}

func TestGetFromRemoteAddr(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
log.Printf("%v", r)
adr, err := GetRealIP(r)
require.NoError(t, err)
assert.Equal(t, "127.0.0.1", adr)
}))

req, err := http.NewRequest("GET", ts.URL+"/something", http.NoBody)
require.NoError(t, err)
client := http.Client{Timeout: time.Second}
_, err = client.Do(req)
require.NoError(t, err)
}

0 comments on commit 4ebe433

Please sign in to comment.