This repository has been archived by the owner on Dec 19, 2023. It is now read-only.
forked from marcinwyszynski/kmsjwt
/
kmsjwt.go
136 lines (110 loc) · 3.17 KB
/
kmsjwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package kmsjwt
import (
"context"
"crypto/subtle"
"encoding/base64"
"errors"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/dgrijalva/jwt-go"
cache "github.com/patrickmn/go-cache"
)
const kmsAlgorighm = "KMS"
// ErrKmsVerification is an error shown when KMS token verification fails.
var ErrKmsVerification = errors.New("kms: verification error")
type kmsClient struct {
kmsiface.KMSAPI
cache *cache.Cache
kmsKeyID string
withCache bool
defaultExpiration time.Duration
cleanupInterval time.Duration
signingAlgorithm string
}
// New provides a KMS-based implementation of JWT signing method.
func New(client kmsiface.KMSAPI, kmsKeyID string, opts ...Option) jwt.SigningMethod {
ret := &kmsClient{
KMSAPI: client,
kmsKeyID: kmsKeyID,
withCache: true,
defaultExpiration: time.Hour,
cleanupInterval: time.Minute,
signingAlgorithm: kms.SigningAlgorithmSpecRsassaPssSha512,
}
for _, opt := range opts {
opt(ret)
}
if ret.withCache {
ret.cache = cache.New(ret.defaultExpiration, ret.cleanupInterval)
}
return ret
}
func (k *kmsClient) Alg() string {
return kmsAlgorighm
}
func (k *kmsClient) Sign(signingString string, key interface{}) (string, error) {
ctx, ok := key.(context.Context)
if !ok {
return "", errors.New("key is not a context")
}
out, err := k.SignWithContext(ctx, &kms.SignInput{
KeyId: aws.String(k.kmsKeyID),
Message: []byte(signingString),
MessageType: aws.String("RAW"),
SigningAlgorithm: aws.String(k.signingAlgorithm),
})
if err != nil && errors.Is(err, context.Canceled) {
return "", err
} else if err != nil {
return "", jwt.ErrInvalidKey
}
if k.cache != nil {
k.cache.SetDefault(signingString, out.Signature)
}
return base64.StdEncoding.EncodeToString(out.Signature), nil
}
func (k *kmsClient) Verify(signingString, stringSignature string, key interface{}) error {
ctx, ok := key.(context.Context)
if !ok {
return errors.New("key is not a context")
}
signature, err := base64.StdEncoding.DecodeString(stringSignature)
if err != nil {
return errors.New("invalid signature encoding")
}
if k.verifyCache(signingString, signature) {
return nil
}
out, err := k.VerifyWithContext(ctx, &kms.VerifyInput{
KeyId: aws.String(k.kmsKeyID),
Message: []byte(signingString),
MessageType: aws.String("RAW"),
Signature: signature,
SigningAlgorithm: aws.String(k.signingAlgorithm),
})
if err != nil && errors.Is(err, context.Canceled) {
return err
} else if err != nil || out.SignatureValid == nil || !(*out.SignatureValid) {
return ErrKmsVerification
}
if k.cache != nil {
k.cache.SetDefault(signingString, signature)
}
return nil
}
func (k *kmsClient) verifyCache(signingString string, providedSignature []byte) bool {
if k.cache == nil {
return false
}
untypedCached, isCached := k.cache.Get(signingString)
if !isCached {
return false
}
typedCached, typeOK := untypedCached.([]byte)
if !typeOK {
return false
}
return subtle.ConstantTimeCompare(typedCached, providedSignature) == 1
}