diff --git a/context.go b/context.go index ec7fdd998..5b9a7b149 100644 --- a/context.go +++ b/context.go @@ -158,6 +158,18 @@ func (c *Context) IsWebSocket() bool { return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade") } +func isValidProto(proto string) bool { + if proto == "" { + return false + } + for _, p := range []string{"http", "https", "ws", "wss"} { + if strings.EqualFold(proto, p) { + return true + } + } + return false +} + // Scheme returns the HTTP protocol scheme, `http` or `https`. func (c *Context) Scheme() string { // Can't use `r.Request.URL.Scheme` @@ -165,16 +177,16 @@ func (c *Context) Scheme() string { if c.IsTLS() { return "https" } - if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" { + if scheme := c.request.Header.Get(HeaderXForwardedProto); isValidProto(scheme) { return scheme } - if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" { + if scheme := c.request.Header.Get(HeaderXForwardedProtocol); isValidProto(scheme) { return scheme } if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" { return "https" } - if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" { + if scheme := c.request.Header.Get(HeaderXUrlScheme); isValidProto(scheme) { return scheme } return "http" diff --git a/context_test.go b/context_test.go index 9376f0f41..21a7af099 100644 --- a/context_test.go +++ b/context_test.go @@ -1090,60 +1090,175 @@ func TestContext_Request(t *testing.T) { } func TestContext_Scheme(t *testing.T) { - tests := []struct { - c *Context - s string + var testCases = []struct { + name string + givenIsTLS bool + givenHeaders http.Header + expect string }{ { - &Context{ - request: &http.Request{ - TLS: &tls.ConnectionState{}, - }, + name: "defaults to http without TLS or headers", + givenIsTLS: false, + givenHeaders: nil, + expect: "http", + }, + { + name: "returns https when TLS is enabled", + givenIsTLS: true, + givenHeaders: nil, + expect: "https", + }, + { + name: "TLS takes precedence over forwarded proto", + givenIsTLS: true, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"http"}, }, - "https", + expect: "https", }, { - &Context{ - request: &http.Request{ - Header: http.Header{HeaderXForwardedProto: []string{"https"}}, - }, + name: "uses X-Forwarded-Proto http", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"http"}, }, - "https", + expect: "http", }, { - &Context{ - request: &http.Request{ - Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, - }, + name: "uses X-Forwarded-Proto https", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"https"}, }, - "http", + expect: "https", }, { - &Context{ - request: &http.Request{ - Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, - }, + name: "X-Forwarded-Proto is case insensitive", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"HTTPS"}, }, - "https", + expect: "HTTPS", }, { - &Context{ - request: &http.Request{ - Header: http.Header{HeaderXUrlScheme: []string{"https"}}, - }, + name: "uses X-Forwarded-Proto ws", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"ws"}, }, - "https", + expect: "ws", }, { - &Context{ - request: &http.Request{}, + name: "uses X-Forwarded-Proto wss", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"wss"}, }, - "http", + expect: "wss", + }, + { + name: "ignores invalid X-Forwarded-Proto and uses X-Forwarded-Protocol", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"ftp"}, + HeaderXForwardedProtocol: []string{"https"}, + }, + expect: "https", + }, + { + name: "uses X-Forwarded-Protocol", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProtocol: []string{"https"}, + }, + expect: "https", + }, + { + name: "X-Forwarded-Proto takes precedence over X-Forwarded-Protocol", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"http"}, + HeaderXForwardedProtocol: []string{"https"}, + }, + expect: "http", + }, + { + name: "uses X-Forwarded-Ssl on", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedSsl: []string{"on"}, + }, + expect: "https", + }, + { + name: "X-Forwarded-Ssl on is case sensitive", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedSsl: []string{"ON"}, + }, + expect: "http", + }, + { + name: "X-Forwarded-Protocol takes precedence over X-Forwarded-Ssl", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProtocol: []string{"http"}, + HeaderXForwardedSsl: []string{"on"}, + }, + expect: "http", + }, + { + name: "uses X-Url-Scheme", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXUrlScheme: []string{"https"}, + }, + expect: "https", + }, + { + name: "X-Forwarded-Ssl takes precedence over X-Url-Scheme", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedSsl: []string{"on"}, + HeaderXUrlScheme: []string{"http"}, + }, + expect: "https", + }, + { + name: "ignores invalid forwarded headers and falls back to http", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{"ftp"}, + HeaderXForwardedProtocol: []string{"smtp"}, + HeaderXForwardedSsl: []string{"off"}, + HeaderXUrlScheme: []string{"file"}, + }, + expect: "http", + }, + { + name: "ignores empty forwarded proto and uses X-Url-Scheme", + givenIsTLS: false, + givenHeaders: http.Header{ + HeaderXForwardedProto: []string{""}, + HeaderXUrlScheme: []string{"https"}, + }, + expect: "https", }, } - for _, tt := range tests { - assert.Equal(t, tt.s, tt.c.Scheme()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenHeaders != nil { + req.Header = tc.givenHeaders + } + c := NewContext(req, nil) + if tc.givenIsTLS { + c.request.TLS = &tls.ConnectionState{} + } + + assert.Equal(t, tc.expect, c.Scheme()) + }) } }