-
Notifications
You must be signed in to change notification settings - Fork 6
/
masquerade.go
182 lines (159 loc) · 4.82 KB
/
masquerade.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package fronted
import (
"fmt"
"net"
"net/http"
"sort"
"strings"
"sync"
"time"
)
const (
NumWorkers = 10 // number of worker goroutines for verifying
)
var (
defaultValidator = NewStatusCodeValidator([]int{403})
)
// CA represents a certificate authority
type CA struct {
CommonName string
Cert string // PEM-encoded
}
// Masquerade contains the data for a single masquerade host, including
// the domain and the root CA.
type Masquerade struct {
// Domain: the domain to use for domain fronting
Domain string
// IpAddress: pre-resolved ip address to use instead of Domain (if
// available)
IpAddress string
}
type masquerade struct {
Masquerade
// lastSucceeded: the most recent time at which this Masquerade succeeded
LastSucceeded time.Time
// id of DirectProvider that this masquerade is provided by
ProviderID string
mx sync.RWMutex
}
func (m *masquerade) lastSucceeded() time.Time {
m.mx.RLock()
defer m.mx.RUnlock()
return m.LastSucceeded
}
func (m *masquerade) markSucceeded() {
m.mx.Lock()
defer m.mx.Unlock()
m.LastSucceeded = time.Now()
}
func (m *masquerade) markFailed() {
m.mx.Lock()
defer m.mx.Unlock()
m.LastSucceeded = time.Time{}
}
// A Direct fronting provider configuration.
type Provider struct {
// Specific hostname mappings used for this provider.
// remaps certain requests to provider specific host names.
HostAliases map[string]string
// Allow unaliased pass-through of hostnames
// matching these patterns.
// eg "*.cloudfront.net" for cloudfront provider
// would permit all .cloudfront.net domains to
// pass through without alias. Only suffix
// patterns and exact matches are supported.
PassthroughPatterns []string
// Url used to vet masquerades for this provider
TestURL string
Masquerades []*Masquerade
// Optional response validator used to determine whether
// fronting succeeded for this provider. If the validator
// detects a failure for a given masquerade, it is discarded.
// The default validator is used if nil.
Validator ResponseValidator
}
// Create a Provider with the given details
func NewProvider(hosts map[string]string, testURL string, masquerades []*Masquerade, validator ResponseValidator, passthrough []string) *Provider {
d := &Provider{
HostAliases: make(map[string]string),
TestURL: testURL,
Masquerades: make([]*Masquerade, 0, len(masquerades)),
Validator: validator,
PassthroughPatterns: make([]string, 0, len(passthrough)),
}
for k, v := range hosts {
d.HostAliases[strings.ToLower(k)] = v
}
for _, m := range masquerades {
d.Masquerades = append(d.Masquerades, &Masquerade{Domain: m.Domain, IpAddress: m.IpAddress})
}
d.PassthroughPatterns = append(d.PassthroughPatterns, passthrough...)
return d
}
// Lookup the host alias for the given hostname for this provider
func (p *Provider) Lookup(hostname string) string {
// only consider the host porition if given a port as well.
if h, _, err := net.SplitHostPort(hostname); err == nil {
hostname = h
}
hostname = strings.ToLower(hostname)
if alias := p.HostAliases[hostname]; alias != "" {
return alias
}
for _, pt := range p.PassthroughPatterns {
pt = strings.ToLower(pt)
if strings.HasPrefix(pt, "*.") && strings.HasSuffix(hostname, pt[1:]) {
return hostname
} else if pt == hostname {
return hostname
}
}
return ""
}
// Validate a fronted response. Returns an error if the
// response failed to reach the origin, eg if the request
// was rejected by the provider.
func (p *Provider) ValidateResponse(res *http.Response) error {
if p.Validator != nil {
return p.Validator(res)
} else {
return defaultValidator(res)
}
}
// A validator for fronted responses. Returns an error if the
// response failed to reach the origin, eg if the request
// was rejected by the provider.
type ResponseValidator func(*http.Response) error
// Create a new ResponseValidator that rejects any response with
// a given list of http status codes.
func NewStatusCodeValidator(reject []int) ResponseValidator {
bad := make(map[int]bool)
for _, code := range reject {
bad[code] = true
}
return func(res *http.Response) error {
if bad[res.StatusCode] {
return fmt.Errorf("response status %d: %v", res.StatusCode, res.Status)
}
return nil
}
}
// slice of masquerade sorted by last vetted time
type sortedMasquerades []*masquerade
func (m sortedMasquerades) Len() int { return len(m) }
func (m sortedMasquerades) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m sortedMasquerades) Less(i, j int) bool {
if m[i].lastSucceeded().After(m[j].lastSucceeded()) {
return true
} else if m[j].lastSucceeded().After(m[i].lastSucceeded()) {
return false
} else {
return m[i].IpAddress < m[j].IpAddress
}
}
func (m sortedMasquerades) sortedCopy() sortedMasquerades {
c := make(sortedMasquerades, len(m))
copy(c, m)
sort.Sort(c)
return c
}