forked from oapi-codegen/oapi-codegen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jwt_authenticator.go
136 lines (116 loc) · 4.07 KB
/
jwt_authenticator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package server
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/foorester/oapi-codegen/pkg/middleware"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/lestrrat-go/jwx/jwt"
)
// JWSValidator is used to validate JWS payloads and return a JWT if they're
// valid
type JWSValidator interface {
ValidateJWS(jws string) (jwt.Token, error)
}
const JWTClaimsContextKey = "jwt_claims"
var (
ErrNoAuthHeader = errors.New("Authorization header is missing")
ErrInvalidAuthHeader = errors.New("Authorization header is malformed")
ErrClaimsInvalid = errors.New("Provided claims do not match expected scopes")
)
// GetJWSFromRequest extracts a JWS string from an Authorization: Bearer <jws> header
func GetJWSFromRequest(req *http.Request) (string, error) {
authHdr := req.Header.Get("Authorization")
// Check for the Authorization header.
if authHdr == "" {
return "", ErrNoAuthHeader
}
// We expect a header value of the form "Bearer <token>", with 1 space after
// Bearer, per spec.
prefix := "Bearer "
if !strings.HasPrefix(authHdr, prefix) {
return "", ErrInvalidAuthHeader
}
return strings.TrimPrefix(authHdr, prefix), nil
}
func NewAuthenticator(v JWSValidator) openapi3filter.AuthenticationFunc {
return func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
return Authenticate(v, ctx, input)
}
}
// Authenticate uses the specified validator to ensure a JWT is valid, then makes
// sure that the claims provided by the JWT match the scopes as required in the API.
func Authenticate(v JWSValidator, ctx context.Context, input *openapi3filter.AuthenticationInput) error {
// Our security scheme is named BearerAuth, ensure this is the case
if input.SecuritySchemeName != "BearerAuth" {
return fmt.Errorf("security scheme %s != 'BearerAuth'", input.SecuritySchemeName)
}
// Now, we need to get the JWS from the request, to match the request expectations
// against request contents.
jws, err := GetJWSFromRequest(input.RequestValidationInput.Request)
if err != nil {
return fmt.Errorf("getting jws: %w", err)
}
// if the JWS is valid, we have a JWT, which will contain a bunch of claims.
token, err := v.ValidateJWS(jws)
if err != nil {
return fmt.Errorf("validating JWS: %w", err)
}
// We've got a valid token now, and we can look into its claims to see whether
// they match. Every single scope must be present in the claims.
err = CheckTokenClaims(input.Scopes, token)
if err != nil {
return fmt.Errorf("token claims don't match: %w", err)
}
// Set the property on the echo context so the handler is able to
// access the claims data we generate in here.
eCtx := middleware.GetEchoContext(ctx)
eCtx.Set(JWTClaimsContextKey, token)
return nil
}
// GetClaimsFromToken returns a list of claims from the token. We store these
// as a list under the "perms" claim, short for permissions, to keep the token
// shorter.
func GetClaimsFromToken(t jwt.Token) ([]string, error) {
rawPerms, found := t.Get(PermissionsClaim)
if !found {
// If the perms aren't found, it means that the token has none, but it has
// passed signature validation by now, so it's a valid token, so we return
// the empty list.
return make([]string, 0), nil
}
// rawPerms will be an untyped JSON list, so we need to convert it to
// a string list.
rawList, ok := rawPerms.([]interface{})
if !ok {
return nil, fmt.Errorf("'%s' claim is unexpected type'", PermissionsClaim)
}
claims := make([]string, len(rawList))
for i, rawClaim := range rawList {
var ok bool
claims[i], ok = rawClaim.(string)
if !ok {
return nil, fmt.Errorf("%s[%d] is not a string", PermissionsClaim, i)
}
}
return claims, nil
}
func CheckTokenClaims(expectedClaims []string, t jwt.Token) error {
claims, err := GetClaimsFromToken(t)
if err != nil {
return fmt.Errorf("getting claims from token: %w", err)
}
// Put the claims into a map, for quick access.
claimsMap := make(map[string]bool, len(claims))
for _, c := range claims {
claimsMap[c] = true
}
for _, e := range expectedClaims {
if !claimsMap[e] {
return ErrClaimsInvalid
}
}
return nil
}