/
process.go
152 lines (137 loc) · 3.26 KB
/
process.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/**
* @Author: cyj19
* @Date: 2022/2/24 14:49
*/
package server
import (
"github.com/cyj19/sparrow/codec"
"github.com/cyj19/sparrow/compressor"
"github.com/cyj19/sparrow/protocol"
"io"
"log"
"net"
"reflect"
)
// 处理请求
func (s *Server) process(conn net.Conn) {
defer conn.Close()
sChannel := NewSendChannel(s.Option.SendChannelSize)
defer func() {
sChannel.Close()
}()
// 回复消息
go func() {
loop:
for {
select {
case respMsg, ok := <-sChannel.Ch:
if !ok {
break loop
}
// 写入响应
_, err := conn.Write(respMsg)
if err != nil {
log.Printf("conn.Write error:%v", err)
break loop
}
log.Println("finish write message...")
}
}
}()
// 读取消息
for {
message, err := protocol.DecodeMessage(conn)
if err != nil {
// 说明连接被对端关闭了
if err == io.EOF {
log.Printf("ip: %s close", conn.RemoteAddr())
break
}
//log.Printf("protocol.DecodeMessage error:%v", err)
break
}
go s.handleRequest(sChannel, message)
}
}
func (s *Server) handleRequest(sChannel *SendChannel, reqMsg *protocol.Message) {
compressorType := compressor.CompressorType(reqMsg.Header.CompressorType)
compressPlugin, ex := compressor.Get(compressorType)
if !ex {
log.Println("rpc not have this compressor type")
return
}
cType := codec.CodecType(reqMsg.Header.CodecType)
codecPlugin, ok := codec.Get(cType)
if !ok {
log.Println("rpc not have this codecType")
return
}
serviceName := reqMsg.Body.ServiceName
serviceMethod := reqMsg.Body.ServiceMethod
// 获取服务实例
srv, ok := s.serviceMap[serviceName]
if !ok {
log.Printf("the service:%s is not register", serviceName)
return
}
method, ok := srv.methodMap[serviceMethod]
if !ok {
log.Printf("the method:%s is not register", serviceMethod)
return
}
// 创建参数实例
argVal := reflect.New(method.argType.Elem()).Interface()
replyVal := reflect.New(method.replyType.Elem()).Interface()
// 解压
var err error
reqMsg.Body.Payload, err = compressPlugin.Unzip(reqMsg.Body.Payload)
if err != nil {
log.Printf("server compressor.Unzip error:%#v", err)
}
// 反序列化
err = codecPlugin.Decode(reqMsg.Body.Payload, argVal)
if err != nil {
log.Printf("server codecPlugin.Decode error:%v", err)
return
}
// 调用方法
reflectValues := method.method.Func.Call([]reflect.Value{srv.refVal, reflect.ValueOf(argVal), reflect.ValueOf(replyVal)})
errorVal := reflectValues[0].Interface()
if errorVal != nil {
// 调用失败
log.Printf("%s.%s error:%v", serviceName, serviceMethod, errorVal)
return
}
// 写消息
msg := &protocol.Message{
Header: reqMsg.Header,
}
body := &protocol.Body{
Magic: reqMsg.Body.Magic,
ServiceName: serviceName,
ServiceMethod: serviceMethod,
}
// 序列化
body.Payload, err = codecPlugin.Encode(replyVal)
if err != nil {
log.Printf("codecPlugin.Encode error:%v", err)
return
}
// 压缩
body.Payload, err = compressPlugin.Zip(body.Payload)
if err != nil {
log.Printf("compressPlugin.Encode error:%v", err)
return
}
msg.Body = body
msgData, err := protocol.EncodeMessage(msg)
if err != nil {
log.Printf("protocol.EncodeMessage error:%v", err)
return
}
err = sChannel.Send(msgData)
if err != nil {
log.Println(err)
return
}
}