Skip to content

Commit

Permalink
Merge dbd1506 into 2e9dcb5
Browse files Browse the repository at this point in the history
  • Loading branch information
moredure committed Feb 10, 2022
2 parents 2e9dcb5 + dbd1506 commit 5232552
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 17 deletions.
85 changes: 85 additions & 0 deletions src/net/dial_windows_test.go
@@ -0,0 +1,85 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package net

import (
"context"
"errors"
"net/internal/socktest"
"syscall"
"testing"
"time"
)

func init() {
isEADDRINUSE = func(err error) bool {
return errors.Is(err, syscall.EADDRINUSE)
}
}

// Issue 16523
func TestDialContextCancelRace(t *testing.T) {
oldTestHookCanceledDial := testHookCanceledDial
defer func() {
testHookCanceledDial = oldTestHookCanceledDial
}()

ln := newLocalListener(t, "tcp")
listenerDone := make(chan struct{})
go func() {
defer close(listenerDone)
c, err := ln.Accept()
if err == nil {
c.Close()
}
}()
defer func() { <-listenerDone }()
defer ln.Close()

sawCancel := make(chan bool, 1)
testHookCanceledDial = func() {
sawCancel <- true
}

ctx, cancelCtx := context.WithCancel(context.Background())
sw.Set(socktest.FilterConnect, func(*socktest.Status) (socktest.AfterFilter, error) {
return func(*socktest.Status) error {
cancelCtx()
// And wait for the "interrupter" goroutine to
// cancel the dial by messing with its write
// timeout before returning.
select {
case <-sawCancel:
t.Logf("saw cancel")
case <-time.After(5 * time.Second):
t.Errorf("didn't see cancel after 5 seconds")
}
return context.Canceled
}, nil
})
defer sw.Set(socktest.FilterConnect, nil)

var d Dialer
c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
if err == nil {
c.Close()
t.Fatal("unexpected successful dial; want context canceled error")
}

select {
case <-ctx.Done():
case <-time.After(5 * time.Second):
t.Fatal("expected context to be canceled")
}

oe, ok := err.(*OpError)
if !ok || oe.Op != "dial" {
t.Fatalf("Dial error = %#v; want dial *OpError", err)
}

if oe.Err != errCanceled {
t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, errCanceled)
}
}
37 changes: 21 additions & 16 deletions src/net/fd_windows.go
Expand Up @@ -86,26 +86,31 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.
}
}

// Wait for the goroutine converting context.Done into a write timeout
// to exist, otherwise our caller might cancel the context and
// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done := make(chan bool) // must be unbuffered
defer func() { done <- true }()
go func() {
select {
case <-ctx.Done():
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.pfd.SetWriteDeadline(aLongTimeAgo)
<-done
case <-done:
}
}()
// Start the "interrupter" goroutine, if this context might be canceled.
ctxDone := ctx.Done()
if ctxDone != nil {
// Wait for the goroutine converting context.Done into a write timeout
// to exist, otherwise our caller might cancel the context and
// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
done := make(chan struct{})
defer func() { done <- struct{}{} }()
go func() {
select {
case <-ctxDone:
// Force the runtime's poller to immediately give
// up waiting for writability.
fd.pfd.SetWriteDeadline(aLongTimeAgo)
testHookCanceledDial()
<-done
case <-done:
}
}()
}

// Call ConnectEx API.
if err := fd.pfd.ConnectEx(ra); err != nil {
select {
case <-ctx.Done():
case <-ctxDone:
return nil, mapErr(ctx.Err())
default:
if _, ok := err.(syscall.Errno); ok {
Expand Down
3 changes: 2 additions & 1 deletion src/net/hook_windows.go
Expand Up @@ -11,7 +11,8 @@ import (
)

var (
testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
testHookCanceledDial = func() {} // for golang.org/issue/16523

// Placeholders for socket system calls.
socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
Expand Down

0 comments on commit 5232552

Please sign in to comment.