Skip to content

Commit

Permalink
feat: add WithOnDisconnect callback (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Feb 2, 2024
1 parent 329188d commit c3792e8
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 5 deletions.
2 changes: 1 addition & 1 deletion connection.go
Expand Up @@ -19,7 +19,7 @@ import (
"time"
)

// CloseCallback will be called when the connection is closed.
// CloseCallback will be called after the connection is closed.
// Return: error is unused which will be ignored directly.
type CloseCallback func(connection Connection) error

Expand Down
26 changes: 22 additions & 4 deletions connection_onevent.go
Expand Up @@ -48,10 +48,11 @@ type gracefulExit interface {
// OnPrepare, OnRequest, CloseCallback share the lock processing,
// which is a CAS lock and can only be cleared by OnRequest.
type onEvent struct {
ctx context.Context
onConnectCallback atomic.Value
onRequestCallback atomic.Value
closeCallbacks atomic.Value // value is latest *callbackNode
ctx context.Context
onConnectCallback atomic.Value
onDisconnectCallback atomic.Value
onRequestCallback atomic.Value
closeCallbacks atomic.Value // value is latest *callbackNode
}

type callbackNode struct {
Expand All @@ -67,6 +68,14 @@ func (c *connection) SetOnConnect(onConnect OnConnect) error {
return nil
}

// SetOnDisconnect set the OnDisconnect callback.
func (c *connection) SetOnDisconnect(onDisconnect OnDisconnect) error {
if onDisconnect != nil {
c.onDisconnectCallback.Store(onDisconnect)
}
return nil
}

// SetOnRequest initialize ctx when setting OnRequest.
func (c *connection) SetOnRequest(onRequest OnRequest) error {
if onRequest == nil {
Expand Down Expand Up @@ -99,6 +108,7 @@ func (c *connection) AddCloseCallback(callback CloseCallback) error {
func (c *connection) onPrepare(opts *options) (err error) {
if opts != nil {
c.SetOnConnect(opts.onConnect)
c.SetOnDisconnect(opts.onDisconnect)
c.SetOnRequest(opts.onRequest)
c.SetReadTimeout(opts.readTimeout)
c.SetWriteTimeout(opts.writeTimeout)
Expand Down Expand Up @@ -150,6 +160,14 @@ func (c *connection) onConnect() {
)
}

func (c *connection) onDisconnect() {
var onDisconnect, _ = c.onDisconnectCallback.Load().(OnDisconnect)
if onDisconnect == nil {
return
}
onDisconnect(c.ctx, c)
}

// onRequest is responsible for executing the closeCallbacks after the connection has been closed.
func (c *connection) onRequest() (needTrigger bool) {
var onRequest, ok = c.onRequestCallback.Load().(OnRequest)
Expand Down
4 changes: 4 additions & 0 deletions connection_reactor.go
Expand Up @@ -30,6 +30,10 @@ func (c *connection) onHup(p Poll) error {
}
c.triggerRead(Exception(ErrEOF, "peer close"))
c.triggerWrite(Exception(ErrConnClosed, "peer close"))

// call Disconnect callback first
c.onDisconnect()

// It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively.
// It can be confirmed that the OnRequest goroutine has been exited before closeCallback executing,
// and it is safe to close the buffer at this time.
Expand Down
14 changes: 14 additions & 0 deletions eventloop.go
Expand Up @@ -34,6 +34,15 @@ type EventLoop interface {
Shutdown(ctx context.Context) error
}

/* The Connection Callback Sequence Diagram
| Connection State | Callback Function | Notes
| Connected but not initialized | OnPrepare | Conn is not registered into poller
| Connected and initialized | OnConnect | Conn is ready for read or write
| Read first byte | OnRequest | Conn is ready for read or write
| Peer closed but conn is active | OnDisconnect | Conn access will race with OnRequest function
| Self closed and conn is closed | CloseCallback | Conn is destroyed
*/

// OnPrepare is used to inject custom preparation at connection initialization,
// which is optional but important in some scenarios. For example, a qps limiter
// can be set by closing overloaded connections directly in OnPrepare.
Expand Down Expand Up @@ -63,6 +72,11 @@ type OnPrepare func(connection Connection) context.Context
// }
type OnConnect func(ctx context.Context, connection Connection) context.Context

// OnDisconnect is called once connection is going to be closed.
// OnDisconnect must return as quick as possible because it will block poller.
// OnDisconnect is different from CloseCallback, you could check with "The Connection Callback Sequence Diagram" section.
type OnDisconnect func(ctx context.Context, connection Connection)

// OnRequest defines the function for handling connection. When data is sent from the connection peer,
// netpoll actively reads the data in LT mode and places it in the connection's input buffer.
// Generally, OnRequest starts handling the data in the following way:
Expand Down
8 changes: 8 additions & 0 deletions netpoll_options.go
Expand Up @@ -77,6 +77,13 @@ func WithOnConnect(onConnect OnConnect) Option {
}}
}

// WithOnDisconnect registers the OnDisconnect method to EventLoop.
func WithOnDisconnect(onDisconnect OnDisconnect) Option {
return Option{func(op *options) {
op.onDisconnect = onDisconnect
}}
}

// WithReadTimeout sets the read timeout of connections.
func WithReadTimeout(timeout time.Duration) Option {
return Option{func(op *options) {
Expand Down Expand Up @@ -106,6 +113,7 @@ type Option struct {
type options struct {
onPrepare OnPrepare
onConnect OnConnect
onDisconnect OnDisconnect
onRequest OnRequest
readTimeout time.Duration
writeTimeout time.Duration
Expand Down
60 changes: 60 additions & 0 deletions netpoll_test.go
Expand Up @@ -136,6 +136,66 @@ func TestOnConnectWrite(t *testing.T) {
MustNil(t, err)
}

func TestOnDisconnect(t *testing.T) {
var ctxKey = struct{}{}
var network, address = "tcp", ":8888"
var canceled, closed int32
var conns int32 = 100
req := "ping"
var loop = newTestEventLoop(network, address,
func(ctx context.Context, connection Connection) error {
cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc)
MustTrue(t, cancelFunc != nil)
Assert(t, ctx.Done() != nil)

buf, err := connection.Reader().Next(4) // should consumed all data
MustNil(t, err)
Equal(t, string(buf), req)
select {
case <-ctx.Done():
atomic.AddInt32(&canceled, 1)
case <-time.After(time.Second):
}
return nil
},
WithOnConnect(func(ctx context.Context, conn Connection) context.Context {
conn.AddCloseCallback(func(connection Connection) error {
atomic.AddInt32(&closed, 1)
return nil
})
ctx, cancel := context.WithCancel(ctx)
return context.WithValue(ctx, ctxKey, cancel)
}),
WithOnDisconnect(func(ctx context.Context, conn Connection) {
cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc)
MustTrue(t, cancelFunc != nil)
cancelFunc()
}),
)

for i := int32(0); i < conns; i++ {
var conn, err = DialConnection(network, address, time.Second)
MustNil(t, err)

_, err = conn.Writer().WriteString(req)
MustNil(t, err)
err = conn.Writer().Flush()
MustNil(t, err)

err = conn.Close()
MustNil(t, err)
}
for atomic.LoadInt32(&closed) < conns {
t.Logf("closed: %d, canceled: %d", atomic.LoadInt32(&closed), atomic.LoadInt32(&canceled))
runtime.Gosched()
}
Equal(t, atomic.LoadInt32(&closed), conns)
Equal(t, atomic.LoadInt32(&canceled), conns)

err := loop.Shutdown(context.Background())
MustNil(t, err)
}

func TestGracefulExit(t *testing.T) {
var network, address = "tcp", ":8888"

Expand Down

0 comments on commit c3792e8

Please sign in to comment.