-
Notifications
You must be signed in to change notification settings - Fork 1
/
tcp_server.go
144 lines (129 loc) · 2.78 KB
/
tcp_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package tcpproxy
import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
)
// ErrServerClosed
var (
ErrServerClosed = errors.New("tcp: Server closed")
ErrAbortHandler = errors.New("tcp: abort TCPHandler")
ServerContextKey = &contextKey{"tcp-server"}
LocalAddrContextKey = &contextKey{"local-addr"}
)
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}
func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}
func (oc *onceCloseListener) close() {
oc.closeErr = oc.Listener.Close()
}
// TCPHandler TCP 处理器
type TCPHandler interface {
ServeTCP(ctx context.Context, conn net.Conn)
}
// TCPServer xx
type TCPServer struct {
Addr string
Handler TCPHandler
err error
BaseCtx context.Context
WriteTimeout time.Duration
ReadTimeout time.Duration
KeepAliveTimeout time.Duration
mu sync.Mutex
inShutdown int32
doneChan chan struct{}
l *onceCloseListener
}
func (s *TCPServer) shuttingDown() bool {
return atomic.LoadInt32(&s.inShutdown) != 0
}
// Close 关闭连接
func (s *TCPServer) Close() error {
atomic.StoreInt32(&s.inShutdown, 1)
close(s.doneChan) // 关闭channel
s.Close() // 执行 listener 关闭
return nil
}
// ListenAndServe 监听开启服务
func (s *TCPServer) ListenAndServe() error {
if s.shuttingDown() {
return ErrServerClosed
}
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
addr := s.Addr
if addr == "" {
return errors.New("need addr")
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return s.Serve(tcpKeepAliveListener{
ln.(*net.TCPListener),
})
}
// Serve 创建
func (s *TCPServer) Serve(l net.Listener) error {
s.l = &onceCloseListener{Listener: l}
defer s.l.Close() // 执行 listener 关闭
if s.BaseCtx == nil {
s.BaseCtx = context.Background()
}
baseCtx := s.BaseCtx
ctx := context.WithValue(baseCtx, ServerContextKey, s)
for {
rw, e := l.Accept()
if e != nil {
select {
case <-s.getDoneChan():
return ErrServerClosed
default:
}
fmt.Printf("accept fail., err:%v\n", e)
continue
}
c := s.newConn(rw)
go c.serve(ctx)
}
}
func (s *TCPServer) newConn(rwc net.Conn) *conn {
c := &conn{
server: s,
rwc: rwc,
}
// 设置参数
if d := c.server.ReadTimeout; d != 0 {
c.rwc.SetReadDeadline(time.Now().Add(d))
}
if d := c.server.WriteTimeout; d != 0 {
c.rwc.SetWriteDeadline(time.Now().Add(d))
}
if d := c.server.KeepAliveTimeout; d != 0 {
if tcpConn, ok := c.rwc.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(d)
}
}
return c
}
func (s *TCPServer) getDoneChan() <-chan struct{} {
s.mu.Lock()
defer s.mu.Unlock()
if s.doneChan == nil {
s.doneChan = make(chan struct{})
}
return s.doneChan
}