From 7fd613642282944805ba6f4f30bd5501b1f74e99 Mon Sep 17 00:00:00 2001 From: Gus Eggert Date: Mon, 30 Jan 2023 08:02:52 -0500 Subject: [PATCH] Fix dial panic when ctx is nil When the ctx is nil, http.NewRequestWithContext returns a "net/http: nil Context" error and a nil request. In this case, the dial function panics because it assumes the req is never nil. This checks the returning error and returns it, so that callers get an error instead of a panic in that scenario. --- dial.go | 5 ++++- dial_test.go | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/dial.go b/dial.go index 7a7787ff..0ae0d570 100644 --- a/dial.go +++ b/dial.go @@ -157,7 +157,10 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } - req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to build HTTP request: %w", err) + } req.Header = opts.HTTPHeader.Clone() req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") diff --git a/dial_test.go b/dial_test.go index 28c255c6..80ba9a3d 100644 --- a/dial_test.go +++ b/dial_test.go @@ -23,10 +23,11 @@ func TestBadDials(t *testing.T) { t.Parallel() testCases := []struct { - name string - url string - opts *DialOptions - rand readerFunc + name string + url string + opts *DialOptions + rand readerFunc + nilCtx bool }{ { name: "badURL", @@ -46,6 +47,11 @@ func TestBadDials(t *testing.T) { return 0, io.EOF }, }, + { + name: "nilContext", + url: "http://localhost", + nilCtx: true, + }, } for _, tc := range testCases { @@ -53,8 +59,12 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + var ctx context.Context + var cancel func() + if !tc.nilCtx { + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + } if tc.rand == nil { tc.rand = rand.Reader.Read