/
wsserver.go
158 lines (137 loc) · 4.44 KB
/
wsserver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package server
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/hoshinonyaruko/gensokyo-discord/Processor"
"github.com/hoshinonyaruko/gensokyo-discord/callapi"
"github.com/hoshinonyaruko/gensokyo-discord/config"
"github.com/hoshinonyaruko/gensokyo-discord/mylog"
"github.com/hoshinonyaruko/gensokyo-discord/wsclient"
)
type WebSocketServerClient struct {
Conn *websocket.Conn
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// 确保WebSocketServerClient实现了interfaces.WebSocketServerClienter接口
var _ callapi.WebSocketServerClienter = &WebSocketServerClient{}
// 使用闭包结构 因为gin需要c *gin.Context固定签名
func WsHandlerWithDependencies(s *discordgo.Session, p *Processor.Processors) gin.HandlerFunc {
return func(c *gin.Context) {
wsHandler(s, p, c)
}
}
// 处理正向ws客户端的连接
func wsHandler(s *discordgo.Session, p *Processor.Processors, c *gin.Context) {
// 先从请求头中尝试获取token
tokenFromHeader := c.Request.Header.Get("Authorization")
token := ""
if tokenFromHeader != "" {
if strings.HasPrefix(tokenFromHeader, "Token ") {
// 从 "Token " 后面提取真正的token值
token = strings.TrimPrefix(tokenFromHeader, "Token ")
} else if strings.HasPrefix(tokenFromHeader, "Bearer ") {
// 从 "Bearer " 后面提取真正的token值
token = strings.TrimPrefix(tokenFromHeader, "Bearer ")
} else {
// 直接使用token值
token = tokenFromHeader
}
} else {
// 如果请求头中没有token,则从URL参数中获取
token = c.Query("access_token")
}
// 获取配置中的有效 token
validToken := config.GetWsServerToken()
// 如果配置的 token 不为空,但提供的 token 为空或不匹配
if validToken != "" && (token == "" || token != validToken) {
if token == "" {
mylog.Printf("Connection failed due to missing token. Headers: %v", c.Request.Header)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Missing token"})
} else {
mylog.Printf("Connection failed due to incorrect token. Headers: %v, Provided token: %s", c.Request.Header, token)
c.JSON(http.StatusForbidden, gin.H{"error": "Incorrect token"})
}
return
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
mylog.Printf("Failed to set websocket upgrade: %+v", err)
return
}
clientIP := c.ClientIP()
mylog.Printf("WebSocket client connected. IP: %s", clientIP)
// 创建WebSocketServerClient实例
client := &WebSocketServerClient{
Conn: conn,
}
// 将此客户端添加到Processor的WsServerClients列表中
p.WsServerClients = append(p.WsServerClients, client)
// 获取botID
botID := config.GetAppID()
// 发送连接成功的消息
message := map[string]interface{}{
"meta_event_type": "lifecycle",
"post_type": "meta_event",
"self_id": botID,
"sub_type": "connect",
"time": int(time.Now().Unix()),
}
err = client.SendMessage(message)
if err != nil {
mylog.Printf("Error sending connection success message: %v\n", err)
}
// 在defer语句之前运行
defer func() {
// 移除客户端从WsServerClients
for i, wsClient := range p.WsServerClients {
if wsClient == client {
p.WsServerClients = append(p.WsServerClients[:i], p.WsServerClients[i+1:]...)
break
}
}
}()
//退出时候的清理
defer conn.Close()
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
mylog.Printf("Error reading message: %v", err)
return
}
if messageType == websocket.TextMessage {
processWSMessage(client, p, s)
}
}
}
func processWSMessage(client *WebSocketServerClient, msg []byte, s *discordgo.Session) {
var message callapi.ActionMessage
err := json.Unmarshal(msg, &message)
if err != nil {
mylog.Printf("Error unmarshalling message: %v, Original message: %s", err, string(msg))
return
}
mylog.Println("Received from WebSocket onebotv11 client:", wsclient.TruncateMessage(message, 500))
// 调用callapi
callapi.CallAPIFromDict(client, s, message)
}
// 发信息给client
func (c *WebSocketServerClient) SendMessage(message map[string]interface{}) error {
msgBytes, err := json.Marshal(message)
if err != nil {
mylog.Println("Error marshalling message:", err)
return err
}
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
}
func (client *WebSocketServerClient) Close() error {
return client.Conn.Close()
}