This repository has been archived by the owner on Sep 15, 2022. It is now read-only.
/
transport.go
185 lines (161 loc) · 4.88 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
package lbtransport
import (
"context"
"io"
"net"
"net/http"
"sync"
http_ctxtags "github.com/improbable-eng/go-httpwares/tags"
"github.com/improbable-eng/kedge/pkg/http/ctxtags"
"github.com/improbable-eng/kedge/pkg/reporter"
"github.com/improbable-eng/kedge/pkg/reporter/errtypes"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc/naming"
)
type tripper struct {
targetName string
parent http.RoundTripper
policy LBPolicy
mu sync.RWMutex
currentTargets []*Target
irrecoverableErr error
}
var (
failedDialsCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "kedge",
Subsystem: "http_lbtransport",
Name: "failed_dials",
Help: "Total number of failed dials that are in resolver and should blacklist the target.",
},
[]string{"resolve_addr", "target"},
)
)
func init() {
prometheus.MustRegister(failedDialsCounter)
}
// New creates a new load-balanced Round Tripper for a single backend.
//
// This RoundTripper is meant to only dial a single backend, and will throw errors if the req.URL.Host
// doesn't match the targetAddr.
//
// For resolving backend addresses it uses a grpc.naming.Resolver, allowing for generic use.
func New(ctx context.Context, targetAddr string, parent http.RoundTripper, resolver naming.Resolver, policy LBPolicy) (*tripper, error) {
s := &tripper{
targetName: targetAddr,
parent: parent,
policy: policy,
currentTargets: []*Target{},
}
watcher, err := resolver.Resolve(targetAddr)
if err != nil {
return nil, errors.Wrapf(err, "tripper: failed to do initial resolve for target %s", targetAddr)
}
go func() {
<-ctx.Done()
watcher.Close()
}()
go s.run(ctx, watcher)
return s, nil
}
func (s *tripper) run(ctx context.Context, watcher naming.Watcher) {
var localCurrentTargets []*Target
for ctx.Err() == nil {
updates, err := watcher.Next() // blocking call until new updates are there
if err != nil {
// Watcher next errors are irrecoverable.
s.mu.Lock()
s.irrecoverableErr = err
s.currentTargets = []*Target{}
s.mu.Unlock()
return
}
for _, u := range updates {
if u.Op == naming.Add {
localCurrentTargets = append(localCurrentTargets, &Target{DialAddr: u.Addr})
} else if u.Op == naming.Delete {
var kept []*Target
for _, t := range localCurrentTargets {
if u.Addr != t.DialAddr {
kept = append(kept, t)
}
}
localCurrentTargets = kept
}
}
s.mu.Lock()
s.currentTargets = localCurrentTargets
s.mu.Unlock()
}
}
func (s *tripper) RoundTrip(r *http.Request) (*http.Response, error) {
tags := http_ctxtags.ExtractInbound(r)
tags.Set(ctxtags.TagForBackendTarget, s.targetName)
s.mu.RLock()
targetsRef := s.currentTargets
irrecoverableErr := s.irrecoverableErr
s.mu.RUnlock()
if irrecoverableErr != nil {
err := errors.Wrapf(irrecoverableErr, "lb: critical naming.Watcher error for target %s. Tripper is closed.", s.targetName)
reporter.Extract(r).ReportError(errtypes.IrrecoverableWatcherError, err)
closeIfNotNil(r.Body)
return nil, err
}
if len(targetsRef) == 0 {
err := errors.Errorf("lb: no backend is available for %s. 0 resolved addresses.", s.targetName)
reporter.Extract(r).ReportError(errtypes.NoResolutionAvailable, err)
closeIfNotNil(r.Body)
return nil, err
}
if r.Body != nil {
// We have to own the body for the request because we cannot reuse same reader closer
// in multiple calls to http.Transport.
body := r.Body
defer closeIfNotNil(body)
r.Body = newReplayableReader(body)
}
picker := s.policy.Picker()
for {
target, err := picker.Pick(r, targetsRef)
if err != nil {
err = errors.Wrapf(err, "lb: failed choosing valid target for %s", s.targetName)
reporter.Extract(r).ReportError(errtypes.NoConnToAllResolvedAddresses, err)
return nil, err
}
// Override the host for downstream Tripper, usually http.DefaultTransport.
// http.Default transport uses `URL.Host` for Dial(<host>) and relevant connection pooling.
// We override it to make sure it enters the appropriate dial method and the appropriate connection pool.
// See http.connectMethodKey.
r.URL.Host = target.DialAddr
tags.Set(ctxtags.TagForTargetAddress, target.DialAddr)
if r.Body != nil {
r.Body.(*replayableReader).rewind()
}
resp, err := s.parent.RoundTrip(r)
if err == nil {
return resp, nil
}
if !isDialError(err) {
reporter.Extract(r).ReportError(errtypes.TransportUnknownError, err)
return resp, err
}
failedDialsCounter.WithLabelValues(s.targetName, target.DialAddr).Inc()
// Retry without this target.
// NOTE: We need to trust picker that it blacklist the targets well.
picker.ExcludeTarget(target)
}
}
func isDialError(err error) bool {
if opErr, ok := err.(*net.OpError); ok {
if opErr.Op == "dial" {
return true
}
}
return false
}
func closeIfNotNil(r io.Closer) {
if r != nil {
_ = r.Close()
}
}