-
Notifications
You must be signed in to change notification settings - Fork 0
/
jwks.go
149 lines (131 loc) · 4.07 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/pquerna/cachecontrol"
"golang.org/x/oauth2"
)
// RemoteKeyStore Stores OIDC provider's JWKs and caches them for the duration specified in
// the cache-control header. Keys will be refreshed upon expiry.
type RemoteKeyStore struct {
jose.JSONWebKeySet
Context context.Context
JwksURI string
Expiry time.Time
mutex sync.Mutex
}
// ByUse returns a key from RemoteKeyStore by use. If the keystore contains
// multiple keys with same use then first key will be returned.
func (r *RemoteKeyStore) ByUse(use string) (*jose.JSONWebKey, error) {
// Let's refresh the keys if cached keys expire within the next 10 minutes
tenMinutesFromNow := time.Now().UTC().Add(10 * time.Minute)
if tenMinutesFromNow.After(r.Expiry.Add(-1 * time.Second)) {
keys, expiry, err := updateKeys(r.Context, r.JwksURI)
if err != nil {
return nil, err
}
r.Keys = keys
r.Expiry = expiry
}
for _, key := range r.Keys {
if key.Use == use {
return &key, nil
}
}
return nil, errors.New("key not found")
}
// ById returns a key from RemoteKeyStore by key id. If the RemoteKeyStore
// contains multiple keys with same id then first matching key is returned.
func (r *RemoteKeyStore) ById(kid string) (*jose.JSONWebKey, error) {
// Let's refresh the keys if cached keys expire within the next 10 minutes
tenMinutesFromNow := time.Now().UTC().Add(10 * time.Minute)
if tenMinutesFromNow.After(r.Expiry) {
err := r.updateKeys()
if err != nil {
return nil, err
}
}
for _, key := range r.Keys {
if key.KeyID == kid {
return &key, nil
}
}
return nil, errors.New("key not found")
}
// updateKeys updates the keys and expiration in RemoteKeyStore.
func (r *RemoteKeyStore) updateKeys() error {
r.mutex.Lock()
defer r.mutex.Unlock()
keys, expiry, err := updateKeys(r.Context, r.JwksURI)
if err != nil {
return err
}
r.Keys = keys
r.Expiry = expiry
return nil
}
// providerRemoteKeys is a convenience method for fetching and unmarshaling
// the provider jwks from the jwks_uri. Returns a JWSONWebKeySet containing
// the keys.
func providerRemoteKeys(ctx context.Context, jwksUri string) (*RemoteKeyStore, error) {
keys, expiry, err := updateKeys(ctx, jwksUri)
if err != nil {
return nil, err
}
return &RemoteKeyStore{
JSONWebKeySet: jose.JSONWebKeySet{Keys: keys},
Context: ctx,
JwksURI: jwksUri,
Expiry: expiry,
}, nil
}
// updateKeys fetches the providers jwks from jwks_uri. The function respects
// cache headers and caches the results for specified time period. updateKeys is
func updateKeys(ctx context.Context, jwksUri string) ([]jose.JSONWebKey, time.Time, error) {
req, err := http.NewRequest("GET", jwksUri, nil)
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to create request: %v", err)
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to fetch keys %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, time.Time{}, fmt.Errorf("unable to get keys: %s %s", resp.Status, body)
}
var keySet jose.JSONWebKeySet
err = json.Unmarshal(body, &keySet)
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to unmarshal keys: %v", err)
}
expiry := time.Now().UTC()
_, exp, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})
if err != nil {
return nil, time.Time{}, fmt.Errorf("unable to parse response cache headers: %v", err)
}
if exp.After(expiry) {
expiry = exp
}
return keySet.Keys, expiry, nil
}
// doRequest executes http request using either http.DefaultClient or the
// client specified in context if it's available.
func doRequest(ctx context.Context, request *http.Request) (*http.Response, error) {
client := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
client = c
}
return client.Do(request.WithContext(ctx))
}