-
Notifications
You must be signed in to change notification settings - Fork 15
/
auth_token.go
136 lines (111 loc) · 3.81 KB
/
auth_token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package form3
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/google/uuid"
)
// #nosec G101
const tokenBasedAuthEndpoint = "/v1/oauth2/token"
type authJSONResponse struct {
AccessToken string `json:"access_token"`
}
type TokenOption func(*TokenTransport)
type TokenTransport struct {
clientSecret string
clientID uuid.UUID
token string
underlyingTransport http.RoundTripper
}
// Deprecated: token based authentication is deprecated and will be removed at
// some point in the future in favour of request signing.
func WithClientID(clientID uuid.UUID) TokenOption {
return func(t *TokenTransport) {
t.clientID = clientID
}
}
// Deprecated: token based authentication is deprecated and will be removed at
// some point in the future in favour of request signing.
func WithClientSecret(secret string) TokenOption {
return func(t *TokenTransport) {
t.clientSecret = secret
}
}
// Deprecated: token based authentication is deprecated and will be removed at
// some point in the future in favour of request signing.
func WithInitialToken(token string) TokenOption {
return func(t *TokenTransport) {
t.token = token
}
}
// Deprecated: token based authentication is deprecated and will be removed at
// some point in the future in favour of request signing.
func WithUnderlyingTokenTransport(tr http.RoundTripper) TokenOption {
return func(t *TokenTransport) {
t.underlyingTransport = tr
}
}
// Deprecated: token based authentication is deprecated and will be removed at
// some point in the future in favour of request signing.
func NewTokenTransport(opts ...TokenOption) *TokenTransport {
t := &TokenTransport{
underlyingTransport: http.DefaultTransport,
}
for _, opt := range opts {
opt(t)
}
return t
}
func (t *TokenTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", UserAgent)
if t.token != "" {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
}
var requestBodyClone io.Reader
if req.Body != nil {
originalRequestBody, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("could not read original request body: %w", err)
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(originalRequestBody))
requestBodyClone = bytes.NewBuffer(originalRequestBody)
}
res, err := t.underlyingTransport.RoundTrip(req)
if res != nil && (res.StatusCode == http.StatusUnauthorized || res.StatusCode == http.StatusForbidden) {
authRequest, err := http.NewRequest(http.MethodPost, req.URL.Scheme+"://"+req.URL.Host+tokenBasedAuthEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("could not build auth request: %w", err)
}
authRequest.SetBasicAuth(t.clientID.String(), t.clientSecret)
authResponse, err := t.underlyingTransport.RoundTrip(authRequest)
if err != nil {
return nil, fmt.Errorf("error authenticating: %w", err)
}
defer authResponse.Body.Close()
if authResponse.StatusCode != http.StatusOK {
return nil, fmt.Errorf("non 200 status code getting token, status code: %d", authResponse.StatusCode)
}
authBody, err := ioutil.ReadAll(authResponse.Body)
if err != nil {
return nil, fmt.Errorf("could not read auth response body: %w", err)
}
var authJSON authJSONResponse
err = json.Unmarshal(authBody, &authJSON)
if err != nil {
return nil, fmt.Errorf("could not parse auth json response: %w", err)
}
t.token = authJSON.AccessToken
req.Header.Del("Authorization")
retryRequest, err := http.NewRequest(req.Method, req.URL.String(), requestBodyClone)
if err != nil {
return nil, fmt.Errorf("could not build authenticated request: %w", err)
}
retryRequest.Header = req.Header
retryRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
return t.underlyingTransport.RoundTrip(retryRequest)
}
return res, err
}