Skip to content

Commit

Permalink
fix: 修复 server 包异步分流消息的回调函数在取消分流渠道绑定后会在系统分流渠道执行的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
kercylan98 committed Jan 12, 2024
1 parent 3b71eca commit e760ef2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
16 changes: 13 additions & 3 deletions server/internal/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,19 @@ func (d *Dispatcher[P, M]) IncrCount(producer P, i int64) {
d.lock.Lock()
defer d.lock.Unlock()
d.mc += i
d.pmc[producer] += i
if d.expel && d.mc <= 0 {
close(d.abort)
pmc := d.pmc[producer] + i
d.pmc[producer] = pmc
if d.mc <= 0 {
if f := d.pmcF[producer]; f != nil && pmc <= 0 {
func(producer P) {
defer func(producer P) {
if err := super.RecoverTransform(recover()); err != nil {
log.Error("Dispatcher.ProducerDoneHandler", log.Any("producer", producer), log.Err(err))
}
}(producer)
f(producer, &Action[P, M]{d: d, unlock: true})
}(producer)
}
}
}

Expand Down
15 changes: 12 additions & 3 deletions server/message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"github.com/kercylan98/minotaur/server/internal/dispatcher"
"github.com/kercylan98/minotaur/utils/collection"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/super"
Expand Down Expand Up @@ -75,16 +76,23 @@ func HasMessageType(mt MessageType) bool {

// Message 服务器消息
type Message struct {
producer string
dis *dispatcher.Dispatcher[string, *Message] // 指定消息发送到特定的分发器
conn *Conn
err error
ordinaryHandler func()
exceptionHandler func() error
errHandler func(err error)
marks []log.Field
packet []byte
err error
producer string
name string
t MessageType
marks []log.Field
}

// bindDispatcher 绑定分发器
func (slf *Message) bindDispatcher(dis *dispatcher.Dispatcher[string, *Message]) *Message {
slf.dis = dis
return slf
}

func (slf *Message) GetProducer() string {
Expand All @@ -103,6 +111,7 @@ func (slf *Message) reset() {
slf.t = 0
slf.marks = nil
slf.producer = ""
slf.dis = nil
}

// MessageType 返回消息类型
Expand Down
67 changes: 36 additions & 31 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,8 @@ func (srv *Server) GetMessageCount() int64 {
}

// UseShunt 切换连接所使用的消息分流渠道,当分流渠道 name 不存在时将会创建一个新的分流渠道,否则将会加入已存在的分流渠道
// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道时,将会使用指定的消息分流渠道进行消息分发
// - 默认情况下,所有连接都使用系统通道进行消息分发,当指定消息分流渠道且为分流消息类型时,将会使用指定的消息分流渠道进行消息分发
// - 分流渠道会在连接断开时标记为驱逐状态,当分流渠道中的所有消息处理完毕且没有新连接使用时,将会被清除
//
// 一些有趣的情况:
// - 当连接发送异步消息时,消息会被分为两部分,分别是异步部分和回调部分。异步部分会在当前的分流渠道中处理,而回调部分则是根据回调时所在的分流渠道进行处理
func (srv *Server) UseShunt(conn *Conn, name string) {
srv.dispatcherMgr.BindProducer(conn.GetID(), name)
}
Expand All @@ -324,15 +321,17 @@ func (srv *Server) pushMessage(message *Message) {
srv.messagePool.Release(message)
return
}
var d *dispatcher.Dispatcher[string, *Message]
switch message.t {
case MessageTypePacket,
MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback,
MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback,
MessageTypeShunt:
d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID())
case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker:
d = srv.dispatcherMgr.GetSystemDispatcher()
var d = message.dis
if d == nil {
switch message.t {
case MessageTypePacket,
MessageTypeShuntTicker, MessageTypeShuntAsync, MessageTypeShuntAsyncCallback,
MessageTypeUniqueShuntAsync, MessageTypeUniqueShuntAsyncCallback,
MessageTypeShunt:
d = srv.dispatcherMgr.GetDispatcher(message.conn.GetID())
case MessageTypeSystem, MessageTypeAsync, MessageTypeUniqueAsync, MessageTypeAsyncCallback, MessageTypeUniqueAsyncCallback, MessageTypeTicker:
d = srv.dispatcherMgr.GetSystemDispatcher()
}
}
if d == nil {
return
Expand Down Expand Up @@ -403,8 +402,12 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string,
fmt.Println(stack)
srv.OnMessageErrorEvent(msg, err)
}
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
switch msg.t {
case MessageTypeAsyncCallback, MessageTypeShuntAsyncCallback:
dispatcherIns.IncrCount(msg.producer, -1)
case MessageTypeUniqueAsyncCallback, MessageTypeUniqueShuntAsyncCallback:
dispatcherIns.AntiUnique(msg.name)
dispatcherIns.IncrCount(msg.producer, -1)
}

srv.low(msg, present, time.Millisecond*100)
Expand Down Expand Up @@ -455,25 +458,27 @@ func (srv *Server) dispatchMessage(dispatcherIns *dispatcher.Dispatcher[string,
}(cancel, srv, dispatcherIns, msg, present)
var err error
if msg.exceptionHandler != nil {
dispatcherIns.IncrCount(msg.producer, 1)
err = msg.exceptionHandler()
}
if msg.errHandler != nil {
if msg.conn == nil {
if msg.t == MessageTypeUniqueAsync {
srv.PushUniqueAsyncCallbackMessage(msg.name, err, msg.errHandler)
srv.pushUniqueAsyncCallbackMessage(dispatcherIns, msg.name, err, msg.errHandler)
return
}
srv.PushAsyncCallbackMessage(err, msg.errHandler)
srv.pushAsyncCallbackMessage(dispatcherIns, err, msg.errHandler)
return
}
if msg.t == MessageTypeUniqueShuntAsync {
srv.PushUniqueShuntAsyncCallbackMessage(msg.conn, msg.name, err, msg.errHandler)
srv.pushUniqueShuntAsyncCallbackMessage(dispatcherIns, msg.conn, msg.name, err, msg.errHandler)
return
}
srv.PushShuntAsyncCallbackMessage(msg.conn, err, msg.errHandler)
srv.pushShuntAsyncCallbackMessage(dispatcherIns, msg.conn, err, msg.errHandler)
return
}
dispatcherIns.AntiUnique(msg.name)
dispatcherIns.IncrCount(msg.producer, -1)
if err != nil {
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", string(debug.Stack())))
}
Expand Down Expand Up @@ -505,11 +510,11 @@ func (srv *Server) PushAsyncMessage(caller func() error, callback func(err error
srv.pushMessage(srv.messagePool.Get().castToAsyncMessage(caller, callback, mark...))
}

// PushAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息
// pushAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息
// - 异步消息回调将会通过一个接收 error 的函数进行处理,该函数将在系统分发器中执行
// - mark 为可选的日志标记,当发生异常时,将会在日志中进行体现
func (srv *Server) PushAsyncCallbackMessage(err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...))
func (srv *Server) pushAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToAsyncCallbackMessage(err, callback, mark...).bindDispatcher(dis))
}

// PushShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致
Expand All @@ -519,10 +524,10 @@ func (srv *Server) PushShuntAsyncMessage(conn *Conn, caller func() error, callba
srv.pushMessage(srv.messagePool.Get().castToShuntAsyncMessage(conn, caller, callback, mark...))
}

// PushShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
// - 需要注意的是,当未指定 UseShunt 时,将会通过 PushAsyncCallbackMessage 进行转发
func (srv *Server) PushShuntAsyncCallbackMessage(conn *Conn, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...))
// pushShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
// - 需要注意的是,当未指定 UseShunt 时,将会通过 pushAsyncCallbackMessage 进行转发
func (srv *Server) pushShuntAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], conn *Conn, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToShuntAsyncCallbackMessage(conn, err, callback, mark...).bindDispatcher(dis))
}

