Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,35 @@ func (c *Context) IsWebSocket() bool {
return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade")
}

func isValidProto(proto string) bool {
if proto == "" {
return false
}
for _, p := range []string{"http", "https", "ws", "wss"} {
if strings.EqualFold(proto, p) {
return true
}
}
return false
}

// Scheme returns the HTTP protocol scheme, `http` or `https`.
func (c *Context) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
return "https"
}
if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" {
if scheme := c.request.Header.Get(HeaderXForwardedProto); isValidProto(scheme) {
return scheme
}
if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" {
if scheme := c.request.Header.Get(HeaderXForwardedProtocol); isValidProto(scheme) {
return scheme
}
if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" {
return "https"
}
if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" {
if scheme := c.request.Header.Get(HeaderXUrlScheme); isValidProto(scheme) {
return scheme
}
return "http"
Expand Down
181 changes: 148 additions & 33 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1090,60 +1090,175 @@ func TestContext_Request(t *testing.T) {
}

func TestContext_Scheme(t *testing.T) {
tests := []struct {
c *Context
s string
var testCases = []struct {
name string
givenIsTLS bool
givenHeaders http.Header
expect string
}{
{
&Context{
request: &http.Request{
TLS: &tls.ConnectionState{},
},
name: "defaults to http without TLS or headers",
givenIsTLS: false,
givenHeaders: nil,
expect: "http",
},
{
name: "returns https when TLS is enabled",
givenIsTLS: true,
givenHeaders: nil,
expect: "https",
},
{
name: "TLS takes precedence over forwarded proto",
givenIsTLS: true,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"http"},
},
"https",
expect: "https",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProto: []string{"https"}},
},
name: "uses X-Forwarded-Proto http",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"http"},
},
"https",
expect: "http",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProtocol: []string{"http"}},
},
name: "uses X-Forwarded-Proto https",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"https"},
},
"http",
expect: "https",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedSsl: []string{"on"}},
},
name: "X-Forwarded-Proto is case insensitive",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"HTTPS"},
},
"https",
expect: "HTTPS",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXUrlScheme: []string{"https"}},
},
name: "uses X-Forwarded-Proto ws",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"ws"},
},
"https",
expect: "ws",
},
{
&Context{
request: &http.Request{},
name: "uses X-Forwarded-Proto wss",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"wss"},
},
"http",
expect: "wss",
},
{
name: "ignores invalid X-Forwarded-Proto and uses X-Forwarded-Protocol",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"ftp"},
HeaderXForwardedProtocol: []string{"https"},
},
expect: "https",
},
{
name: "uses X-Forwarded-Protocol",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProtocol: []string{"https"},
},
expect: "https",
},
{
name: "X-Forwarded-Proto takes precedence over X-Forwarded-Protocol",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"http"},
HeaderXForwardedProtocol: []string{"https"},
},
expect: "http",
},
{
name: "uses X-Forwarded-Ssl on",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedSsl: []string{"on"},
},
expect: "https",
},
{
name: "X-Forwarded-Ssl on is case sensitive",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedSsl: []string{"ON"},
},
expect: "http",
},
{
name: "X-Forwarded-Protocol takes precedence over X-Forwarded-Ssl",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProtocol: []string{"http"},
HeaderXForwardedSsl: []string{"on"},
},
expect: "http",
},
{
name: "uses X-Url-Scheme",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXUrlScheme: []string{"https"},
},
expect: "https",
},
{
name: "X-Forwarded-Ssl takes precedence over X-Url-Scheme",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedSsl: []string{"on"},
HeaderXUrlScheme: []string{"http"},
},
expect: "https",
},
{
name: "ignores invalid forwarded headers and falls back to http",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{"ftp"},
HeaderXForwardedProtocol: []string{"smtp"},
HeaderXForwardedSsl: []string{"off"},
HeaderXUrlScheme: []string{"file"},
},
expect: "http",
},
{
name: "ignores empty forwarded proto and uses X-Url-Scheme",
givenIsTLS: false,
givenHeaders: http.Header{
HeaderXForwardedProto: []string{""},
HeaderXUrlScheme: []string{"https"},
},
expect: "https",
},
}

for _, tt := range tests {
assert.Equal(t, tt.s, tt.c.Scheme())
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenHeaders != nil {
req.Header = tc.givenHeaders
}
c := NewContext(req, nil)
if tc.givenIsTLS {
c.request.TLS = &tls.ConnectionState{}
}

assert.Equal(t, tc.expect, c.Scheme())
})
}
}

Expand Down
Loading