Skip to content

Commit

Permalink
feat: server 包新增 WithWebsocketUpgrade 函数,支持自定义 websocket.Upgrader
Browse files Browse the repository at this point in the history
  • Loading branch information
kercylan98 committed Dec 25, 2023
1 parent 7efe88a commit e960d07
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 55 deletions.
12 changes: 12 additions & 0 deletions server/constants.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"github.com/gorilla/websocket"
"net/http"
"time"
)

Expand All @@ -17,3 +19,13 @@ const (
DefaultDispatcherBufferSize = 1024 * 16
DefaultConnWriteBufferSize = 1024 * 1
)

func DefaultWebsocketUpgrader() *websocket.Upgrader {
return &websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
}
52 changes: 33 additions & 19 deletions server/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package server

import (
"github.com/gin-contrib/pprof"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/utils/log"
"github.com/kercylan98/minotaur/utils/timer"
"google.golang.org/grpc"
Expand Down Expand Up @@ -30,25 +31,38 @@ type option struct {
}

type runtime struct {
deadlockDetect time.Duration // 是否开启死锁检测
supportMessageTypes map[int]bool // websocket模式下支持的消息类型
certFile, keyFile string // TLS文件
tickerPool *timer.Pool // 定时器池
ticker *timer.Ticker // 定时器
tickerAutonomy bool // 定时器是否独立运行
connTickerSize int // 连接定时器大小
websocketReadDeadline time.Duration // websocket连接超时时间
websocketCompression int // websocket压缩等级
websocketWriteCompression bool // websocket写入压缩
limitLife time.Duration // 限制最大生命周期
packetWarnSize int // 数据包大小警告
messageStatisticsDuration time.Duration // 消息统计时长
messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁
dispatcherBufferSize int // 消息分发器缓冲区大小
connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
deadlockDetect time.Duration // 是否开启死锁检测
supportMessageTypes map[int]bool // websocket 模式下支持的消息类型
certFile, keyFile string // TLS文件
tickerPool *timer.Pool // 定时器池
ticker *timer.Ticker // 定时器
tickerAutonomy bool // 定时器是否独立运行
connTickerSize int // 连接定时器大小
websocketReadDeadline time.Duration // websocket 连接超时时间
websocketCompression int // websocket 压缩等级
websocketWriteCompression bool // websocket 写入压缩
limitLife time.Duration // 限制最大生命周期
packetWarnSize int // 数据包大小警告
messageStatisticsDuration time.Duration // 消息统计时长
messageStatisticsLimit int // 消息统计数量
messageStatistics []*atomic.Int64 // 消息统计数量
messageStatisticsLock *sync.RWMutex // 消息统计锁
dispatcherBufferSize int // 消息分发器缓冲区大小
connWriteBufferSize int // 连接写入缓冲区大小
disableAutomaticReleaseShunt bool // 是否禁用自动释放分流渠道
websocketUpgrader *websocket.Upgrader // websocket 升级器
}

// WithWebsocketUpgrade 通过指定 websocket.Upgrader 的方式创建服务器
// - 默认值为 DefaultWebsocketUpgrader
// - 该选项仅在创建 NetworkWebsocket 服务器时有效
func WithWebsocketUpgrade(upgrader *websocket.Upgrader) Option {
return func(srv *Server) {
if srv.network != NetworkWebsocket {
return
}
srv.websocketUpgrader = upgrader
}
}

// WithDisableAutomaticReleaseShunt 通过禁用自动释放分流渠道的方式创建服务器
Expand Down
43 changes: 11 additions & 32 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/kercylan98/minotaur/server/internal/logger"
"github.com/kercylan98/minotaur/utils/concurrent"
"github.com/kercylan98/minotaur/utils/log"
Expand Down Expand Up @@ -196,12 +195,8 @@ func (slf *Server) Run(addr string) error {

go func(conn *Conn) {
defer func() {
if err := recover(); err != nil {
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()

Expand Down Expand Up @@ -254,16 +249,12 @@ func (slf *Server) Run(addr string) error {
pattern = addr[index:]
slf.addr = slf.addr[:index]
}
var upgrade = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
if slf.websocketUpgrader == nil {
slf.websocketUpgrader = DefaultWebsocketUpgrader()
}
http.HandleFunc(pattern, func(writer http.ResponseWriter, request *http.Request) {
ip := request.Header.Get("X-Real-IP")
ws, err := upgrade.Upgrade(writer, request, nil)
ws, err := slf.websocketUpgrader.Upgrade(writer, request, nil)
if err != nil {
return
}
Expand All @@ -289,12 +280,8 @@ func (slf *Server) Run(addr string) error {
slf.OnConnectionOpenedEvent(conn)

defer func() {
if err := recover(); err != nil {
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
conn.Close(e)
if err := super.RecoverTransform(recover()); err != nil {
conn.Close(err)
}
}()
for !conn.IsClosed() {
Expand Down Expand Up @@ -734,15 +721,11 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
if msg.t != MessageTypeAsync && msg.t != MessageTypeUniqueAsync && msg.t != MessageTypeShuntAsync && msg.t != MessageTypeUniqueShuntAsync {
defer func(msg *Message) {
super.Handle(cancel)
if err := recover(); err != nil {
if err := super.RecoverTransform(recover()); err != nil {
stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.String("Info", msg.String()), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack)
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
slf.OnMessageErrorEvent(msg, err)
}
if msg.t == MessageTypeUniqueAsyncCallback || msg.t == MessageTypeUniqueShuntAsyncCallback {
dispatcher.antiUnique(msg.name)
Expand Down Expand Up @@ -782,18 +765,14 @@ func (slf *Server) dispatchMessage(dispatcher *dispatcher, msg *Message) {
case MessageTypeAsync, MessageTypeShuntAsync, MessageTypeUniqueAsync, MessageTypeUniqueShuntAsync:
if err := slf.ants.Submit(func() {
defer func() {
if err := recover(); err != nil {
if err := super.RecoverTransform(recover()); err != nil {
if msg.t == MessageTypeUniqueAsync || msg.t == MessageTypeUniqueShuntAsync {
dispatcher.antiUnique(msg.name)
}
stack := string(debug.Stack())
log.Error("Server", log.String("MessageType", messageNames[msg.t]), log.Any("error", err), log.String("stack", stack))
fmt.Println(stack)
e, ok := err.(error)
if !ok {
e = fmt.Errorf("%v", err)
}
slf.OnMessageErrorEvent(msg, e)
slf.OnMessageErrorEvent(msg, err)
}
super.Handle(cancel)
slf.low(msg, present, time.Second)
Expand Down
21 changes: 17 additions & 4 deletions server/service_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package server_test

import (
"fmt"
"github.com/kercylan98/minotaur/server"
"testing"
"time"
)

type TestService struct {
}
type TestService struct{}

func (ts *TestService) OnInit(srv *server.Server) {
srv.RegStartFinishEvent(func(srv *server.Server) {
println("Server started")
fmt.Println("server start finish")
})

srv.RegStopEvent(func(srv *server.Server) {
println("Server stopped")
fmt.Println("server stop")
})
}

Expand All @@ -24,7 +24,20 @@ func TestBindService(t *testing.T) {

server.BindService(srv, new(TestService))

if err := srv.RunNone(); err != nil {
t.Fatal(err)
}
}

func ExampleBindService() {
srv := server.New(server.NetworkNone, server.WithLimitLife(time.Second))
server.BindService(srv, new(TestService))

if err := srv.RunNone(); err != nil {
panic(err)
}

// Output:
// server start finish
// server stop
}

0 comments on commit e960d07

Please sign in to comment.