diff --git a/client_test.go b/client_test.go index 890385c..a4c8ea7 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "errors" + "fmt" "io" "log" "net" @@ -712,12 +713,12 @@ func TestClientRetransmission(t *testing.T) { agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error { if attempt == 0 { attempt++ - agent.h(Event{ + go agent.h(Event{ TransactionID: id, Error: ErrTransactionTimeOut, }) } else { - agent.h(Event{ + go agent.h(Event{ TransactionID: id, Message: response, }) @@ -734,6 +735,7 @@ func TestClientRetransmission(t *testing.T) { t.Fatal(err) } c.SetRTO(time.Second) + gotReads := make(chan struct{}) go func() { buf := make([]byte, 1500) readN, readErr := connL.Read(buf) @@ -750,6 +752,7 @@ func TestClientRetransmission(t *testing.T) { if !IsMessage(buf[:readN]) { t.Error("should be STUN") } + gotReads <- struct{}{} }() if doErr := c.Do(MustBuild(response, BindingRequest), func(event Event) { if event.Error != nil { @@ -758,5 +761,81 @@ func TestClientRetransmission(t *testing.T) { }); doErr != nil { t.Fatal(err) } + <-gotReads +} + +func testClientDoConcurrent(t *testing.T, concurrency int) { + response := MustBuild(TransactionID, BindingSuccess) + response.Encode() + connL, connR := net.Pipe() + defer connL.Close() + collector := new(manualCollector) + clock := &manualClock{current: time.Now()} + agent := &manualAgent{} + agent.start = func(id [TransactionIDSize]byte, deadline time.Time) error { + go agent.h(Event{ + TransactionID: id, + Message: response, + }) + return nil + } + c, err := NewClient(ClientOptions{ + Agent: agent, + Collector: collector, + Connection: connR, + Clock: clock, + }) + if err != nil { + t.Fatal(err) + } + c.SetRTO(time.Second) + connClosed := make(chan struct{}) + go func() { + defer func() { + connClosed <- struct{}{} + }() + buf := make([]byte, 1500) + for { + readN, readErr := connL.Read(buf) + if readErr != nil { + if readErr == io.EOF { + break + } + t.Error(readErr) + } + if !IsMessage(buf[:readN]) { + t.Error("should be STUN") + } + } + }() + wg := new(sync.WaitGroup) + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if doErr := c.Do(MustBuild(TransactionID, BindingRequest), func(event Event) { + if event.Error != nil { + t.Error("failed") + } + }); doErr != nil { + t.Error(doErr) + } + }() + } + wg.Wait() + if connErr := connR.Close(); connErr != nil { + t.Error(connErr) + } + <-connClosed +} +func TestClient_DoConcurrent(t *testing.T) { + t.Parallel() + for _, concurrency := range []int{ + 1, 5, 10, 25, 100, 500, + } { + t.Run(fmt.Sprintf("%d", concurrency), func(t *testing.T) { + testClientDoConcurrent(t, concurrency) + }) + } } diff --git a/go.test.sh b/go.test.sh index 8c3b012..b4ffaf7 100755 --- a/go.test.sh +++ b/go.test.sh @@ -12,6 +12,9 @@ go test ./... # test with "debug" tag go test -tags debug ./... +# test concurrency +go test -race -cpu=1,2,4 -run TestClient_DoConcurrent + for d in $(go list ./... | grep -v vendor); do go test -race -coverprofile=profile.out -covermode=atomic "$d" if [ -f profile.out ]; then