diff --git a/const.go b/const.go index 9e11ba3..c326f96 100644 --- a/const.go +++ b/const.go @@ -114,7 +114,8 @@ const ( const ( // initialStreamWindow is the initial stream window size - initialStreamWindow uint32 = 256 * 1024 + initialStreamWindow uint32 = 64 * 1024 + maxStreamWindow uint32 = 16 * 1024 * 1024 ) const ( diff --git a/mux.go b/mux.go index 8b9b154..dd5b598 100644 --- a/mux.go +++ b/mux.go @@ -60,7 +60,7 @@ func DefaultConfig() *Config { EnableKeepAlive: true, KeepAliveInterval: 30 * time.Second, ConnectionWriteTimeout: 10 * time.Second, - MaxStreamWindowSize: initialStreamWindow, + MaxStreamWindowSize: maxStreamWindow, LogOutput: os.Stderr, ReadBufSize: 4096, MaxMessageSize: 64 * 1024, // Means 64KiB/10s = 52kbps minimum speed. diff --git a/session.go b/session.go index 55b4e28..b10d466 100644 --- a/session.go +++ b/session.go @@ -50,6 +50,8 @@ type Session struct { pingID uint32 activePing *ping + rtt int64 // to be accessed atomically, in nanoseconds + // streams maps a stream id to a stream, and inflight has an entry // for any outgoing stream that has not yet been established. Both are // protected by streamLock. @@ -129,6 +131,7 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Sessio } go s.recv() go s.send() + go s.measureRTT() return s } @@ -291,6 +294,19 @@ func (s *Session) goAway(reason uint32) header { return hdr } +func (s *Session) measureRTT() { + rtt, err := s.Ping() + if err != nil { + return + } + atomic.StoreInt64(&s.rtt, rtt.Nanoseconds()) +} + +// 0 if we don't yet have a measurement +func (s *Session) getRTT() time.Duration { + return time.Duration(atomic.LoadInt64(&s.rtt)) +} + // Ping is used to measure the RTT response time func (s *Session) Ping() (dur time.Duration, err error) { // Prepare a ping. diff --git a/session_norace_test.go b/session_norace_test.go index 4c45bd7..3e70056 100644 --- a/session_norace_test.go +++ b/session_norace_test.go @@ -159,7 +159,7 @@ func TestLargeWindow(t *testing.T) { if err != nil { t.Fatal(err) } - buf := make([]byte, conf.MaxStreamWindowSize) + buf := make([]byte, int(initialStreamWindow)) n, err := stream.Write(buf) if err != nil { t.Fatalf("err: %v", err) diff --git a/session_test.go b/session_test.go index 7f99f98..57f33d6 100644 --- a/session_test.go +++ b/session_test.go @@ -1171,7 +1171,7 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { wg.Add(1) // Choose a huge flood size that we know will result in a window update. - flood := int64(client.config.MaxStreamWindowSize) + flood := int64(initialStreamWindow) var wr *Stream // The server will accept a new stream and then flood data to it. @@ -1186,8 +1186,8 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { } sendWindow := atomic.LoadUint32(&wr.sendWindow) - if sendWindow != client.config.MaxStreamWindowSize { - t.Errorf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, sendWindow) + if sendWindow != initialStreamWindow { + t.Errorf("sendWindow: exp=%d, got=%d", initialStreamWindow, sendWindow) return } @@ -1221,8 +1221,9 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { } var ( - exp = uint32(flood / 2) - sendWindow uint32 + expWithoutWindowIncrease = uint32(flood / 2) + expWithWindowIncrease = uint32(flood * 3 / 2) + sendWindow uint32 ) // This test is racy. Wait a short period, then longer and longer. At @@ -1230,11 +1231,11 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { for i := 1; i < 15; i++ { time.Sleep(time.Duration(i*i) * time.Millisecond) sendWindow = atomic.LoadUint32(&wr.sendWindow) - if sendWindow == exp { + if sendWindow == expWithoutWindowIncrease || sendWindow == expWithWindowIncrease { return } } - t.Errorf("sendWindow: exp=%d, got=%d", exp, sendWindow) + t.Errorf("sendWindow: exp=%d or %d, got=%d", expWithoutWindowIncrease, expWithWindowIncrease, sendWindow) } func TestSession_sendMsg_Timeout(t *testing.T) { diff --git a/stream.go b/stream.go index 037c22d..195968d 100644 --- a/stream.go +++ b/stream.go @@ -55,7 +55,7 @@ func newStream(session *Session, id uint32, state streamState) *Stream { sendWindow: initialStreamWindow, readDeadline: makePipeDeadline(), writeDeadline: makePipeDeadline(), - recvBuf: newSegmentedBuffer(initialStreamWindow), + recvBuf: newSegmentedBuffer(initialStreamWindow, session.config.MaxStreamWindowSize, session.getRTT), recvNotifyCh: make(chan struct{}, 1), sendNotifyCh: make(chan struct{}, 1), } @@ -202,11 +202,8 @@ func (s *Stream) sendWindowUpdate() error { // Determine the flags if any flags := s.sendFlags() - // Determine the delta update - max := s.session.config.MaxStreamWindowSize - // Update our window - needed, delta := s.recvBuf.GrowTo(max, flags != 0) + needed, delta := s.recvBuf.GrowTo(flags != 0, time.Now()) if !needed { return nil } diff --git a/util.go b/util.go index 177eb98..d3b980a 100644 --- a/util.go +++ b/util.go @@ -3,7 +3,9 @@ package yamux import ( "fmt" "io" + "math" "sync" + "time" pool "github.com/libp2p/go-buffer-pool" ) @@ -65,18 +67,30 @@ func min(values ...uint32) uint32 { // < len (5) > < cap (5) > // type segmentedBuffer struct { - cap uint32 - len uint32 - bm sync.Mutex + cap uint32 + len uint32 + windowSize uint32 + maxWindowSize uint32 + bm sync.Mutex // read position in b[0]. // We must not reslice any of the buffers in b, as we need to put them back into the pool. readPos int b [][]byte + + epochStart time.Time + getRTT func() time.Duration } // NewSegmentedBuffer allocates a ring buffer. -func newSegmentedBuffer(initialCapacity uint32) segmentedBuffer { - return segmentedBuffer{cap: initialCapacity, b: make([][]byte, 0)} +func newSegmentedBuffer(initialCapacity, maxWindowSize uint32, getRTT func() time.Duration) segmentedBuffer { + return segmentedBuffer{ + cap: initialCapacity, + windowSize: initialCapacity, + maxWindowSize: maxWindowSize, + b: make([][]byte, 0), + epochStart: time.Now(), + getRTT: getRTT, + } } // Len is the amount of data in the receive buffer. @@ -88,20 +102,37 @@ func (s *segmentedBuffer) Len() uint32 { // If the space to write into + current buffer size has grown to half of the window size, // grow up to that max size, and indicate how much additional space was reserved. -func (s *segmentedBuffer) GrowTo(max uint32, force bool) (bool, uint32) { +func (s *segmentedBuffer) GrowTo(force bool, now time.Time) (bool, uint32) { + grow, delta := s.growTo(force, now) + if grow { + s.epochStart = now + } + return grow, delta +} + +func (s *segmentedBuffer) growTo(force bool, now time.Time) (bool, uint32) { s.bm.Lock() defer s.bm.Unlock() currentWindow := s.cap + s.len - if currentWindow >= max { + if currentWindow >= s.windowSize { return force, 0 } - delta := max - currentWindow + delta := s.windowSize - currentWindow - if delta < (max/2) && !force { + if delta < (s.windowSize/2) && !force { return false, 0 } + if rtt := s.getRTT(); rtt > 0 && now.Sub(s.epochStart) < 2*rtt { + if s.windowSize > math.MaxUint32/2 { + s.windowSize = min(math.MaxUint32, s.maxWindowSize) + } else { + s.windowSize = min(s.windowSize*2, s.maxWindowSize) + } + delta = s.windowSize - currentWindow + } + s.cap += delta return true, delta } diff --git a/util_test.go b/util_test.go index 90b9cbe..21076f3 100644 --- a/util_test.go +++ b/util_test.go @@ -5,6 +5,7 @@ import ( "io" "io/ioutil" "testing" + "time" ) func TestAsyncSendErr(t *testing.T) { @@ -53,7 +54,7 @@ func TestMin(t *testing.T) { } func TestSegmentedBuffer(t *testing.T) { - buf := newSegmentedBuffer(100) + buf := newSegmentedBuffer(100, 100, func() time.Duration { return 0 }) assert := func(len, cap uint32) { if buf.Len() != len { t.Fatalf("expected length %d, got %d", len, buf.Len()) @@ -79,10 +80,10 @@ func TestSegmentedBuffer(t *testing.T) { t.Fatalf("expected to read 2 bytes, read %d", n) } assert(1, 97) - if grew, amount := buf.GrowTo(100, false); grew || amount != 0 { + if grew, amount := buf.GrowTo(false, time.Now()); grew || amount != 0 { t.Fatal("should not grow when too small") } - if grew, amount := buf.GrowTo(100, true); !grew || amount != 2 { + if grew, amount := buf.GrowTo(true, time.Now()); !grew || amount != 2 { t.Fatal("should have grown by 2") } @@ -90,7 +91,7 @@ func TestSegmentedBuffer(t *testing.T) { t.Fatal(err) } assert(51, 49) - if grew, amount := buf.GrowTo(100, false); grew || amount != 0 { + if grew, amount := buf.GrowTo(false, time.Now()); grew || amount != 0 { t.Fatal("should not grow when data hasn't been read") } read, err := io.CopyN(ioutil.Discard, &buf, 50) @@ -102,8 +103,64 @@ func TestSegmentedBuffer(t *testing.T) { } assert(1, 49) - if grew, amount := buf.GrowTo(100, false); !grew || amount != 50 { + if grew, amount := buf.GrowTo(false, time.Now()); !grew || amount != 50 { t.Fatal("should have grown when below half, even with reserved space") } assert(1, 99) } + +func TestSegmentedBuffer_WindowAutoSizing(t *testing.T) { + receiveAndConsume := func(buf *segmentedBuffer, size uint32) { + if err := buf.Append(bytes.NewReader(make([]byte, size)), size); err != nil { + t.Fatal(err) + } + if _, err := buf.Read(make([]byte, size)); err != nil { + t.Fatal(err) + } + } + const rtt = 10 * time.Millisecond + const initialWindow uint32 = 10 + t.Run("finding the window size", func(t *testing.T) { + buf := newSegmentedBuffer(initialWindow, 1000*initialWindow, func() time.Duration { return rtt }) + start := time.Now() + // Consume a maximum of 1234 bytes per RTT. + // We expect the window to be scaled such that we send one update every 2 RTTs. + now := start + delta := initialWindow + for i := 0; i < 100; i++ { + now = now.Add(rtt) + size := delta + if size > 1234 { + size = 1234 + } + receiveAndConsume(&buf, size) + grow, d := buf.GrowTo(false, now) + if grow { + delta = d + } + } + if !(buf.windowSize > 2*1234 && buf.windowSize < 3*1234) { + t.Fatalf("unexpected window size: %d", buf.windowSize) + } + }) + t.Run("capping the window size", func(t *testing.T) { + const maxWindow = 78 * initialWindow + buf := newSegmentedBuffer(initialWindow, maxWindow, func() time.Duration { return rtt }) + start := time.Now() + // Consume a maximum of 1234 bytes per RTT. + // We expect the window to be scaled such that we send one update every 2 RTTs. + now := start + delta := initialWindow + for i := 0; i < 100; i++ { + now = now.Add(rtt) + receiveAndConsume(&buf, delta) + grow, d := buf.GrowTo(false, now) + if grow { + delta = d + } + } + if buf.windowSize != maxWindow { + t.Fatalf("expected the window size to be at max (%d), got %d", maxWindow, buf.windowSize) + } + }) +}