Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃敟 Feature: Add TrustedProxies feature #1397

Merged
merged 4 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,34 @@ type Config struct {
//
// Default: NetworkTCP4
Network string

// If you find yourself behind some sort of proxy, like a load balancer,
// then certain header information may be sent to you using special X-Forwarded-* headers or the Forwarded header.
// For example, the Host HTTP header is usually used to return the requested host.
// But when you鈥檙e behind a proxy, the actual host may be stored in an X-Forwarded-Host header.
//
// If you are behind a proxy, you should enable TrustedProxyCheck to prevent header spoofing.
// If you enable EnableTrustedProxyCheck and leave TrustedProxies empty Fiber will skip
// all headers that could be spoofed.
// If request ip in TrustedProxies whitelist then:
// 1. c.Protocol() get value from X-Forwarded-Proto, X-Forwarded-Protocol, X-Forwarded-Ssl or X-Url-Scheme header
// 2. c.IP() get value from ProxyHeader header.
// 3. c.Hostname() get value from X-Forwarded-Host header
// But if request ip NOT in Trusted Proxies whitelist then:
// 1. c.Protocol() WON't get value from X-Forwarded-Proto, X-Forwarded-Protocol, X-Forwarded-Ssl or X-Url-Scheme header,
// will return https in case when tls connection is handled by the app, of http otherwise
// 2. c.IP() WON'T get value from ProxyHeader header, will return RemoteIP() from fasthttp context
// 3. c.Hostname() WON'T get value from X-Forwarded-Host header, fasthttp.Request.URI().Host()
// will be used to get the hostname.
//
// Default: false
EnableTrustedProxyCheck bool `json:"enable_trusted_proxy_check"`

// Read EnableTrustedProxyCheck doc.
//
// Default: []string
TrustedProxies []string `json:"trusted_proxies"`
trustedProxiesMap map[string]struct{}
}

// Static defines configuration options when defining static assets.
Expand Down Expand Up @@ -417,6 +445,11 @@ func New(config ...Config) *App {
app.config.Network = NetworkTCP4
}

app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies))
for _, ip := range app.config.TrustedProxies {
app.config.trustedProxiesMap[ip] = struct{}{}
}

// Init app
app.init()

Expand Down
25 changes: 23 additions & 2 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,18 +500,26 @@ func (c *Ctx) Get(key string, defaultValue ...string) string {
return defaultString(c.app.getString(c.fasthttp.Request.Header.Peek(key)), defaultValue)
}

// Hostname contains the hostname derived from the Host HTTP header.
// Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) Hostname() string {
if c.IsProxyTrusted() {
if host := c.Get(HeaderXForwardedHost); len(host) > 0 {
return host
}
}
return c.app.getString(c.fasthttp.Request.URI().Host())
}

// IP returns the remote IP address of the request.
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) IP() string {
if len(c.app.config.ProxyHeader) > 0 {
if c.IsProxyTrusted() && len(c.app.config.ProxyHeader) > 0 {
return c.Get(c.app.config.ProxyHeader)
}

return c.fasthttp.RemoteIP().String()
}

Expand Down Expand Up @@ -713,11 +721,15 @@ func (c *Ctx) Path(override ...string) string {
}

// Protocol contains the request protocol string: http or https for TLS requests.
// Use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
func (c *Ctx) Protocol() string {
if c.fasthttp.IsTLS() {
return "https"
}
scheme := "http"
if !c.IsProxyTrusted() {
return scheme
}
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
if len(key) < 12 {
return // X-Forwarded-
Expand Down Expand Up @@ -1169,3 +1181,12 @@ func (c *Ctx) configDependentPaths() {
c.treePath = c.detectionPath[:3]
}
}

func (c *Ctx) IsProxyTrusted() bool {
if !c.app.config.EnableTrustedProxyCheck {
return true
}

_, trustProxy := c.app.config.trustedProxiesMap[c.fasthttp.RemoteIP().String()]
return trustProxy
}
108 changes: 108 additions & 0 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,42 @@ func Test_Ctx_Hostname(t *testing.T) {
utils.AssertEqual(t, "google.com", c.Hostname())
}

// go test -run Test_Ctx_Hostname_Untrusted
func Test_Ctx_Hostname_UntrustedProxy(t *testing.T) {
t.Parallel()
// Don't trust any proxy
{
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
utils.AssertEqual(t, "google.com", c.Hostname())
app.ReleaseCtx(c)
}
// Trust to specific proxy list
{
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.8.0.0", "0.8.0.1"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
utils.AssertEqual(t, "google.com", c.Hostname())
app.ReleaseCtx(c)
}
}

// go test -run Test_Ctx_Hostname_Trusted
func Test_Ctx_Hostname_TrustedProxy(t *testing.T) {
t.Parallel()
{
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0", "0.8.0.1"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
utils.AssertEqual(t, "google1.com", c.Hostname())
app.ReleaseCtx(c)
}
}

// go test -run Test_Ctx_IP
func Test_Ctx_IP(t *testing.T) {
t.Parallel()
Expand All @@ -800,6 +836,26 @@ func Test_Ctx_IP_ProxyHeader(t *testing.T) {
utils.AssertEqual(t, "", c.IP())
}

// go test -run Test_Ctx_IP_UntrustedProxy
func Test_Ctx_IP_UntrustedProxy(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.8.0.1"}, ProxyHeader: HeaderXForwardedFor})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1")
defer app.ReleaseCtx(c)
utils.AssertEqual(t, "0.0.0.0", c.IP())
}

// go test -run Test_Ctx_IP_TrustedProxy
func Test_Ctx_IP_TrustedProxy(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0"}, ProxyHeader: HeaderXForwardedFor})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1")
defer app.ReleaseCtx(c)
utils.AssertEqual(t, "0.0.0.1", c.IP())
}

// go test -run Test_Ctx_IPs -parallel
func Test_Ctx_IPs(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1126,6 +1182,58 @@ func Benchmark_Ctx_Protocol(b *testing.B) {
utils.AssertEqual(b, "http", res)
}

// go test -run Test_Ctx_Protocol_TrustedProxy
func Test_Ctx_Protocol_TrustedProxy(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.0.0.0"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "https", c.Protocol())
c.Request().Header.Reset()

utils.AssertEqual(t, "http", c.Protocol())
}

// go test -run Test_Ctx_Protocol_UnTrustedProxy
func Test_Ctx_Protocol_UnTrustedProxy(t *testing.T) {
t.Parallel()
app := New(Config{EnableTrustedProxyCheck: true, TrustedProxies: []string{"0.8.0.1"}})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set(HeaderXForwardedProto, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedProtocol, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXForwardedSsl, "on")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

c.Request().Header.Set(HeaderXUrlScheme, "https")
utils.AssertEqual(t, "http", c.Protocol())
c.Request().Header.Reset()

utils.AssertEqual(t, "http", c.Protocol())
}

// go test -run Test_Ctx_Query
func Test_Ctx_Query(t *testing.T) {
t.Parallel()
Expand Down