/
jwt.go
212 lines (188 loc) · 4.88 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
package jwt
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
)
// ValidatorFunc is a function for running extra
// validators when parsing a JWT string.
type ValidatorFunc func(jot *JWT) error
var (
// ErrEmptyAuthorization indicates that the "Authorization" header
// doesn't have a token and, thus, extracting a token is impossible.
ErrEmptyAuthorization = errors.New("jwt: no token could be extracted from header")
// ErrMalformedToken indicates a token doesn't have
// a valid format, as per the RFC 7519, section 7.2.
ErrMalformedToken = errors.New("jwt: malformed token")
// ErrNilCtxKey indicates that no context key is set for retrieving
// JWTs from context objects. This error is resolved if a key is set.
ErrNilCtxKey = errors.New("jwt: JWT context key is a nil value")
// ErrNilCtxValue indicates the context value is nil.
// This mitigates possible nil pointer reference problems
// and avoids tiring and unnecessary JWT pointer checking.
ErrNilCtxValue = errors.New("jwt: context value is nil")
// ErrCtxAssertion indicates a JWT could not be extracted from a context object
// because the value it holds can not be asserted to a JWT pointer.
ErrCtxAssertion = errors.New("jwt: unable to assert context value into JWT pointer")
)
// JWT is a JSON Web Token.
type JWT struct {
header *header
claims *claims
raw string
sep int
}
// FromContext extracts a JWT object from a given context.
func FromContext(ctx context.Context, key interface{}) (*JWT, error) {
if key == nil {
return nil, ErrNilCtxKey
}
v := ctx.Value(key)
if v == nil {
return nil, ErrNilCtxValue
}
if token, ok := v.(string); ok {
return FromString(token)
}
jot, ok := v.(*JWT)
if !ok {
return nil, ErrCtxAssertion
}
return jot, nil
}
// FromCookie extracts a JWT object from a given cookie.
func FromCookie(c *http.Cookie) (*JWT, error) {
return FromString(c.Value)
}
// FromRequest builds a JWT from the token contained
// in the "Authorization" header.
func FromRequest(r *http.Request) (*JWT, error) {
auth := r.Header.Get("Authorization")
i := strings.IndexByte(auth, ' ')
if i < 0 {
return nil, ErrEmptyAuthorization
}
return FromString(auth[i+1:])
}
// FromString builds a JWT from a string representation
// of a JSON Web Token.
func FromString(s string) (*JWT, error) {
sep1 := strings.IndexByte(s, '.')
if sep1 < 0 {
return nil, ErrMalformedToken
}
sep2 := strings.IndexByte(s[sep1+1:], '.')
if sep2 < 0 {
return nil, ErrMalformedToken
}
sep2 += sep1 + 1
jot := &JWT{raw: s, sep: sep2}
if err := jot.build(); err != nil {
return nil, err
}
return jot, nil
}
// Algorithm returns the "alg" claim
// from the JWT's header.
func (j *JWT) Algorithm() string {
return j.header.Algorithm
}
// Audience returns the "aud" claim
// from the JWT's payload.
func (j *JWT) Audience() string {
return j.claims.aud
}
// Bytes returns a representation of the JWT
// as an array of bytes.
func (j *JWT) Bytes() []byte {
return []byte(j.raw)
}
// ExpirationTime returns the "exp" claim
// from the JWT's payload.
func (j *JWT) ExpirationTime() time.Time {
return j.claims.exp
}
// IssuedAt returns the "iat" claim
// from the JWT's payload.
func (j *JWT) IssuedAt() time.Time {
return j.claims.iat
}
// Issuer returns the "iss" claim
// from the JWT's payload.
func (j *JWT) Issuer() string {
return j.claims.iss
}
// ID returns the "jti" claim
// from the JWT's payload.
func (j *JWT) ID() string {
return j.claims.jti
}
// KeyID returns the "kid" claim
// from the JWT's header.
func (j *JWT) KeyID() string {
return j.header.KeyID
}
// NotBefore returns the "nbf" claim
// from the JWT's payload.
func (j *JWT) NotBefore() time.Time {
return j.claims.nbf
}
// Public returns all public claims set.
func (j *JWT) Public() map[string]interface{} {
return j.claims.pub
}
// Subject returns the "sub" claim
// from the JWT's payload.
func (j *JWT) Subject() string {
return j.claims.sub
}
func (j *JWT) String() string {
return j.raw
}
// Validate iterates over custom validator functions to validate the JWT.
func (j *JWT) Validate(vfuncs ...ValidatorFunc) error {
for _, vfunc := range vfuncs {
if err := vfunc(j); err != nil {
return err
}
}
return nil
}
// Verify verifies the Token's signature.
func (j *JWT) Verify(s Signer) error {
var (
sig []byte
err error
)
if sig, err = decode(j.raw[j.sep+1:]); err != nil {
return err
}
return s.Verify([]byte(j.raw[:j.sep]), sig)
}
func (j *JWT) build() error {
var (
p1, p2 = j.parts()
dec []byte
err error
)
if dec, err = decode(p1); err != nil {
return err
}
if err = json.Unmarshal(dec, &j.header); err != nil {
return err
}
if dec, err = decode(p2); err != nil {
return err
}
if err = json.Unmarshal(dec, &j.claims); err != nil {
return err
}
return nil
}
func (j *JWT) parts() (string, string) {
sep := strings.IndexByte(j.raw[:j.sep], '.')
return j.raw[:sep], j.raw[sep+1 : j.sep]
}