Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/gorilla/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
savsgio committed Jul 26, 2022
2 parents c48d95b + af47554 commit 77c751b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
18 changes: 15 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
Expand Down Expand Up @@ -318,14 +319,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}

netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: netConn,
})
}
if err != nil {
return nil, nil, err
}

defer func() {
if netConn != nil {
Expand Down Expand Up @@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h

resp, err := http.ReadResponse(conn.br, req)
if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err
}

Expand Down
35 changes: 35 additions & 0 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
}
}
}
func TestNextProtos(t *testing.T) {
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()

d := Dialer{
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
}

r, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
r.Body.Close()

// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
// after the Client.Get call from net/http above.
var containsHTTP2 bool = false
for _, proto := range d.TLSClientConfig.NextProtos {
if proto == "h2" {
containsHTTP2 = true
}
}
if !containsHTTP2 {
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
}

_, _, err = d.Dial(makeWsProto(ts.URL), nil)
if err == nil {
t.Fatalf("Dial succeeded, expect fail ")
}
}

0 comments on commit 77c751b

Please sign in to comment.