Skip to content

Commit

Permalink
fix: poller read all data before connection close
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Feb 28, 2023
1 parent 9ddc97b commit 2206ca1
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 30 deletions.
70 changes: 64 additions & 6 deletions netpoll_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"context"
"errors"
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -248,16 +251,18 @@ func TestCloseCallbackWhenOnConnect(t *testing.T) {
MustNil(t, err)
}

func TestCloseAndWrite(t *testing.T) {
func TestServerReadAndClose(t *testing.T) {
var network, address = "tcp", ":18888"
var sendMsg = []byte("hello")
var closed int32
var loop = newTestEventLoop(network, address,
func(ctx context.Context, connection Connection) error {
_, err := connection.Reader().Next(len(sendMsg))
MustNil(t, err)

err = connection.Close()
MustNil(t, err)
atomic.AddInt32(&closed, 1)
return nil
},
)
Expand All @@ -269,7 +274,10 @@ func TestCloseAndWrite(t *testing.T) {
err = conn.Writer().Flush()
MustNil(t, err)

time.Sleep(time.Millisecond * 100) // wait for poller close connection
for atomic.LoadInt32(&closed) == 0 {
runtime.Gosched() // wait for poller close connection
}
time.Sleep(time.Millisecond * 50)
_, err = conn.Writer().WriteBinary(sendMsg)
MustNil(t, err)
err = conn.Writer().Flush()
Expand All @@ -279,9 +287,59 @@ func TestCloseAndWrite(t *testing.T) {
MustNil(t, err)
}

func TestClientWriteAndClose(t *testing.T) {
var (
network, address = "tcp", ":18889"
connnum = 10
packetsize, packetnum = 1000 * 5, 1
recvbytes int32 = 0
)
var loop = newTestEventLoop(network, address,
func(ctx context.Context, connection Connection) error {
buf, err := connection.Reader().Next(connection.Reader().Len())
if errors.Is(err, ErrConnClosed) {
return err
}
MustNil(t, err)
atomic.AddInt32(&recvbytes, int32(len(buf)))
return nil
},
)
var wg sync.WaitGroup
for i := 0; i < connnum; i++ {
wg.Add(1)
go func() {
defer wg.Done()
var conn, err = DialConnection(network, address, time.Second)
MustNil(t, err)
sendMsg := make([]byte, packetsize)
for j := 0; j < packetnum; j++ {
_, err = conn.Write(sendMsg)
MustNil(t, err)
}
err = conn.Close()
MustNil(t, err)
}()
}
wg.Wait()
exceptbytes := int32(packetsize * packetnum * connnum)
for atomic.LoadInt32(&recvbytes) != exceptbytes {
t.Logf("left %d bytes not received", exceptbytes-atomic.LoadInt32(&recvbytes))
runtime.Gosched()
}
err := loop.Shutdown(context.Background())
MustNil(t, err)
}

func newTestEventLoop(network, address string, onRequest OnRequest, opts ...Option) EventLoop {
var listener, _ = CreateListener(network, address)
var eventLoop, _ = NewEventLoop(onRequest, opts...)
go eventLoop.Serve(listener)
return eventLoop
ln, err := CreateListener(network, address)
if err != nil {
panic(err)
}
elp, err := NewEventLoop(onRequest, opts...)
if err != nil {
panic(err)
}
go elp.Serve(ln)
return elp
}
23 changes: 23 additions & 0 deletions poll_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,26 @@ func (p *defaultPoll) onhups() {
}
}(hups)
}

// readall read all left data before close connection
func readall(op *FDOperator, br barrier) (err error) {
var bs = br.bs
var ivs = br.ivs
var n int
for {
bs = op.Inputs(br.bs)
if len(bs) == 0 {
return nil
}

TryRead:
n, err = ioread(op.FD, bs, ivs)
op.InputAck(n)
if err != nil {
return err
}
if n == 0 && err == nil {
goto TryRead
}
}
}
29 changes: 17 additions & 12 deletions poll_default_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (p *defaultPoll) Wait() error {
barriers[i].ivs = make([]syscall.Iovec, caps)
}
// wait
var triggerRead, triggerWrite, triggerHup bool
for {
n, err := syscall.Kevent(p.fd, nil, events, nil)
if err != nil && err != syscall.EINTR {
Expand All @@ -75,20 +76,23 @@ func (p *defaultPoll) Wait() error {
return err
}
for i := 0; i < n; i++ {
var fd = int(events[i].Ident)
evt := events[i]
triggerRead = evt.Filter == syscall.EVFILT_READ && evt.Flags&syscall.EV_ENABLE != 0
triggerWrite = evt.Filter == syscall.EVFILT_WRITE && evt.Flags&syscall.EV_ENABLE != 0
triggerHup = evt.Flags&syscall.EV_EOF != 0

// trigger
if fd == 0 {
if evt.Ident == 0 {
// clean trigger
atomic.StoreUint32(&p.trigger, 0)
continue
}
var operator = p.getOperator(fd, unsafe.Pointer(&events[i].Udata))
if operator == nil || !operator.do() {
var operator = *(**FDOperator)(unsafe.Pointer(&evt.Udata))
if !operator.do() {
continue
}

// check poll in
if events[i].Filter == syscall.EVFILT_READ && events[i].Flags&syscall.EV_ENABLE != 0 {
if triggerRead {
if operator.OnRead != nil {
// for non-connection
operator.OnRead(p)
Expand All @@ -105,15 +109,16 @@ func (p *defaultPoll) Wait() error {
}
}
}

// check hup
if events[i].Flags&syscall.EV_EOF != 0 {
if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close
if err = readall(operator, barriers[i]); err != nil {
logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error())
}
}
if triggerHup {
p.appendHup(operator)
continue
}

// check poll out
if events[i].Filter == syscall.EVFILT_WRITE && events[i].Flags&syscall.EV_ENABLE != 0 {
if triggerWrite {
if operator.OnWrite != nil {
// for non-connection
operator.OnWrite(p)
Expand Down
30 changes: 18 additions & 12 deletions poll_default_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"unsafe"
)

// Includes defaultPoll/multiPoll/uringPoll...
func openPoll() Poll {
return openDefaultPoll()
}
Expand Down Expand Up @@ -109,9 +108,16 @@ func (p *defaultPoll) Wait() (err error) {
}

func (p *defaultPoll) handler(events []epollevent) (closed bool) {
var triggerRead, triggerWrite, triggerHup, triggerError bool
for i := range events {
operator := p.getOperator(0, unsafe.Pointer(&events[i].data))
if operator == nil || !operator.do() {
evt := events[i].events
triggerRead = evt&syscall.EPOLLIN != 0
triggerWrite = evt&syscall.EPOLLOUT != 0
triggerHup = evt&(syscall.EPOLLHUP|syscall.EPOLLRDHUP) != 0
triggerError = evt&syscall.EPOLLERR != 0

var operator = *(**FDOperator)(unsafe.Pointer(&events[i].data))
if !operator.do() {
continue
}
// trigger or exit gracefully
Expand All @@ -130,9 +136,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
continue
}

evt := events[i].events
// check poll in
if evt&syscall.EPOLLIN != 0 {
if triggerRead {
if operator.OnRead != nil {
// for non-connection
operator.OnRead(p)
Expand All @@ -151,13 +155,16 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator)
}
}

// check hup
if evt&(syscall.EPOLLHUP|syscall.EPOLLRDHUP) != 0 {
if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close
if err := readall(operator, p.barriers[i]); err != nil {
logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error())
}
}
if triggerHup {
p.appendHup(operator)
continue
}
if evt&syscall.EPOLLERR != 0 {
if triggerError {
// Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN.
// So here we need to check this error, if it is EAGAIN then do nothing, otherwise still mark as hup.
if _, _, _, _, err := syscall.Recvmsg(operator.FD, nil, nil, syscall.MSG_ERRQUEUE); err != syscall.EAGAIN {
Expand All @@ -167,8 +174,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
}
continue
}
// check poll out
if evt&syscall.EPOLLOUT != 0 {
if triggerWrite {
if operator.OnWrite != nil {
// for non-connection
operator.OnWrite(p)
Expand Down

0 comments on commit 2206ca1

Please sign in to comment.