From e960d07f49adb83359f92543053b2efe1e35d182 Mon Sep 17 00:00:00 2001 From: kercylan98 Date: Mon, 25 Dec 2023 14:40:02 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20server=20=E5=8C=85=E6=96=B0=E5=A2=9E=20?= =?UTF-8?q?WithWebsocketUpgrade=20=E5=87=BD=E6=95=B0=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E8=87=AA=E5=AE=9A=E4=B9=89=20websocket.Upgrader?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/constants.go | 12 ++++++++++ server/options.go | 52 +++++++++++++++++++++++++++--------------- server/server.go | 43 +++++++++------------------------- server/service_test.go | 21 +++++++++++++---- 4 files changed, 73 insertions(+), 55 deletions(-) diff --git a/server/constants.go b/server/constants.go index 8bd70982..634dacef 100644 --- a/server/constants.go +++ b/server/constants.go @@ -1,6 +1,8 @@ package server import ( + "github.com/gorilla/websocket" + "net/http" "time" ) @@ -17,3 +19,13 @@ const ( DefaultDispatcherBufferSize = 1024 * 16 DefaultConnWriteBufferSize = 1024 * 1 ) + +func DefaultWebsocketUpgrader() *websocket.Upgrader { + return &websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } +} diff --git a/server/options.go b/server/options.go index 19dc7725..d10a30f8 100644 --- a/server/options.go +++ b/server/options.go @@ -2,6 +2,7 @@ package server import ( "github.com/gin-contrib/pprof" + "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/utils/log" "github.com/kercylan98/minotaur/utils/timer" "google.golang.org/grpc" @@ -30,25 +31,38 @@ type option struct { } type runtime struct { - deadlockDetect time.Duration // 是否开启死锁检测 - supportMessageTypes map[int]bool // websocket模式下支持的消息类型 - certFile, keyFile string // TLS文件 - tickerPool *timer.Pool // 定时器池 - ticker *timer.Ticker // 定时器 - tickerAutonomy bool // 定时器是否独立运行 - connTickerSize int // 连接定时器大小 - websocketReadDeadline time.Duration // websocket连接超时时间 - websocketCompression int // websocket压缩等级 - websocketWriteCompression bool // websocket写入压缩 - limitLife time.Duration // 限制最大生命周期 - packetWarnSize int // 数据包大小警告 - messageStatisticsDuration time.Duration // 消息统计时长 - messageStatisticsLimit int // 消息统计数量 - messageStatistics []*atomic.Int64 // 消息统计数量 - messageStatisticsLock *sync.RWMutex // 消息统计锁 - dispatcherBufferSize int // 消息分发器缓冲区大小 - connWriteBufferSize int // 连接写入缓冲区大小 - disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 + deadlockDetect time.Duration // 是否开启死锁检测 + supportMessageTypes map[int]bool // websocket 模式下支持的消息类型 + certFile, keyFile string // TLS文件 + tickerPool *timer.Pool // 定时器池 + ticker *timer.Ticker // 定时器 + tickerAutonomy bool // 定时器是否独立运行 + connTickerSize int // 连接定时器大小 + websocketReadDeadline time.Duration // websocket 连接超时时间 + websocketCompression int // websocket 压缩等级 + websocketWriteCompression bool // websocket 写入压缩 + limitLife time.Duration // 限制最大生命周期 + packetWarnSize int // 数据包大小警告 + messageStatisticsDuration time.Duration // 消息统计时长 + messageStatisticsLimit int // 消息统计数量 + messageStatistics []*atomic.Int64 // 消息统计数量 + messageStatisticsLock *sync.RWMutex // 消息统计锁 + dispatcherBufferSize int // 消息分发器缓冲区大小 + connWriteBufferSize int // 连接写入缓冲区大小 + disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道 + websocketUpgrader *websocket.Upgrader // websocket 升级器 +} + +// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器 +// - 默认值为 DefaultWebsocketUpgrader +// - 该选项仅在创建 NetworkWebsocket 服务器时有效 +func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option { + return func(srv *Server) { + if srv.network != NetworkWebsocket { + return + } + srv.websocketUpgrader = upgrader + } } // WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器 diff --git a/server/server.go b/server/server.go index 9b84fcf0..6688cea4 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "github.com/kercylan98/minotaur/server/internal/logger" "github.com/kercylan98/minotaur/utils/concurrent" "github.com/kercylan98/minotaur/utils/log" @@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error { go func(conn *Conn) { defer func() { - if err := recover(); err != nil { - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - conn.Close(e) + if err := super.RecoverTransform(recover()); err != nil { + conn.Close(err) } }() @@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error { pattern = addr[index:] slf.addr = slf.addr[:index] } - var upgrade = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - return true - }, + if slf.websocketUpgrader == nil { + slf.websocketUpgrader = DefaultWebsocketUpgrader() } http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) { ip := request.Header.Get("X-Real-IP") - ws, err := upgrade.Upgrade(writer, request, nil) + ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil) if err != nil { return } @@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error { slf.OnConnectionOpenedEvent(conn) defer func() { - if err := recover(); err != nil { - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - conn.Close(e) + if err := super.RecoverTransform(recover()); err != nil { + conn.Close(err) } }() for !conn.IsClosed() { @@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync { defer func(msg *Message) { super.Handle(cancel) - if err := recover(); err != nil { + if err := super.RecoverTransform(recover()); err != nil { stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack)) fmt.Println(stack) - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - slf.OnMessageErrorEvent(msg, e) + slf.OnMessageErrorEvent(msg, err) } if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback { dispatcher.antiUnique(msg.name) @@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) { case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync: if err := slf.ants.Submit(func() { defer func() { - if err := recover(); err != nil { + if err := super.RecoverTransform(recover()); err != nil { if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync { dispatcher.antiUnique(msg.name) } stack := string(debug.Stack()) log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack)) fmt.Println(stack) - e, ok := err.(error) - if !ok { - e = fmt.Errorf("%v", err) - } - slf.OnMessageErrorEvent(msg, e) + slf.OnMessageErrorEvent(msg, err) } super.Handle(cancel) slf.low(msg, present, time.Second) diff --git a/server/service_test.go b/server/service_test.go index a3bae8e0..1ffb7104 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -1,21 +1,21 @@ package server_test import ( + "fmt" "github.com/kercylan98/minotaur/server" "testing" "time" ) -type TestService struct { -} +type TestService struct{} func (ts *TestService) OnInit(srv *server.Server) { srv.RegStartFinishEvent(func(srv *server.Server) { - println("Server started") + fmt.Println("server start finish") }) srv.RegStopEvent(func(srv *server.Server) { - println("Server stopped") + fmt.Println("server stop") }) } @@ -24,7 +24,20 @@ func TestBindService(t *testing.T) { server.BindService(srv, new(TestService)) + if err := srv.RunNone(); err != nil { + t.Fatal(err) + } +} + +func ExampleBindService() { + srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second)) + server.BindService(srv, new(TestService)) + if err := srv.RunNone(); err != nil { panic(err) } + + // Output: + // server start finish + // server stop }