-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcp.go
71 lines (60 loc) · 1.53 KB
/
tcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package emqd
import (
"io"
"net"
"sync"
log "github.com/ericluj/elog"
"github.com/ericluj/emq/internal/common"
"github.com/ericluj/emq/internal/protocol"
)
type TCPServer struct {
emqd *EMQD
conns sync.Map
}
func (t *TCPServer) Handle(conn net.Conn) {
log.Infof("TCP: new client %s", conn.RemoteAddr())
// 获取连接传过来的协议名是否正确(可以方便未来的协议升级)
buf := make([]byte, common.ProtoMagicLen)
_, err := io.ReadFull(conn, buf)
if err != nil {
log.Errorf("ReadFull: %v", err)
conn.Close()
return
}
// 判断协议是否正确
pm := string(buf)
if pm != common.ProtoMagic {
log.Infof("client %s: bad protocol magic '%s'", conn.RemoteAddr(), pm)
conn.Close()
return
}
log.Infof("client %s: desired protocol magic '%s'", conn.RemoteAddr(), pm)
var prot protocol.Protocol
switch pm {
case common.ProtoMagic:
prot = &Protocol{emqd: t.emqd}
default:
err := protocol.SendFrameData(conn, common.FrameTypeError, common.BadProtocolBytes)
if err != nil {
log.Errorf("SendFrameData: %v", err)
}
conn.Close()
log.Infof("client %s: bad protocol magic '%s'", conn.RemoteAddr(), pm)
return
}
client := prot.NewClient(conn)
t.conns.Store(conn.RemoteAddr(), client)
// client处理工作
err = prot.IOLoop(client)
if err != nil {
log.Errorf("IOLoop: %v, client %s", err, conn.RemoteAddr())
}
t.conns.Delete((conn.RemoteAddr()))
conn.Close()
}
func (s *TCPServer) Close() {
s.conns.Range(func(k, v interface{}) bool {
v.(protocol.Client).Close()
return true
})
}