Skip to content

Commit

Permalink
Extend DialOptions to allow Host header override
Browse files Browse the repository at this point in the history
  • Loading branch information
bendiscz authored and nhooyr committed Oct 13, 2023
1 parent 3f26c9f commit f7bed7c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
7 changes: 7 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ type DialOptions struct {
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header

// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string

// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string

Expand Down Expand Up @@ -168,6 +172,9 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
Expand Down
60 changes: 60 additions & 0 deletions dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package websocket

import (
"bytes"
"context"
"crypto/rand"
"io"
Expand Down Expand Up @@ -118,6 +119,65 @@ func TestBadDials(t *testing.T) {
})
}

func Test_verifyHostOverride(t *testing.T) {
testCases := []struct {
name string
host string
exp string
}{
{
name: "noOverride",
host: "",
exp: "example.com",
},
{
name: "hostOverride",
host: "example.net",
exp: "example.net",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

rt := func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Host", tc.exp, r.Host)

h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))

return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: mockBody{bytes.NewBufferString("hi")},
}, nil
}

_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
})
}

}

type mockBody struct {
*bytes.Buffer
}

func (mb mockBody) Close() error {
return nil
}

func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit f7bed7c

Please sign in to comment.