diff --git a/server/raft_transport.go b/server/raft_transport.go index a01aaefd..9265f22f 100644 --- a/server/raft_transport.go +++ b/server/raft_transport.go @@ -79,20 +79,35 @@ type natsConn struct { } func (n *natsConn) Read(b []byte) (int, error) { + var subTimeout time.Duration + n.mu.RLock() closed := n.closed - subTimeout := n.subTimeout - if subTimeout == 0 { - subTimeout = time.Duration(0x7FFFFFFFFFFFFFFF) + buf := n.pending + pendingSize := len(buf) + // We need the timeout only if we are going to call NextMsg, and if we + // have a pending, we won't. + if pendingSize == 0 { + subTimeout = n.subTimeout + if subTimeout == 0 { + subTimeout = time.Duration(0x7FFFFFFFFFFFFFFF) + } } n.mu.RUnlock() if closed { return 0, io.EOF } - buf := n.pending - if size := len(buf); size > 0 { - nb := copy(b, buf[:len(b)]) - if nb != size { + // If we have a pending, process that first. + if pendingSize > 0 { + // We will copy all data that we have if it can fit, or up to the + // caller's buffer size. + limit := pendingSize + if limit > len(b) { + limit = len(b) + } + nb := copy(b, buf[:limit]) + // If we did not copy everything, reduce size by what we copied. + if nb != pendingSize { buf = buf[nb:] } else { buf = nil diff --git a/server/raft_transport_test.go b/server/raft_transport_test.go index 2ed701af..6bfd0ac1 100644 --- a/server/raft_transport_test.go +++ b/server/raft_transport_test.go @@ -730,6 +730,47 @@ func TestRAFTTransportConnReader(t *testing.T) { }) } + firstPart := "Partial" + secondPart := " and then the rest" + if _, err := bToA.Write([]byte(firstPart + secondPart)); err != nil { + t.Fatalf("Error on write: %v", err) + } + n, err := fromB.Read(buf[:7]) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if string(buf[:n]) != firstPart { + t.Fatalf("Unexpected result: %q", buf[:n]) + } + // Now pass a buffer to Read() that is larger than what is left in pending + n, err = fromB.Read(buf[:]) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if string(buf[:n]) != secondPart { + t.Fatalf("Unexpected result: %q", buf[:n]) + } + + // Another test with a partial... + if _, err := bToA.Write([]byte("ab")); err != nil { + t.Fatalf("Error on write: %v", err) + } + n, err = fromB.Read(buf[:1]) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if string(buf[:n]) != "a" { + t.Fatalf("Unexpected result: %q", buf[:n]) + } + // There is only 1 byte that should be pending, but call with a large buffer. + n, err = fromB.Read(buf[:]) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if string(buf[:n]) != "b" { + t.Fatalf("Unexpected result: %q", buf[:n]) + } + // Write empty message should not go out if n, err := bToA.Write(nil); err != nil || n != 0 { t.Fatalf("Write nil should return 0, nil, got %v and %v", n, err) @@ -741,7 +782,7 @@ func TestRAFTTransportConnReader(t *testing.T) { } // Consume all at once - n, err := fromB.Read(buf[:]) + n, err = fromB.Read(buf[:]) if err != nil || n != 3 { t.Fatalf("Unexpected error on read, n=%v err=%v", n, err) } diff --git a/server/server_test.go b/server/server_test.go index 43f76d89..41d983a0 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1005,20 +1005,24 @@ func TestProtocolOrder(t *testing.T) { // Mix pub and subscribe calls ch = make(chan bool) - errCh = make(chan error) + errCh = make(chan error, 1) startSubAt := 50 var sub stan.Subscription var err error + first := true for i := 1; i <= 100; i++ { if err := sc.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Unexpected error on publish: %v", err) } if i == startSubAt { sub, err = sc.Subscribe("foo", func(m *stan.Msg) { - if m.Sequence == uint64(startSubAt)+1 { - ch <- true - } else if len(errCh) == 0 { - errCh <- fmt.Errorf("Received message %v instead of %v", m.Sequence, startSubAt+1) + if first { + if m.Sequence == uint64(startSubAt)+1 { + ch <- true + } else { + errCh <- fmt.Errorf("Received message %v instead of %v", m.Sequence, startSubAt+1) + } + first = false } }) if err != nil { @@ -1037,6 +1041,7 @@ func TestProtocolOrder(t *testing.T) { sub.Unsubscribe() // Acks should be processed before Connection close + errCh = make(chan error, 1) for i := 0; i < total; i++ { rcv := int32(0) sc2, err := stan.Connect(clusterName, "otherclient")