-
Notifications
You must be signed in to change notification settings - Fork 616
/
vault_client.go
209 lines (187 loc) · 5.16 KB
/
vault_client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
package cert
import (
"encoding/json"
"errors"
"io/ioutil"
"log"
"os"
"strings"
"sync"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
)
// vaultClient wraps an *api.Client and takes care of token renewal
// automatically.
type vaultClient struct {
addr string // overrides the default config
token string // overrides the VAULT_TOKEN environment variable
fetchVaultToken string
prevFetchedToken string
client *api.Client
mu sync.Mutex
}
func NewVaultClient(fetchVaultToken string) *vaultClient {
return &vaultClient{
fetchVaultToken: fetchVaultToken,
}
}
var DefaultVaultClient = &vaultClient{}
func (c *vaultClient) Get() (*api.Client, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.client != nil {
if c.fetchVaultToken != "" {
token := strings.TrimSpace(getVaultToken(c.fetchVaultToken))
if token != c.prevFetchedToken {
log.Printf("[DEBUG] vault: token has changed, setting new token")
// did we get a wrapped token?
resp, err := c.client.Logical().Unwrap(token)
switch {
case err == nil:
log.Printf("[INFO] vault: Unwrapped token %s", token)
c.client.SetToken(resp.Auth.ClientToken)
case strings.HasPrefix(err.Error(), "no value found at"):
// not a wrapped token
default:
return nil, err
}
c.prevFetchedToken = token
}
}
return c.client, nil
}
conf := api.DefaultConfig()
if err := conf.ReadEnvironment(); err != nil {
return nil, err
}
if c.addr != "" {
conf.Address = c.addr
}
client, err := api.NewClient(conf)
if err != nil {
return nil, err
}
if c.fetchVaultToken != "" {
token := strings.TrimSpace(getVaultToken(c.fetchVaultToken))
log.Printf("[DEBUG] vault: fetching initial token")
if token != c.prevFetchedToken {
c.token = token
c.prevFetchedToken = token
}
}
if c.token != "" {
client.SetToken(c.token)
}
token := client.Token()
if token == "" {
return nil, errors.New("vault: no token")
}
// did we get a wrapped token?
resp, err := client.Logical().Unwrap(token)
var respErr *api.ResponseError
contains := func(haystack []string, needle string) bool {
for _, h := range haystack {
if h == needle {
return true
}
}
return false
}
switch {
case err == nil:
log.Printf("[INFO] vault: Unwrapped token %s", token)
client.SetToken(resp.Auth.ClientToken)
case errors.As(err, &respErr) &&
contains(respErr.Errors, consts.ErrInvalidWrappingToken.Error()):
// not wrapped
default:
return nil, err
}
c.client = client
go c.keepTokenAlive()
return client, nil
}
// dropNotRenewableWarning controls whether the 'Token is not renewable'
// warning is logged. This is useful for testing where this is the expected
// behavior. On production, this should always be set to false.
var dropNotRenewableWarning bool
func (c *vaultClient) keepTokenAlive() {
resp, err := c.client.Auth().Token().LookupSelf()
if err != nil {
log.Printf("[WARN] vault: lookup-self failed, token renewal is disabled: %s", err)
return
}
b, _ := json.Marshal(resp.Data)
var data struct {
TTL int `json:"ttl"`
CreationTTL int `json:"creation_ttl"`
Renewable bool `json:"renewable"`
ExpireTime time.Time `json:"expire_time"`
}
if err := json.Unmarshal(b, &data); err != nil {
log.Printf("[WARN] vault: lookup-self failed, token renewal is disabled: %s", err)
return
}
switch {
case data.Renewable:
// no-op
case data.ExpireTime.IsZero():
// token doesn't expire
return
case dropNotRenewableWarning:
return
default:
ttl := time.Until(data.ExpireTime)
ttl = ttl / time.Second * time.Second // truncate to seconds
log.Printf("[WARN] vault: Token is not renewable and will expire %s from now at %s",
ttl, data.ExpireTime.Format(time.RFC3339))
return
}
ttl := time.Duration(data.TTL) * time.Second
timer := time.NewTimer(ttl / 2)
for range timer.C {
resp, err := c.client.Auth().Token().RenewSelf(data.CreationTTL)
if err != nil {
log.Printf("[WARN] vault: Failed to renew token: %s", err)
timer.Reset(time.Second) // TODO: backoff? abort after N consecutive failures?
continue
}
if !resp.Auth.Renewable || resp.Auth.LeaseDuration == 0 {
// token isn't renewable anymore, we're done.
return
}
ttl = time.Duration(resp.Auth.LeaseDuration) * time.Second
timer.Reset(ttl / 2)
}
}
func getVaultToken(c string) string {
var token string
c = strings.TrimSpace(c)
cArray := strings.SplitN(c, ":", 2)
if len(cArray) < 2 {
log.Printf("[WARN] vault: vaultfetchtoken not properly set")
return token
}
if cArray[0] == "file" {
b, err := ioutil.ReadFile(cArray[1]) // just pass the file name
if err != nil {
log.Printf("[WARN] vault: Failed to fetch token from %s", c)
} else {
token = string(b)
log.Printf("[DEBUG] vault: Successfully fetched token from %s", c)
return token
}
} else if cArray[0] == "env" {
token = os.Getenv(cArray[1])
if len(token) == 0 {
log.Printf("[WARN] vault: Failed to fetch token from %s", c)
} else {
log.Printf("[DEBUG] vault: Successfully fetched token from %s", c)
return token
}
} else {
log.Printf("[WARN] vault: vaultfetchtoken not properly set")
}
return token
}