-
Notifications
You must be signed in to change notification settings - Fork 15
/
auth_token_based.go
121 lines (99 loc) · 3.39 KB
/
auth_token_based.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package form3
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
)
const (
tokenBasedAuthEndpoint = "/v1/oauth2/token"
)
type TokenBasedClientConfig struct {
clientID string
clientSecret string
hostURL *url.URL
initialToken string
underlyingTransport http.RoundTripper
}
type tokenBasedTransport struct {
config *TokenBasedClientConfig
token string
underlyingTransport http.RoundTripper
}
type authJSONResponse struct {
AccessToken string `json:"access_token"`
}
func NewTokenBasedClientConfig(clientID, clientSecret string, hostURL *url.URL) *TokenBasedClientConfig {
return &TokenBasedClientConfig{
clientID: clientID,
clientSecret: clientSecret,
hostURL: hostURL,
underlyingTransport: http.DefaultTransport,
}
}
func (c *TokenBasedClientConfig) WithinitialToken(token string) *TokenBasedClientConfig {
c.initialToken = token
return c
}
func (c *TokenBasedClientConfig) WithUnderlyingTransport(underlyingTransport http.RoundTripper) *TokenBasedClientConfig {
c.underlyingTransport = underlyingTransport
return c
}
func NewTokenBasedHTTPClient(config *TokenBasedClientConfig) *http.Client {
transport := &tokenBasedTransport{underlyingTransport: config.underlyingTransport, config: config}
transport.token = config.initialToken
h := &http.Client{Transport: transport}
return h
}
func (t *tokenBasedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.token != "" {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
}
var requestBodyClone io.Reader
if req.Body != nil {
originalRequestBody, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body = ioutil.NopCloser(bytes.NewBuffer(originalRequestBody))
requestBodyClone = bytes.NewBuffer(originalRequestBody)
}
res, err := t.underlyingTransport.RoundTrip(req)
if res != nil && (res.StatusCode == http.StatusUnauthorized || res.StatusCode == http.StatusForbidden) {
authRequest, err := http.NewRequest(http.MethodPost, t.config.hostURL.String()+tokenBasedAuthEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("could not build auth request, error: %v", err)
}
authRequest.SetBasicAuth(t.config.clientID, t.config.clientSecret)
authResponse, err := t.underlyingTransport.RoundTrip(authRequest)
if err != nil {
return nil, fmt.Errorf("error authenticating, error: %v", err)
}
defer authResponse.Body.Close()
if authResponse.StatusCode != http.StatusOK {
return nil, fmt.Errorf("non 200 status code getting token, status code: %d", authResponse.StatusCode)
}
authBody, err := ioutil.ReadAll(authResponse.Body)
if err != nil {
return nil, fmt.Errorf("could not read auth response body, error: %v", err)
}
var authJSON authJSONResponse
err = json.Unmarshal(authBody, &authJSON)
if err != nil {
return nil, fmt.Errorf("could not parse auth json response, error: %v", err)
}
t.token = authJSON.AccessToken
req.Header.Del("Authorization")
retryRequest, err := http.NewRequest(req.Method, req.URL.String(), requestBodyClone)
if err != nil {
return nil, fmt.Errorf("could not build authenticated request, error: %v", err)
}
retryRequest.Header = req.Header
retryRequest.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
return t.underlyingTransport.RoundTrip(retryRequest)
}
return res, err
}