-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcp.go
124 lines (112 loc) · 2.96 KB
/
tcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package greenlight
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"regexp"
"time"
)
var (
DefaultTCPMaxBytes = 32 * 1024
)
type TCPCheckConfig struct {
Host string `yaml:"host"`
Port string `yaml:"port"`
Send string `yaml:"send"`
Quit string `yaml:"quiet"`
MaxBytes int `yaml:"max_bytes"`
ExpectPattern string `yaml:"expect_pattern"`
TLS bool `yaml:"tls"`
NoCheckCertificate bool `yaml:"no_check_certificate"`
}
type TCPChecker struct {
Host string
Port string
Send string
Quit string
MaxBytes int
ExpectPattern *regexp.Regexp
Timeout time.Duration
TLS bool
NoCheckCertificate bool
name string
}
func (p *TCPChecker) Name() string {
return p.name
}
func NewTCPChecker(cfg *CheckConfig) (*TCPChecker, error) {
p := &TCPChecker{
name: cfg.Name,
Timeout: cfg.Timeout,
MaxBytes: cfg.TCP.MaxBytes,
TLS: cfg.TCP.TLS,
NoCheckCertificate: cfg.TCP.NoCheckCertificate,
Host: cfg.TCP.Host,
Port: cfg.TCP.Port,
Send: cfg.TCP.Send,
}
if cfg.TCP.ExpectPattern != "" {
pt, err := regexp.Compile(cfg.TCP.ExpectPattern)
if err != nil {
return nil, fmt.Errorf("invalid expect_pattern: %w", err)
}
p.ExpectPattern = pt
}
if p.MaxBytes == 0 {
p.MaxBytes = DefaultTCPMaxBytes
}
return p, nil
}
func (p *TCPChecker) Run(ctx context.Context) error {
logger := newLoggerFromContext(ctx).With("name", p.name, "module", "tcpchecker")
ctx, cancel := context.WithTimeout(ctx, p.Timeout)
defer cancel()
addr := net.JoinHostPort(p.Host, p.Port)
conn, err := dialTCP(ctx, addr, p.TLS, p.NoCheckCertificate, p.Timeout)
if err != nil {
return fmt.Errorf("tcp connect failed: %w", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(p.Timeout))
logger.Debug("connected " + addr)
if p.Send != "" {
logger.Debug("send " + p.Send)
_, err := io.WriteString(conn, p.Send)
if err != nil {
return fmt.Errorf("tcp send failed: %w", err)
}
}
if p.ExpectPattern != nil {
buf := make([]byte, p.MaxBytes)
r := bufio.NewReader(conn)
n, err := r.Read(buf)
if err != nil {
return fmt.Errorf("tcp read failed: %w", err)
}
logger.Debug("read" + string(buf[:n]))
if !p.ExpectPattern.Match(buf[:n]) {
return fmt.Errorf("tcp unexpected response: %s", string(buf[:n]))
}
}
if p.Quit != "" {
logger.Debug("quit " + p.Quit)
io.WriteString(conn, p.Quit)
}
return nil
}
func dialTCP(ctx context.Context, address string, useTLS bool, noCheckCertificate bool, timeout time.Duration) (net.Conn, error) {
d := &net.Dialer{Timeout: timeout}
if useTLS {
td := &tls.Dialer{
NetDialer: d,
Config: &tls.Config{
InsecureSkipVerify: noCheckCertificate,
},
}
return td.DialContext(ctx, "tcp", address)
}
return d.DialContext(ctx, "tcp", address)
}