diff --git a/gossip/comm/comm_impl.go b/gossip/comm/comm_impl.go index 5a0222b4bae..c2605f2e809 100644 --- a/gossip/comm/comm_impl.go +++ b/gossip/comm/comm_impl.go @@ -112,7 +112,7 @@ func NewCommInstanceWithServer(port int, sec SecurityProvider, pkID common.PKIid // NewCommInstance creates a new comm instance that binds itself to the given gRPC server func NewCommInstance(s *grpc.Server, sec SecurityProvider, PKIID common.PKIidType, dialOpts ...grpc.DialOption) (Comm, error) { - commInst, err := NewCommInstanceWithServer(-1, sec, PKIID, dialOpts ...) + commInst, err := NewCommInstanceWithServer(-1, sec, PKIID, dialOpts...) if err != nil { return nil, err } @@ -162,12 +162,12 @@ func (c *commImpl) createConnection(endpoint string, expectedPKIID common.PKIidT if stream, err := cl.GossipStream(context.Background()); err == nil { pkiID, err := c.authenticateRemotePeer(stream) - if expectedPKIID != nil && !bytes.Equal(pkiID, expectedPKIID) { - // PKIID is nil when we don't know the remote PKI id's - c.logger.Warning("Remote endpoint claims to be a different peer, expected", expectedPKIID, "but got", pkiID) - return nil, fmt.Errorf("Authentication failure") - } if err == nil { + if expectedPKIID != nil && !bytes.Equal(pkiID, expectedPKIID) { + // PKIID is nil when we don't know the remote PKI id's + c.logger.Warning("Remote endpoint claims to be a different peer, expected", expectedPKIID, "but got", pkiID) + return nil, fmt.Errorf("Authentication failure") + } conn := newConnection(cl, cc, stream, nil) conn.pkiID = pkiID conn.logger = c.logger diff --git a/gossip/comm/comm_test.go b/gossip/comm/comm_test.go index e78a97c6710..8f33b845f4c 100644 --- a/gossip/comm/comm_test.go +++ b/gossip/comm/comm_test.go @@ -165,28 +165,26 @@ func TestHandshake(t *testing.T) { } func TestBasic(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(2000, naiveSec) comm2, _ := newCommInstance(3000, naiveSec) - defer comm1.Stop() - defer comm2.Stop() - time.Sleep(time.Duration(3) * time.Second) - msgs := make(chan *proto.GossipMessage, 2) - go func() { - m := <-comm2.Accept(acceptAll) - msgs <- m.GetGossipMessage() - }() - go func() { - m := <-comm1.Accept(acceptAll) - msgs <- m.GetGossipMessage() - }() - comm1.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:3000"), Endpoint: "localhost:3000"}) - time.Sleep(time.Second) - comm2.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:2000"), Endpoint: "localhost:2000"}) + m1 := comm1.Accept(acceptAll) + m2 := comm2.Accept(acceptAll) + out := make(chan uint64, 2) + reader := func(ch <-chan ReceivedMessage) { + m := <- ch + out <- m.GetGossipMessage().Nonce + } + go reader(m1) + go reader(m2) + comm1.Send(createGossipMsg(), remotePeer(3000)) time.Sleep(time.Second) - assert.Equal(t, 2, len(msgs)) + comm2.Send(createGossipMsg(), remotePeer(2000)) + waitForMessages(t, out, 2, "Didn't receive 2 messages") } func TestBlackListPKIid(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(1611, naiveSec) comm2, _ := newCommInstance(1612, naiveSec) comm3, _ := newCommInstance(1613, naiveSec) @@ -196,7 +194,7 @@ func TestBlackListPKIid(t *testing.T) { defer comm3.Stop() defer comm4.Stop() - reader := func(out chan uint64, in <-chan ReceivedMessage) { + reader := func(instance string, out chan uint64, in <-chan ReceivedMessage) { for { msg := <-in if msg == nil { @@ -206,52 +204,48 @@ func TestBlackListPKIid(t *testing.T) { } } - sender := func(comm Comm, port int, n int) { - endpoint := fmt.Sprintf("localhost:%d", port) - for i := 0; i < n; i++ { - comm.Send(createGossipMsg(), &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}) - time.Sleep(time.Duration(1) * time.Second) - } - } + out1 := make(chan uint64, 4) + out2 := make(chan uint64, 4) + out3 := make(chan uint64, 4) + out4 := make(chan uint64, 4) - out1 := make(chan uint64, 5) - out2 := make(chan uint64, 5) - out3 := make(chan uint64, 10) - out4 := make(chan uint64, 10) - - go reader(out1, comm1.Accept(acceptAll)) - go reader(out2, comm2.Accept(acceptAll)) - go reader(out3, comm3.Accept(acceptAll)) - go reader(out4, comm4.Accept(acceptAll)) + go reader("comm1", out1, comm1.Accept(acceptAll)) + go reader("comm2", out2, comm2.Accept(acceptAll)) + go reader("comm3", out3, comm3.Accept(acceptAll)) + go reader("comm4", out4, comm4.Accept(acceptAll)) // have comm1 BL comm3 comm1.BlackListPKIid([]byte("localhost:1613")) // make comm3 send to 1 and 2 - go sender(comm3, 1611, 5) - go sender(comm3, 1612, 5) + comm3.Send(createGossipMsg(), remotePeer(1612)) // out2++ + comm3.Send(createGossipMsg(), remotePeer(1611)) - // make comm1 and comm2 send to comm3 - go sender(comm1, 1613, 5) - go sender(comm2, 1613, 5) + waitForMessages(t, out2, 1, "comm2 should have received 1 message") - // make comm1 and comm2 send to comm4 which is not blacklisted - go sender(comm1, 1614, 5) - go sender(comm2, 1614, 5) + // make comm1 and comm2 send to comm3 + comm1.Send(createGossipMsg(), remotePeer(1613)) + comm2.Send(createGossipMsg(), remotePeer(1613)) // out3++ + waitForMessages(t, out3, 1, "comm3 should have received 1 message") - time.Sleep(time.Duration(1) * time.Second) + // make comm1 and comm2 send to comm4 which is not blacklisted // out4 += 4 + comm1.Send(createGossipMsg(), remotePeer(1614)) + comm2.Send(createGossipMsg(), remotePeer(1614)) + comm1.Send(createGossipMsg(), remotePeer(1614)) + comm2.Send(createGossipMsg(), remotePeer(1614)) - // blacklist comm3 mid-sending + // blacklist comm3 by comm2 comm2.BlackListPKIid([]byte("localhost:1613")) - time.Sleep(time.Duration(5) * time.Second) - assert.Equal(t, 0, len(out1), "Comm instance 1 received messages(%d) from comm3 although comm3 is black listed", len(out1)) - assert.True(t, len(out2) < 2, "Comm instance 2 received too many messages(%d) from comm3 although comm3 is black listed", len(out2)) - assert.True(t, len(out3) < 3, "Comm instance 3 received too many messages(%d) although black listed", len(out3)) - assert.Equal(t, 10, len(out4), "Comm instance 4 didn't receive all messages sent to it") + // send from comm1 and comm2 to comm3 again + comm1.Send(createGossipMsg(), remotePeer(1613)) // shouldn't have an effect + comm2.Send(createGossipMsg(), remotePeer(1613)) // shouldn't have an effect + + waitForMessages(t, out4, 4, "comm1 should have received 4 messages") } func TestParallelSend(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(5611, naiveSec) comm2, _ := newCommInstance(5612, naiveSec) defer comm1.Stop() @@ -266,7 +260,7 @@ func TestParallelSend(t *testing.T) { emptyMsg := createGossipMsg() go func() { defer wg.Done() - comm1.Send(emptyMsg, &RemotePeer{Endpoint: "localhost:5612", PKIID: []byte("localhost:5612")}) + comm1.Send(emptyMsg, remotePeer(5612)) }() } wg.Wait() @@ -293,6 +287,7 @@ func TestParallelSend(t *testing.T) { } func TestResponses(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(8611, naiveSec) comm2, _ := newCommInstance(8612, naiveSec) @@ -316,7 +311,7 @@ func TestResponses(t *testing.T) { responsesFromComm1 := comm2.Accept(acceptAll) ticker := time.NewTicker(time.Duration(6000) * time.Millisecond) - comm2.Send(msg, &RemotePeer{PKIID: []byte("localhost:8611"), Endpoint: "localhost:8611"}) + comm2.Send(msg, remotePeer(8611)) time.Sleep(time.Duration(100) * time.Millisecond) select { @@ -331,6 +326,7 @@ func TestResponses(t *testing.T) { } func TestAccept(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(7611, naiveSec) comm2, _ := newCommInstance(7612, naiveSec) @@ -348,11 +344,13 @@ func TestAccept(t *testing.T) { var evenResults []uint64 var oddResults []uint64 + out := make(chan uint64, defRecvBuffSize) sem := make(chan struct{}, 0) readIntoSlice := func(a *[]uint64, ch <-chan ReceivedMessage) { for m := range ch { *a = append(*a, m.GetGossipMessage().Nonce) + out <- m.GetGossipMessage().Nonce } sem <- struct{}{} } @@ -361,10 +359,10 @@ func TestAccept(t *testing.T) { go readIntoSlice(&oddResults, oddNONCES) for i := 0; i < defRecvBuffSize; i++ { - comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:7611", PKIID: []byte("localhost:7611")}) + comm2.Send(createGossipMsg(), remotePeer(7611)) } - time.Sleep(time.Duration(5) * time.Second) + waitForMessages(t, out, defRecvBuffSize, "Didn't receive all messages sent") comm1.Stop() comm2.Stop() @@ -386,6 +384,7 @@ func TestAccept(t *testing.T) { } func TestReConnections(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(3611, naiveSec) comm2, _ := newCommInstance(3612, naiveSec) @@ -406,10 +405,10 @@ func TestReConnections(t *testing.T) { go reader(out2, comm2.Accept(acceptAll)) // comm1 connects to comm2 - comm1.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3612", PKIID: []byte("localhost:3612")}) + comm1.Send(createGossipMsg(), remotePeer(3612)) time.Sleep(100 * time.Millisecond) // comm2 sends to comm1 - comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3611", PKIID: []byte("localhost:3611")}) + comm2.Send(createGossipMsg(), remotePeer(3611)) time.Sleep(100 * time.Millisecond) assert.Equal(t, 1, len(out2)) @@ -419,39 +418,43 @@ func TestReConnections(t *testing.T) { comm1, _ = newCommInstance(3611, naiveSec) go reader(out1, comm1.Accept(acceptAll)) time.Sleep(300 * time.Millisecond) - comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3611", PKIID: []byte("localhost:3611")}) + comm2.Send(createGossipMsg(), remotePeer(3611)) time.Sleep(100 * time.Millisecond) assert.Equal(t, 2, len(out1)) } func TestProbe(t *testing.T) { + t.Parallel() comm1, _ := newCommInstance(6611, naiveSec) defer comm1.Stop() comm2, _ := newCommInstance(6612, naiveSec) time.Sleep(time.Duration(1) * time.Second) - assert.NoError(t, comm1.Probe(&RemotePeer{Endpoint: "localhost:6612", PKIID: []byte("localhost:6612")})) - assert.Error(t, comm1.Probe(&RemotePeer{Endpoint: "localhost:9012", PKIID: []byte("localhost:9012")})) + assert.NoError(t, comm1.Probe(remotePeer(6612))) + assert.Error(t, comm1.Probe(remotePeer(9012))) comm2.Stop() time.Sleep(time.Second) - assert.Error(t, comm1.Probe(&RemotePeer{Endpoint: "localhost:6612", PKIID: []byte("localhost:6612")})) + assert.Error(t, comm1.Probe(remotePeer(6612))) comm2, _ = newCommInstance(6612, naiveSec) defer comm2.Stop() time.Sleep(time.Duration(1) * time.Second) - assert.NoError(t, comm2.Probe(&RemotePeer{Endpoint: "localhost:6611", PKIID: []byte("localhost:6611")})) - assert.NoError(t, comm1.Probe(&RemotePeer{Endpoint: "localhost:6612", PKIID: []byte("localhost:6612")})) + assert.NoError(t, comm2.Probe(remotePeer(6611))) + assert.NoError(t, comm1.Probe(remotePeer(6612))) } func TestPresumedDead(t *testing.T) { - comm1, _ := newCommInstance(7611, naiveSec) - defer comm1.Stop() - comm2, _ := newCommInstance(7612, naiveSec) - go comm1.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:7612"), Endpoint: "localhost:7612"}) + t.Parallel() + comm1, _ := newCommInstance(4611, naiveSec) + comm2, _ := newCommInstance(4612, naiveSec) + go comm1.Send(createGossipMsg(), remotePeer(4612)) <-comm2.Accept(acceptAll) comm2.Stop() - for i := 0; i < 5; i++ { - comm1.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:7612", PKIID: []byte("localhost:7612")}) - time.Sleep(time.Second) - } + go func() { + for i := 0; i < 5; i++ { + comm1.Send(createGossipMsg(), remotePeer(4612)) + time.Sleep(time.Millisecond * 200) + } + }() + ticker := time.NewTicker(time.Second * time.Duration(3)) select { case <-ticker.C: @@ -472,3 +475,28 @@ func createGossipMsg() *proto.GossipMessage { }, } } + +func remotePeer(port int) *RemotePeer { + endpoint := fmt.Sprintf("localhost:%d", port) + return &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)} +} + +func waitForMessages(t *testing.T, msgChan chan uint64, count int, errMsg string) { + c := 0 + waiting := true + ticker := time.NewTicker(time.Duration(5) * time.Second) + for waiting { + select { + case <-msgChan: + c++ + if c == count { + waiting = false + } + break + case <-ticker.C: + waiting = false + break + } + } + assert.Equal(t, count, c, errMsg) +}