This repository has been archived by the owner on Aug 24, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 11
/
uaa_client.go
289 lines (231 loc) · 6.75 KB
/
uaa_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
package auth
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/dvsekhvalnov/jose2go"
)
type Metrics interface {
NewGauge(name, unit string) func(value float64)
}
type HTTPClient interface {
Do(r *http.Request) (*http.Response, error)
}
type UAAClient struct {
httpClient HTTPClient
uaa *url.URL
log *log.Logger
publicKeys sync.Map
minimumRefreshInterval time.Duration
lastQueryTime int64
}
func NewUAAClient(
uaaAddr string,
httpClient HTTPClient,
m Metrics,
log *log.Logger,
opts ...UAAOption,
) *UAAClient {
u, err := url.Parse(uaaAddr)
if err != nil {
log.Fatalf("failed to parse UAA addr: %s", err)
}
u.Path = "token_keys"
c := &UAAClient{
uaa: u,
httpClient: httpClient,
log: log,
publicKeys: sync.Map{},
minimumRefreshInterval: 30 * time.Second,
}
for _, opt := range opts {
opt(c)
}
return c
}
type UAAOption func(c *UAAClient)
func WithMinimumRefreshInterval(interval time.Duration) UAAOption {
return func(c *UAAClient) {
c.minimumRefreshInterval = interval
}
}
func (c *UAAClient) RefreshTokenKeys() error {
lastQueryTime := atomic.LoadInt64(&c.lastQueryTime)
nextAllowedRefreshTime := time.Unix(0, lastQueryTime).Add(c.minimumRefreshInterval)
if time.Now().Before(nextAllowedRefreshTime) {
c.log.Printf(
"UAA TokenKey refresh throttled to every %s, try again in %s",
c.minimumRefreshInterval,
time.Until(nextAllowedRefreshTime).Round(time.Millisecond),
)
return nil
}
atomic.CompareAndSwapInt64(&c.lastQueryTime, lastQueryTime, time.Now().UnixNano())
req, err := http.NewRequest("GET", c.uaa.String(), nil)
if err != nil {
panic(fmt.Sprintf("failed to create request to UAA: %s", err))
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to get token keys from UAA: %s", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("got an invalid status code talking to UAA %v", resp.Status)
}
defer resp.Body.Close()
tokenKeys, err := unmarshalTokenKeys(resp.Body)
if err != nil {
return err
}
currentKeyIds := make(map[string]struct{})
c.publicKeys.Range(func(keyId, publicKey interface{}) bool {
currentKeyIds[keyId.(string)] = struct{}{}
return true
})
for _, tokenKey := range tokenKeys {
if tokenKey.Value == "" {
return fmt.Errorf("received an empty token key from UAA")
}
block, _ := pem.Decode([]byte(tokenKey.Value))
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("error parsing public key: %s", err)
}
publicKey, isRSAPublicKey := publicKeyInterface.(*rsa.PublicKey)
if !isRSAPublicKey {
return fmt.Errorf("did not get a valid RSA key from UAA: %s", err)
}
// always overwrite the key stored at keyId because:
// if you manually delete the UAA signing key from Credhub, UAA will
// generate a new key with the same (default) keyId, which is something
// along the lines of `key-1`
c.publicKeys.Store(tokenKey.KeyId, publicKey)
// update list of previously-known keys so that we can prune them
// if UAA no longer considers them valid
delete(currentKeyIds, tokenKey.KeyId)
}
for keyId := range currentKeyIds {
c.publicKeys.Delete(keyId)
}
return nil
}
type AlgorithmError struct {
Alg string
}
func (e AlgorithmError) Error() string {
return fmt.Sprintf("unsupported algorithm: %s", e.Alg)
}
type UnknownTokenKeyError struct {
Kid string
}
func (e UnknownTokenKeyError) Error() string {
return fmt.Sprintf("using unknown token key: %s", e.Kid)
}
func (c *UAAClient) Read(token string) (Oauth2ClientContext, error) {
if token == "" {
return Oauth2ClientContext{}, errors.New("missing token")
}
payload, _, err := jose.Decode(trimBearer(token), func(headers map[string]interface{}, payload string) interface{} {
if headers["alg"] != "RS256" {
return AlgorithmError{Alg: headers["alg"].(string)}
}
keyId := headers["kid"].(string)
publicKey, err := c.loadOrFetchPublicKey(keyId)
if err != nil {
return err
}
return publicKey
})
if err != nil {
switch err.(type) {
case AlgorithmError, UnknownTokenKeyError:
// no-op
default:
// we're specifically trying to catch "crypto/rsa: verification error",
// which generally means we've tried to decode a token with the
// wrong private key. just in case UAA has rolled the key, but
// kept the same keyId, let's renew our keys.
go c.RefreshTokenKeys()
}
return Oauth2ClientContext{}, fmt.Errorf("failed to decode token: %s", err.Error())
}
decodedToken, err := decodeToken(strings.NewReader(payload))
if err != nil {
return Oauth2ClientContext{}, fmt.Errorf("failed to unmarshal token: %s", err.Error())
}
if time.Now().After(decodedToken.ExpTime) {
return Oauth2ClientContext{}, fmt.Errorf("token is expired, exp = %s", decodedToken.ExpTime)
}
var isAdmin bool
for _, scope := range decodedToken.Scope {
if scope == "doppler.firehose" || scope == "logs.admin" {
isAdmin = true
}
}
return Oauth2ClientContext{
IsAdmin: isAdmin,
Token: token,
ExpiresAt: decodedToken.ExpTime,
}, err
}
func (c *UAAClient) loadOrFetchPublicKey(keyId string) (*rsa.PublicKey, error) {
publicKey, ok := c.publicKeys.Load(keyId)
if ok {
return (publicKey.(*rsa.PublicKey)), nil
}
c.RefreshTokenKeys()
publicKey, ok = c.publicKeys.Load(keyId)
if ok {
return (publicKey.(*rsa.PublicKey)), nil
}
return nil, UnknownTokenKeyError{Kid: keyId}
}
var bearerRE = regexp.MustCompile(`(?i)^bearer\s+`)
func trimBearer(authToken string) string {
return bearerRE.ReplaceAllString(authToken, "")
}
// TODO: move key processing to a method of tokenKey
type tokenKey struct {
KeyId string `json:"kid"`
Value string `json:"value"`
}
type tokenKeys struct {
Keys []tokenKey `json:"keys"`
}
func unmarshalTokenKeys(r io.Reader) ([]tokenKey, error) {
var dtks tokenKeys
if err := json.NewDecoder(r).Decode(&dtks); err != nil {
return []tokenKey{}, fmt.Errorf("unable to decode json token keys from UAA: %s", err)
}
return dtks.Keys, nil
}
type decodedToken struct {
Value string `json:"value"`
Scope []string `json:"scope"`
Exp float64 `json:"exp"`
ExpTime time.Time `json:"-"`
}
func decodeToken(r io.Reader) (decodedToken, error) {
var dt decodedToken
if err := json.NewDecoder(r).Decode(&dt); err != nil {
return decodedToken{}, fmt.Errorf("unable to decode json token from UAA: %s", err)
}
dt.ExpTime = time.Unix(int64(dt.Exp), 0)
return dt, nil
}