Skip to content

Commit

Permalink
refactor: 优化 server 消息类型,合并 Websocket 数据包监听到统一的 RegConnectionReceiveP…
Browse files Browse the repository at this point in the history
…acketEvent 中
  • Loading branch information
kercylan98 committed Jul 7, 2023
1 parent 6d27433 commit 8b90307
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 223 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ require (
github.com/nats-io/nats-server/v2 v2.9.16 // indirect
github.com/nats-io/nkeys v0.4.4 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/panjf2000/ants/v2 v2.8.1 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/smartystreets/assertions v1.13.1 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/panjf2000/ants/v2 v2.4.7 h1:MZnw2JRyTJxFwtaMtUJcwE618wKD04POWk2gwwP4E2M=
github.com/panjf2000/ants/v2 v2.4.7/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A=
github.com/panjf2000/ants/v2 v2.8.1 h1:C+n/f++aiW8kHCExKlpX6X+okmxKXP7DWLutxuAPuwQ=
github.com/panjf2000/ants/v2 v2.8.1/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8=
github.com/panjf2000/gnet v1.6.6 h1:P6bApc54hnVcJVgH+SMe41mn47ECCajB6E/dKq27Y0c=
github.com/panjf2000/gnet v1.6.6/go.mod h1:KcOU7QsCaCBjeD5kyshBIamG3d9kAQtlob4Y0v0E+sc=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
Expand Down Expand Up @@ -228,6 +230,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
38 changes: 6 additions & 32 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,45 +156,29 @@ func (slf *Conn) IsWebsocket() bool {
return slf.server.network == NetworkWebsocket
}

// WriteString 向连接中写入字符串
// - 通过转换为[]byte调用 *Conn.Write
func (slf *Conn) WriteString(data string, messageType ...int) {
slf.Write([]byte(data), messageType...)
}

// WriteStringWithCallback 与 WriteString 相同,但是会在写入完成后调用 callback
// - 当 callback 为 nil 时,与 WriteString 相同
func (slf *Conn) WriteStringWithCallback(data string, callback func(err error), messageType ...int) {
slf.WriteWithCallback([]byte(data), callback, messageType...)
}

// Write 向连接中写入数据
// - messageType: websocket模式中指定消息类型
func (slf *Conn) Write(data []byte, messageType ...int) {
func (slf *Conn) Write(packet Packet) {
if slf.packetPool == nil {
return
}
cp := slf.packetPool.Get()
if len(messageType) > 0 {
cp.websocketMessageType = messageType[0]
}
cp.packet = data
cp.websocketMessageType = packet.WebsocketType
cp.packet = packet.Data
slf.mutex.Lock()
slf.packets = append(slf.packets, cp)
slf.mutex.Unlock()
}

// WriteWithCallback 与 Write 相同,但是会在写入完成后调用 callback
// - 当 callback 为 nil 时,与 Write 相同
func (slf *Conn) WriteWithCallback(data []byte, callback func(err error), messageType ...int) {
func (slf *Conn) WriteWithCallback(packet Packet, callback func(err error), messageType ...int) {
if slf.packetPool == nil {
return
}
cp := slf.packetPool.Get()
if len(messageType) > 0 {
cp.websocketMessageType = messageType[0]
}
cp.packet = data
cp.websocketMessageType = packet.WebsocketType
cp.packet = packet.Data
cp.callback = callback
slf.mutex.Lock()
slf.packets = append(slf.packets, cp)
Expand Down Expand Up @@ -233,18 +217,8 @@ func (slf *Conn) writeLoop(wait *sync.WaitGroup) {
slf.mutex.Unlock()
for i := 0; i < len(packets); i++ {
data := packets[i]
//if len(data.packet) == 0 {
// for _, packet := range packets {
// slf.packetPool.Release(packet)
// }
// slf.Close()
// return
//}
var err error
if slf.IsWebsocket() {
if data.websocketMessageType <= 0 {
data.websocketMessageType = slf.server.websocketWriteMessageType
}
err = slf.ws.WriteMessage(data.websocketMessageType, data.packet)
} else {
if slf.gn != nil {
Expand Down
26 changes: 8 additions & 18 deletions server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,12 @@ package server
import "errors"

var (
ErrConstructed = errors.New("the Server must be constructed using the server.New function")
ErrCanNotSupportNetwork = errors.New("can not support network")
ErrMessageTypePacketAttrs = errors.New("MessageTypePacket must contain *Conn and []byte")
ErrWebsocketMessageTypePacketAttrs = errors.New("MessageTypePacket must contain *Conn and []byte and int(MessageType)")
ErrMessageTypeErrorAttrs = errors.New("MessageTypePacket must contain error and MessageErrorAction")
ErrMessageTypeCrossErrorAttrs = errors.New("MessageTypeCross must contain int64(server id) and []byte")
ErrMessageTypeTickerErrorAttrs = errors.New("MessageTypeTicker must contain func()")
ErrNetworkOnlySupportHttp = errors.New("the current network mode is not compatible with HttpRouter, only NetworkHttp is supported")
ErrNetworkOnlySupportGRPC = errors.New("the current network mode is not compatible with RegGrpcServer, only NetworkGRPC is supported")
ErrNetworkIncompatibleHttp = errors.New("the current network mode is not compatible with NetworkHttp")
ErrWebsocketMessageTypeException = errors.New("unknown message type, will not work")
ErrNotWebsocketUseMessageType = errors.New("message type filtering only supports websocket and does not take effect")
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
ErrNoSupportCross = errors.New("the server does not support GetID or PushCrossMessage, please use the WithCross option to create the server")
ErrNoSupportTicker = errors.New("the server does not support Ticker, please use the WithTicker option to create the server")
ErrUnregisteredCrossName = errors.New("unregistered cross name, please use the WithCross option to create the server")
ErrConstructed = errors.New("the Server must be constructed using the server.New function")
ErrCanNotSupportNetwork = errors.New("can not support network")
ErrNetworkOnlySupportHttp = errors.New("the current network mode is not compatible with HttpRouter, only NetworkHttp is supported")
ErrNetworkOnlySupportGRPC = errors.New("the current network mode is not compatible with RegGrpcServer, only NetworkGRPC is supported")
ErrNetworkIncompatibleHttp = errors.New("the current network mode is not compatible with NetworkHttp")
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
ErrNoSupportCross = errors.New("the server does not support GetID or PushCrossMessage, please use the WithCross option to create the server")
ErrNoSupportTicker = errors.New("the server does not support Ticker, please use the WithTicker option to create the server")
)
53 changes: 13 additions & 40 deletions server/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
type StartBeforeEventHandle func(srv *Server)
type StartFinishEventHandle func(srv *Server)
type StopEventHandle func(srv *Server)
type ConnectionReceivePacketEventHandle func(srv *Server, conn *Conn, packet []byte)
type ConnectionReceiveWebsocketPacketEventHandle func(srv *Server, conn *Conn, packet []byte, messageType int)
type ConnectionReceivePacketEventHandle func(srv *Server, conn *Conn, packet Packet)
type ConnectionOpenedEventHandle func(srv *Server, conn *Conn)
type ConnectionClosedEventHandle func(srv *Server, conn *Conn, err any)
type ReceiveCrossPacketEventHandle func(srv *Server, senderServerId int64, packet []byte)
Expand All @@ -24,16 +23,15 @@ type ConsoleCommandEventHandle func(srv *Server)

type event struct {
*Server
startBeforeEventHandles []StartBeforeEventHandle
startFinishEventHandles []StartFinishEventHandle
stopEventHandles []StopEventHandle
connectionReceivePacketEventHandles []ConnectionReceivePacketEventHandle
connectionReceiveWebsocketPacketEventHandles []ConnectionReceiveWebsocketPacketEventHandle
connectionOpenedEventHandles []ConnectionOpenedEventHandle
connectionClosedEventHandles []ConnectionClosedEventHandle
receiveCrossPacketEventHandles []ReceiveCrossPacketEventHandle
messageErrorEventHandles []MessageErrorEventHandle
messageLowExecEventHandles []MessageLowExecEventHandle
startBeforeEventHandles []StartBeforeEventHandle
startFinishEventHandles []StartFinishEventHandle
stopEventHandles []StopEventHandle
connectionReceivePacketEventHandles []ConnectionReceivePacketEventHandle
connectionOpenedEventHandles []ConnectionOpenedEventHandle
connectionClosedEventHandles []ConnectionClosedEventHandle
receiveCrossPacketEventHandles []ReceiveCrossPacketEventHandle
messageErrorEventHandles []MessageErrorEventHandle
messageLowExecEventHandles []MessageLowExecEventHandle

consoleCommandEventHandles map[string][]ConsoleCommandEventHandle

Expand Down Expand Up @@ -147,34 +145,16 @@ func (slf *event) RegConnectionReceivePacketEvent(handle ConnectionReceivePacket
if slf.network == NetworkHttp {
panic(ErrNetworkIncompatibleHttp)
}
if slf.network == NetworkWebsocket {
panic(ErrPleaseUseWebsocketHandle)
}
slf.connectionReceivePacketEventHandles = append(slf.connectionReceivePacketEventHandles, handle)
log.Info("Server", zap.String("RegEvent", runtimes.CurrentRunningFuncName()), zap.String("handle", reflect.TypeOf(handle).String()))
}

func (slf *event) OnConnectionReceivePacketEvent(conn *Conn, packet []byte) {
func (slf *event) OnConnectionReceivePacketEvent(conn *Conn, packet Packet) {
for _, handle := range slf.connectionReceivePacketEventHandles {
handle(slf.Server, conn, packet)
}
}

// RegConnectionReceiveWebsocketPacketEvent 在接收到Websocket数据包时将立刻执行被注册的事件处理函数
func (slf *event) RegConnectionReceiveWebsocketPacketEvent(handle ConnectionReceiveWebsocketPacketEventHandle) {
if slf.network != NetworkWebsocket {
panic(ErrPleaseUseOrdinaryPacketHandle)
}
slf.connectionReceiveWebsocketPacketEventHandles = append(slf.connectionReceiveWebsocketPacketEventHandles, handle)
log.Info("Server", zap.String("RegEvent", runtimes.CurrentRunningFuncName()), zap.String("handle", reflect.TypeOf(handle).String()))
}

func (slf *event) OnConnectionReceiveWebsocketPacketEvent(conn *Conn, packet []byte, messageType int) {
for _, handle := range slf.connectionReceiveWebsocketPacketEventHandles {
handle(slf.Server, conn, packet, messageType)
}
}

// RegReceiveCrossPacketEvent 在接收到跨服数据包时将立即执行被注册的事件处理函数
func (slf *event) RegReceiveCrossPacketEvent(handle ReceiveCrossPacketEventHandle) {
slf.receiveCrossPacketEventHandles = append(slf.receiveCrossPacketEventHandles, handle)
Expand Down Expand Up @@ -215,15 +195,8 @@ func (slf *event) check() {
switch slf.network {
case NetworkHttp, NetworkGRPC:
default:
switch slf.network {
case NetworkWebsocket:
if len(slf.connectionReceiveWebsocketPacketEventHandles) == 0 {
log.Warn("Server", zap.String("ConnectionReceiveWebsocketPacketEvent", "invalid server, no packets processed"))
}
default:
if len(slf.connectionReceivePacketEventHandles) == 0 {
log.Warn("Server", zap.String("ConnectionReceivePacketEvent", "invalid server, no packets processed"))
}
if len(slf.connectionReceivePacketEventHandles) == 0 {
log.Warn("Server", zap.String("ConnectionReceivePacketEvent", "invalid server, no packets processed"))
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/gnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (slf *gNet) AfterWrite(c gnet.Conn, b []byte) {
}

func (slf *gNet) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) {
slf.Server.pushMessage(MessageTypePacket, c.Context().(*Conn), bytes.Clone(packet))
PushPacketMessage(slf.Server, c.Context().(*Conn), append(bytes.Clone(packet), 0))
return nil, gnet.None
}

Expand Down
112 changes: 24 additions & 88 deletions server/message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import "runtime/debug"
import (
"runtime/debug"
)

const (
// MessageTypePacket 数据包消息类型:该类型的数据将被发送到 ConnectionReceivePacketEvent 进行处理
Expand All @@ -24,6 +26,7 @@ var messageNames = map[MessageType]string{
MessageTypeError: "MessageTypeError",
MessageTypeCross: "MessageTypeCross",
MessageTypeTicker: "MessageTypeTicker",
MessageTypeAsync: "MessageTypeAsync",
}

const (
Expand Down Expand Up @@ -57,85 +60,6 @@ func (slf MessageType) String() string {
return messageNames[slf]
}

func (slf MessageType) deconstructWebSocketPacket(attrs ...any) (conn *Conn, packet []byte, messageType int) {
if len(attrs) != 3 {
panic(ErrWebsocketMessageTypePacketAttrs)
}
var ok bool
if conn, ok = attrs[0].(*Conn); !ok {
panic(ErrWebsocketMessageTypePacketAttrs)
}
if packet, ok = attrs[1].([]byte); !ok {
panic(ErrWebsocketMessageTypePacketAttrs)
}
if messageType, ok = attrs[2].(int); !ok {
panic(ErrWebsocketMessageTypePacketAttrs)
}
return
}

func (slf MessageType) deconstructPacket(attrs ...any) (conn *Conn, packet []byte) {
if len(attrs) != 2 {
panic(ErrMessageTypePacketAttrs)
}
var ok bool
if conn, ok = attrs[0].(*Conn); !ok {
panic(ErrMessageTypePacketAttrs)
}
if packet, ok = attrs[1].([]byte); !ok {
panic(ErrMessageTypePacketAttrs)
}
return
}

func (slf MessageType) deconstructError(attrs ...any) (err error, action MessageErrorAction, stack string) {
if len(attrs) != 3 {
panic(ErrMessageTypeErrorAttrs)
}
var ok bool
if err, ok = attrs[0].(error); !ok {
panic(ErrMessageTypeErrorAttrs)
}
if action, ok = attrs[1].(MessageErrorAction); !ok {
panic(ErrMessageTypeErrorAttrs)
}
stack = attrs[2].(string)
return
}

func (slf MessageType) deconstructCross(attrs ...any) (serverId int64, packet []byte) {
if len(attrs) != 2 {
panic(ErrMessageTypeCrossErrorAttrs)
}
var ok bool
if serverId, ok = attrs[0].(int64); !ok {
panic(ErrMessageTypeCrossErrorAttrs)
}
if packet, ok = attrs[1].([]byte); !ok {
panic(ErrMessageTypeCrossErrorAttrs)
}
return
}

func (slf MessageType) deconstructTicker(attrs ...any) (caller func()) {
if len(attrs) != 1 {
panic(ErrMessageTypeTickerErrorAttrs)
}
var ok bool
if caller, ok = attrs[0].(func()); !ok {
panic(ErrMessageTypeTickerErrorAttrs)
}
return
}

// PushWebsocketPacketMessage 向特定服务器中推送 WebsocketPacket 消息
func PushWebsocketPacketMessage(srv *Server, conn *Conn, packet []byte, messageType int) {
msg := srv.messagePool.Get()
msg.t = MessageTypePacket
msg.attrs = []any{conn, packet, messageType}
srv.pushMessage(msg)
}

// PushPacketMessage 向特定服务器中推送 Packet 消息
func PushPacketMessage(srv *Server, conn *Conn, packet []byte) {
msg := srv.messagePool.Get()
Expand All @@ -154,15 +78,27 @@ func PushErrorMessage(srv *Server, err error, action MessageErrorAction) {

// PushCrossMessage 向特定服务器中推送 Cross 消息
func PushCrossMessage(srv *Server, crossName string, serverId int64, packet []byte) {
if len(srv.cross) == 0 {
return
}
_, exist := srv.cross[crossName]
if !exist {
return
if serverId == srv.id {
msg := srv.messagePool.Get()
msg.t = MessageTypeCross
msg.attrs = []any{serverId, packet}
srv.pushMessage(msg)
} else {
if len(srv.cross) == 0 {
return
}
cross, exist := srv.cross[crossName]
if !exist {
return
}
_ = cross.PushMessage(serverId, packet)
}
}

// PushTickerMessage 向特定服务器中推送 Ticker 消息
func PushTickerMessage(srv *Server, caller func()) {
msg := srv.messagePool.Get()
msg.t = MessageTypeCross
msg.attrs = []any{serverId, packet}
msg.t = MessageTypeTicker
msg.attrs = []any{caller}
srv.pushMessage(msg)
}
Loading

0 comments on commit 8b90307

Please sign in to comment.