/
validate.go
82 lines (70 loc) · 2.06 KB
/
validate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package actions
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/mihaitodor/wormhole/config"
"github.com/mihaitodor/wormhole/transport"
)
type ValidateAction struct {
ActionBase
Scheme string `mapstructure:"scheme"`
Port uint `mapstructure:"port"`
UrlPath string `yaml:"url_path" mapstructure:"url_path"`
Retries uint `mapstructure:"retries"`
Timeout time.Duration `mapstructure:"timeout"`
StatusCode int `yaml:"status_code" mapstructure:"status_code"`
BodyContent string `yaml:"body_content" mapstructure:"body_content"`
}
func (a *ValidateAction) validate(ctx context.Context, req *http.Request) error {
ctx, timeoutFunc := context.WithTimeout(ctx, a.Timeout)
defer timeoutFunc()
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
if err != nil {
return fmt.Errorf("failed to execute request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != a.StatusCode {
return fmt.Errorf("expected status %d but got %d instead", a.StatusCode, resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %s", err)
}
if !strings.Contains(string(body), a.BodyContent) {
return errors.New("response does not contain expected content")
}
return nil
}
func (a *ValidateAction) Run(ctx context.Context, conn transport.Connection, _ config.Config) error {
host := conn.GetHost()
if a.Port != 0 {
host = fmt.Sprintf("%s:%d", host, a.Port)
}
u := url.URL{
Scheme: a.Scheme,
Host: host,
Path: a.UrlPath,
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return fmt.Errorf("failed to create http request: %s", err)
}
// Try to run and validate the request several times
retries := 1
if a.Retries > 0 {
retries = int(a.Retries)
}
for i := 0; i < retries; i++ {
err = a.validate(ctx, req)
if err == nil {
return nil
}
}
return fmt.Errorf("failed to validate %q after %d retries: %s", u.String(), a.Retries, err)
}