Skip to content

Commit

Permalink
Check for and report bad protocol in TLSClientConfig.NextProtos (#788)
Browse files Browse the repository at this point in the history
* return an error when Dialer.TLSClientConfig.NextProtos contains a protocol that is not http/1.1

* include the likely cause of the error in the error message

* check for nil-ness of Dialer.TLSClientConfig before attempting to run the check

* addressing the review

* move the NextProtos test into a separate file so that it can be run conditionally on go versions >= 1.14

* moving the new error check into existing http response error block to reduce the possibility of false positives

* wrapping the error in %w

* using %v instead of %w for compatibility with older versions of go

* Revert "using %v instead of %w for compatibility with older versions of go"

This reverts commit d34dd94.

* move the unit test back into the existing test code since golang build constraint is no longer necessary

Co-authored-by: Chan Kang <chankang@chankang17@gmail.com>
  • Loading branch information
ChannyClaus and Chan Kang committed Jun 21, 2022
1 parent 27d91a9 commit bc7ce89
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
12 changes: 12 additions & 0 deletions client.go
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
Expand Down Expand Up @@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h

resp, err := http.ReadResponse(conn.br, req)
if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err
}

Expand Down
35 changes: 35 additions & 0 deletions client_server_test.go
Expand Up @@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
}
}
}
func TestNextProtos(t *testing.T) {
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()

d := Dialer{
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
}

r, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
r.Body.Close()

// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
// after the Client.Get call from net/http above.
var containsHTTP2 bool = false
for _, proto := range d.TLSClientConfig.NextProtos {
if proto == "h2" {
containsHTTP2 = true
}
}
if !containsHTTP2 {
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
}

_, _, err = d.Dial(makeWsProto(ts.URL), nil)
if err == nil {
t.Fatalf("Dial succeeded, expect fail ")
}
}

0 comments on commit bc7ce89

Please sign in to comment.