// PushPacketMessage 向服务器中推送 MessageTypePacket 消息
Expand Down Expand Up @@ -558,9 +563,9 @@ func (srv *Server) PushUniqueAsyncMessage(unique string, caller func() error, ca
srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncMessage(unique, caller, callback, mark...))
}

// PushUniqueAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
func (srv *Server) PushUniqueAsyncCallbackMessage(unique string, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...))
// pushUniqueAsyncCallbackMessage 向服务器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
func (srv *Server) pushUniqueAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], unique string, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToUniqueAsyncCallbackMessage(unique, err, callback, mark...).bindDispatcher(dis))
}

// PushUniqueShuntAsyncMessage 向特定分发器中推送 MessageTypeAsync 消息,消息执行与 MessageTypeAsync 一致
Expand All @@ -570,10 +575,10 @@ func (srv *Server) PushUniqueShuntAsyncMessage(conn *Conn, unique string, caller
srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncMessage(conn, unique, caller, callback, mark...))
}

// PushUniqueShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
// pushUniqueShuntAsyncCallbackMessage 向特定分发器中推送 MessageTypeAsyncCallback 消息,消息执行与 MessageTypeAsyncCallback 一致
// - 需要注意的是,当未指定 UseShunt 时,将会通过系统分流渠道进行转发
func (srv *Server) PushUniqueShuntAsyncCallbackMessage(conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...))
func (srv *Server) pushUniqueShuntAsyncCallbackMessage(dis *dispatcher.Dispatcher[string, *Message], conn *Conn, unique string, err error, callback func(err error), mark ...log.Field) {
srv.pushMessage(srv.messagePool.Get().castToUniqueShuntAsyncCallbackMessage(conn, unique, err, callback, mark...).bindDispatcher(dis))
}

// PushShuntMessage 向特定分发器中推送 MessageTypeShunt 消息,消息执行与 MessageTypeSystem 一致,不同的是将会在特定分发器中执行
Expand Down

0 comments on commit e760ef2

Please sign in to comment.