Skip to content

Commit

Permalink
修复Websocket消息类型过滤不设置时无法接收数据包的问题,服务器增加连接分流功能
Browse files Browse the repository at this point in the history
  • Loading branch information
kercylan98 committed May 15, 2023
1 parent df4aa30 commit 926b69b
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 29 deletions.
11 changes: 11 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package main

import "github.com/kercylan98/minotaur/server"

// 无意义的测试main入口
func main() {
srv := server.New(server.NetworkWebsocket, server.WithConnectPacketDiversion(3, 2))
if err := srv.Run(":8999"); err != nil {
panic(err)
}
}
50 changes: 36 additions & 14 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,51 @@ import (
"github.com/gorilla/websocket"
"github.com/panjf2000/gnet"
"github.com/xtaci/kcp-go/v5"
"net"
"strings"
)

// newKcpConn 创建一个处理KCP的连接
func newKcpConn(session *kcp.UDPSession) *Conn {
return &Conn{
ip: session.RemoteAddr().String(),
kcp: session,
c := &Conn{
remoteAddr: session.RemoteAddr(),
ip: session.RemoteAddr().String(),
kcp: session,
write: func(data []byte) error {
_, err := session.Write(data)
return err
},
data: map[any]any{},
}
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
return c
}

// newKcpConn 创建一个处理GNet的连接
func newGNetConn(conn gnet.Conn) *Conn {
return &Conn{
ip: conn.RemoteAddr().String(),
gn: conn,
c := &Conn{
remoteAddr: conn.RemoteAddr(),
ip: conn.RemoteAddr().String(),
gn: conn,
write: func(data []byte) error {
return conn.AsyncWrite(data)
},
data: map[any]any{},
}
if index := strings.LastIndex(c.ip, ":"); index != -1 {
c.ip = c.ip[0:index]
}
return c
}

// newKcpConn 创建一个处理WebSocket的连接
func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {
return &Conn{
ip: ip,
ws: ws,
remoteAddr: ws.RemoteAddr(),
ip: ip,
ws: ws,
write: func(data []byte) error {
return ws.WriteMessage(websocket.BinaryMessage, data)
},
Expand All @@ -45,15 +58,24 @@ func newWebsocketConn(ws *websocket.Conn, ip string) *Conn {

// Conn 服务器连接
type Conn struct {
ip string
ws *websocket.Conn
gn gnet.Conn
kcp *kcp.UDPSession
write func(data []byte) error
data map[any]any
remoteAddr net.Addr
ip string
ws *websocket.Conn
gn gnet.Conn
kcp *kcp.UDPSession
write func(data []byte) error
data map[any]any
}

func (slf *Conn) RemoteAddr() net.Addr {
return slf.remoteAddr
}

func (slf *Conn) GetID() string {
return slf.remoteAddr.String()
}

func (slf *Conn) GetIP() string {
return slf.ip
}

Expand Down
1 change: 1 addition & 0 deletions server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ var (
ErrWebsocketIllegalMessageType = errors.New("illegal message type")
ErrPleaseUseWebsocketHandle = errors.New("in Websocket mode, please use the RegConnectionReceiveWebsocketPacketEvent function to register")
ErrPleaseUseOrdinaryPacketHandle = errors.New("non Websocket mode, please use the RegConnectionReceivePacketEvent function to register")
ErrOnlySupportSocket = errors.New("only supports Socket programming")
)
6 changes: 5 additions & 1 deletion server/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ func (slf *event) RegConnectionOpenedEvent(handle ConnectionOpenedEventHandle) {
}

func (slf *event) OnConnectionOpenedEvent(conn *Conn) {
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
if len(slf.diversionMessageChannels) == 0 {
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()))
} else {
log.Debug("Server", zap.String("ConnectionOpened", conn.GetID()), zap.Int("Node", slf.diversionConsistency.PickNode(conn.GetID())))
}
for _, handle := range slf.connectionOpenedEventHandles {
handle(slf.Server, conn)
}
Expand Down
20 changes: 20 additions & 0 deletions server/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"github.com/kercylan98/minotaur/utils/hash"
"github.com/kercylan98/minotaur/utils/log"
"go.uber.org/zap"
"google.golang.org/grpc"
Expand All @@ -21,6 +22,25 @@ const (

type Option func(srv *Server)

// WithConnectPacketDiversion 通过连接数据包消息分流的方式创建服务器
// - 连接消息分流后数据包消息将会从其他消息类型中独立出来,并且由多个消息管道及协程进行处理
// - 默认不会进行消息分流
// - 需要注意并发编程
func WithConnectPacketDiversion(diversionNumber, channelSize int) Option {
return func(srv *Server) {
if srv.network == NetworkHttp || srv.network == NetworkGRPC {
log.Warn("WithConnectPacketDiversion", zap.String("Network", string(srv.network)), zap.Error(ErrOnlySupportSocket))
return
}
srv.diversionMessageChannels = make([]chan *message, diversionNumber)
srv.diversionConsistency = hash.NewConsistency(3)
for i := 0; i < diversionNumber; i++ {
srv.diversionMessageChannels[i] = make(chan *message, channelSize)
srv.diversionConsistency.AddNode(i + 1)
}
}
}

// WithTLS 通过安全传输层协议TLS创建服务器
// - 支持:Http、Websocket
func WithTLS(certFile, keyFile string) Option {
Expand Down
45 changes: 35 additions & 10 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/hash"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/synchronization"
"github.com/panjf2000/gnet"
Expand Down Expand Up @@ -62,14 +63,16 @@ type Server struct {
isShutdown atomic.Bool // 是否已关闭
closeChannel chan struct{} // 关闭信号

gServer *gNet // TCP或UDP模式下的服务器
messagePool *synchronization.Pool[*message] // 消息池
messagePoolSize int // 消息池大小
messageChannel chan *message // 消息管道
initMessageChannel bool // 消息管道是否已经初始化
multiple bool // 是否为多服务器模式下运行
prod bool // 是否为生产模式
core int // 消息处理核心数
gServer *gNet // TCP或UDP模式下的服务器
messagePool *synchronization.Pool[*message] // 消息池
messagePoolSize int // 消息池大小
messageChannel chan *message // 消息管道
initMessageChannel bool // 消息管道是否已经初始化
multiple bool // 是否为多服务器模式下运行
prod bool // 是否为生产模式
core int // 消息处理核心数
diversionMessageChannels []chan *message // 分流消息管道
diversionConsistency *hash.Consistency // 哈希一致性分流器
}

// Run 使用特定地址运行服务器
Expand Down Expand Up @@ -118,6 +121,15 @@ func (slf *Server) Run(addr string) error {
slf.dispatchMessage(message)
}
}()
go func() {
for i := 0; i < len(slf.diversionMessageChannels); i++ {
go func(channel chan *message) {
for message := range channel {
slf.dispatchMessage(message)
}
}(slf.diversionMessageChannels[i])
}
}()
}
}

Expand Down Expand Up @@ -249,7 +261,7 @@ func (slf *Server) Run(addr string) error {
if err != nil {
panic(err)
}
if !slf.supportMessageTypes[messageType] {
if len(slf.supportMessageTypes) > 0 && !slf.supportMessageTypes[messageType] {
panic(ErrWebsocketIllegalMessageType)
}
slf.PushMessage(MessageTypePacket, conn, packet, messageType)
Expand Down Expand Up @@ -311,6 +323,11 @@ func (slf *Server) IsDev() bool {
// Shutdown 停止运行服务器
func (slf *Server) Shutdown(err error) {
slf.isShutdown.Store(true)
if len(slf.diversionMessageChannels) > 0 {
for i := 0; i < len(slf.diversionMessageChannels); i++ {
close(slf.diversionMessageChannels[i])
}
}
if slf.initMessageChannel {
if slf.gServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand Down Expand Up @@ -364,7 +381,12 @@ func (slf *Server) PushMessage(messageType MessageType, attrs ...any) {
msg := slf.messagePool.Get()
msg.t = messageType
msg.attrs = attrs
slf.messageChannel <- msg
if messageType == MessageTypePacket && len(slf.diversionMessageChannels) > 0 {
conn := attrs[0].(*Conn)
slf.diversionMessageChannels[slf.diversionConsistency.PickNode(conn.ip)] <- msg
} else {
slf.messageChannel <- msg
}
}

// dispatchMessage 消息分发
Expand All @@ -381,6 +403,9 @@ func (slf *Server) dispatchMessage(msg *message) {
case MessageTypePacket:
if slf.network == NetworkWebsocket {
conn, packet, messageType := msg.t.deconstructWebSocketPacket(msg.attrs...)
if slf.diversionConsistency != nil {
slf.diversionConsistency.PickNode(conn)
}
slf.OnConnectionReceiveWebsocketPacketEvent(conn, packet, messageType)
} else {
conn, packet := msg.t.deconstructPacket(msg.attrs...)
Expand Down
14 changes: 10 additions & 4 deletions utils/hash/consistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ import (
"strings"
)

func NewConsistency(replicas int) *Consistency {
return &Consistency{
replicas: replicas,
}
}

// Consistency 一致性哈希生成
//
// https://blog.csdn.net/zhpCSDN921011/article/details/126845397
type Consistency struct {
Replicas int // 虚拟节点的数量
replicas int // 虚拟节点的数量
keys []int // 所有虚拟节点的哈希值
hashMap map[int]int // 虚拟节点的哈希值: 节点(虚拟节点映射到真实节点)
}
Expand All @@ -22,11 +28,11 @@ func (slf *Consistency) AddNode(keys ...int) {
if slf.hashMap == nil {
slf.hashMap = map[int]int{}
}
if slf.Replicas == 0 {
slf.Replicas = 3
if slf.replicas == 0 {
slf.replicas = 3
}
for _, key := range keys {
for i := 0; i < slf.Replicas; i++ {
for i := 0; i < slf.replicas; i++ {
// 计算虚拟节点哈希值
hash := int(crc32.ChecksumIEEE([]byte(strconv.Itoa(i) + strconv.Itoa(key))))

Expand Down

0 comments on commit 926b69b

Please sign in to comment.