forked from drone/go-scm
-
Notifications
You must be signed in to change notification settings - Fork 84
/
refresh.go
137 lines (119 loc) · 3.07 KB
/
refresh.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"context"
"encoding/json"
"net/http"
"net/url"
"strings"
"time"
"github.com/jenkins-x/go-scm/scm"
)
// expiryDelta determines how earlier a token should be considered
// expired than its actual expiration time. It is used to avoid late
// expirations due to client-server time mismatches.
const expiryDelta = time.Minute
// Refresher is an http.RoundTripper that refreshes oauth
// tokens, wrapping a base RoundTripper and refreshing the
// token if expired.
//
// IMPORTANT the Refresher is NOT safe for concurrent use
// by multiple goroutines.
type Refresher struct {
ClientID string
ClientSecret string
Endpoint string
Source scm.TokenSource
Client *http.Client
}
// Token returns a token. If the token is missing or
// expired, the token is refreshed.
func (t *Refresher) Token(ctx context.Context) (*scm.Token, error) {
token, err := t.Source.Token(ctx)
if err != nil {
return nil, err
}
if !expired(token) {
return token, nil
}
err = t.Refresh(token)
if err != nil {
return nil, err
}
return token, nil
}
// Refresh refreshes the expired token.
func (t *Refresher) Refresh(token *scm.Token) error {
values := url.Values{}
values.Set("grant_type", "refresh_token")
values.Set("refresh_token", token.Refresh)
reader := strings.NewReader(
values.Encode(),
)
req, err := http.NewRequest("POST", t.Endpoint, reader)
if err != nil {
return err
}
req.SetBasicAuth(t.ClientID, t.ClientSecret)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
res, err := t.client().Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode > 299 {
out := new(tokenError)
err = json.NewDecoder(res.Body).Decode(out)
if err != nil {
return err
}
return out
}
out := new(tokenGrant)
err = json.NewDecoder(res.Body).Decode(out)
if err != nil {
return err
}
token.Token = out.Access
token.Refresh = out.Refresh
token.Expires = time.Now().Add(
time.Duration(out.Expires) * time.Second,
)
return nil
}
// client returns the http transport. If no base client
// is configured, the default client is returned.
func (t *Refresher) client() *http.Client {
if t.Client != nil {
return t.Client
}
return http.DefaultClient
}
// expired reports whether the token is expired.
func expired(token *scm.Token) bool {
if token.Refresh == "" {
return false
}
if token.Expires.IsZero() && token.Token != "" {
return false
}
return token.Expires.Add(-expiryDelta).
Before(time.Now())
}
// tokenGrant is the token returned by the token endpoint.
type tokenGrant struct {
Access string `json:"access_token"`
Refresh string `json:"refresh_token"`
Expires int64 `json:"expires_in"`
}
// tokenError is the error returned when the token endpoint
// returns a non-2XX HTTP status code.
type tokenError struct {
Code string `json:"error"`
Message string `json:"error_description"`
}
func (t *tokenError) Error() string {
return t.Message
}