Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aboard Read()/Write() by SetDeadline() with past time #51

Closed
wants to merge 7 commits into from
5 changes: 0 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package yamux

import (
"fmt"
"testing"
)

Expand Down Expand Up @@ -85,7 +84,6 @@ func BenchmarkSendRecvLarge(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

const sendSize = 512 * 1024 * 1024
const recvSize = 4 * 1024

Expand All @@ -107,9 +105,6 @@ func BenchmarkSendRecvLarge(b *testing.B) {
b.Fatalf("err: %v", err)
}
}

fmt.Printf("Capacity of rcv buffer = %v, length of rcv window = %v\n", stream.recvBuf.Cap(), stream.recvWindow)

}
close(recvDone)
}()
Expand Down
53 changes: 35 additions & 18 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,17 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()

ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
Expand All @@ -355,8 +364,17 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e
// the send happens right here, we enforce the connection write timeout if we
// can't queue the header to be sent.
func (s *Session) sendNoWait(hdr header) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()

select {
case s.sendCh <- sendReady{Hdr: hdr}:
Expand Down Expand Up @@ -414,11 +432,20 @@ func (s *Session) recv() {
}
}

// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)

// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
Expand All @@ -434,22 +461,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}

// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}

// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
return err
}
}
Expand Down
120 changes: 112 additions & 8 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,12 @@ func TestSendData_Large(t *testing.T) {
defer client.Close()
defer server.Close()

data := make([]byte, 512*1024)
const (
sendSize = 250 * 1024 * 1024
recvSize = 4 * 1024
)

data := make([]byte, sendSize)
for idx := range data {
data[idx] = byte(idx % 256)
}
Expand All @@ -390,16 +395,17 @@ func TestSendData_Large(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

buf := make([]byte, 4*1024)
for i := 0; i < 128; i++ {
var sz int
buf := make([]byte, recvSize)
for i := 0; i < sendSize/recvSize; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4*1024 {
if n != recvSize {
t.Fatalf("short read: %d", n)
}
sz += n
for idx := range buf {
if buf[idx] != byte(idx%256) {
t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
Expand All @@ -410,6 +416,8 @@ func TestSendData_Large(t *testing.T) {
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}

t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()

go func() {
Expand Down Expand Up @@ -439,7 +447,7 @@ func TestSendData_Large(t *testing.T) {
}()
select {
case <-doneCh:
case <-time.After(time.Second):
case <-time.After(5 * time.Second):
panic("timeout")
}
}
Expand Down Expand Up @@ -688,6 +696,20 @@ func TestReadDeadline(t *testing.T) {
if _, err := stream.Read(buf); err != ErrTimeout {
t.Fatalf("err: %v", err)
}

// reset to no dead line
if err := stream.SetReadDeadline(time.Time{}); err != nil {
t.Fatalf("err: %v", err)
}

go func() {
time.Sleep(5 * time.Microsecond)
stream.SetReadDeadline(time.Unix(1, 0)) // net.aLongTimeAgo
}()

if _, err := stream.Read(buf); err != ErrTimeout {
t.Fatalf("err: %v", err)
}
}

func TestWriteDeadline(t *testing.T) {
Expand All @@ -712,15 +734,43 @@ func TestWriteDeadline(t *testing.T) {
}

buf := make([]byte, 512)
ok := false
for i := 0; i < int(initialStreamWindow); i++ {
_, err := stream.Write(buf)
if err != nil && err == ErrTimeout {
ok = true
break
} else if err != nil {
t.Fatalf("err: %v", err)
}
}
if !ok {
t.Fatalf("Expected timeout")
}

// reset to no dead line
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
t.Fatalf("err: %v", err)
}

go func() {
time.Sleep(5 * time.Microsecond)
stream.SetWriteDeadline(time.Unix(1, 0)) // net.aLongTimeAgo
}()

ok = false
for i := 0; i < int(initialStreamWindow); i++ {
_, err := stream.Write(buf)
if err != nil && err == ErrTimeout {
return
ok = true
break
} else if err != nil {
t.Fatalf("err: %v", err)
}
}
t.Fatalf("Expected timeout")
if !ok {
t.Fatalf("Expected timeout")
}
}

func TestBacklogExceeded(t *testing.T) {
Expand Down Expand Up @@ -1026,6 +1076,60 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
wg.Wait()
}

func TestSession_PartialReadWindowUpdate(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()

var wg sync.WaitGroup
wg.Add(1)

// Choose a huge flood size that we know will result in a window update.
flood := int64(client.config.MaxStreamWindowSize)
var wr *Stream

// The server will accept a new stream and then flood data to it.
go func() {
defer wg.Done()

var err error
wr, err = server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer wr.Close()

if wr.sendWindow != client.config.MaxStreamWindowSize {
t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
}

n, err := wr.Write(make([]byte, flood))
if err != nil {
t.Fatalf("err: %v", err)
}
if int64(n) != flood {
t.Fatalf("short write: %d", n)
}
if wr.sendWindow != 0 {
t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
}
}()

stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()

wg.Wait()

_, err = stream.Read(make([]byte, flood/2+1))

if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
}
}

func TestSession_sendNoWait_Timeout(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
Expand Down
29 changes: 20 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,25 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow

// Determine the flags if any
flags := s.sendFlags()

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}

// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
Expand Down Expand Up @@ -396,16 +403,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}

// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}

if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
Expand All @@ -418,7 +427,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}

// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow -= length
s.recvLock.Unlock()

// Unblock any readers
Expand All @@ -437,15 +446,17 @@ func (s *Stream) SetDeadline(t time.Time) error {
return nil
}

// SetReadDeadline sets the deadline for future Read calls.
// SetReadDeadline sets the deadline for pending and future Read calls.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline.Store(t)
asyncNotify(s.recvNotifyCh)
return nil
}

// SetWriteDeadline sets the deadline for future Write calls
// SetWriteDeadline sets the deadline for pending and future Write calls
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline.Store(t)
asyncNotify(s.sendNotifyCh)
return nil
}

Expand Down
15 changes: 15 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
package yamux

import (
"sync"
"time"
)

var (
timerPool = &sync.Pool{
New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down