-
Notifications
You must be signed in to change notification settings - Fork 6
/
jwt.go
143 lines (125 loc) · 3.87 KB
/
jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package daemonauth
import (
"context"
"crypto/rsa"
"fmt"
"net/http"
"os"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/go-chi/jwtauth/v5"
"github.com/shaj13/go-guardian/v2/auth"
"github.com/shaj13/go-guardian/v2/auth/strategies/token"
"golang.org/x/crypto/ssh"
)
type (
// JWTCreator implements CreateUserToken method
JWTCreator struct{}
// apiClaims defines api claims
apiClaims struct {
Grant []string `json:"grant"`
*jwt.StandardClaims
}
// JWTFiler is the interface that groups SignKeyFile and VerifyKeyFile methods
// for JWT auth.
JWTFiler interface {
SignKeyFile() string
VerifyKeyFile() string
}
)
var (
jwtAuth *jwtauth.JWTAuth
// jwtVerifyKeySign is the jwt verify key signature initialized during initAuthJWT
jwtVerifyKeySign string
)
func initJWT(i interface{}) (string, auth.Strategy, error) {
var (
err error
verifyKey *rsa.PublicKey
name = "jwt"
)
verifyKey, jwtAuth, err = initAuthJWT(i)
if err != nil {
return name, nil, err
}
validate := func(ctx context.Context, r *http.Request, s string) (info auth.Info, exp time.Time, err error) {
var tk *jwt.Token
tk, err = jwt.ParseWithClaims(s, &apiClaims{}, func(token *jwt.Token) (interface{}, error) {
return verifyKey, nil
})
if err != nil {
return
}
claims := tk.Claims.(*apiClaims)
exp = time.Unix(claims.ExpiresAt, 0)
extensions := authenticatedExtensions("jwt", claims.Grant...)
info = auth.NewUserInfo(claims.Subject, claims.Subject, nil, *extensions)
return
}
return name, token.New(validate, cache), nil
}
// initAuthJWT initialize auth JWT and returns verify key and *jwtauth.JWTAuth
func initAuthJWT(i interface{}) (*rsa.PublicKey, *jwtauth.JWTAuth, error) {
var (
err error
verifyBytes []byte
signBytes []byte
signKey *rsa.PrivateKey
verifyKey *rsa.PublicKey
)
f, ok := i.(JWTFiler)
if !ok {
return nil, nil, fmt.Errorf("missing sign and verify files")
}
var (
signKeyFile = f.SignKeyFile()
verifyKeyFile = f.VerifyKeyFile()
)
if signKeyFile == "" && verifyKeyFile == "" {
return nil, nil, fmt.Errorf("jwt undefined files: sign key and verify key")
} else if signKeyFile == "" {
return nil, nil, fmt.Errorf("jwt undefined file: sign key")
// If we want to support less secure HMAC token from a static sign key:
// jwtAuth = jwtauth.New("HMAC", []byte(jwtSignKey), nil)
} else if verifyKeyFile == "" {
return nil, nil, fmt.Errorf("jwt undefined file: verify key")
}
if signBytes, err = os.ReadFile(signKeyFile); err != nil {
return nil, nil, fmt.Errorf("%w: jwt sign key file", err)
}
if verifyBytes, err = os.ReadFile(verifyKeyFile); err != nil {
return nil, nil, fmt.Errorf("%w: jwt verify key file", err)
}
if signKey, err = jwt.ParseRSAPrivateKeyFromPEM(signBytes); err != nil {
return nil, nil, fmt.Errorf("%w: parse RSA private key from sign key file content", err)
}
if verifyKey, err = jwt.ParseRSAPublicKeyFromPEM(verifyBytes); err != nil {
return nil, nil, fmt.Errorf("%w: parse RSA public key from verify key file content", err)
}
if pk, err := ssh.NewPublicKey(verifyKey); err != nil {
jwtVerifyKeySign = fmt.Sprintf("can't read public key:%s", err)
} else {
jwtVerifyKeySign = ssh.FingerprintLegacyMD5(pk)
}
return verifyKey, jwtauth.New("RS256", signKey, verifyKey), nil
}
// CreateUserToken implements CreateUserToken interface for JWTCreator.
// empty token is returned if jwtAuth is not initialized
func (*JWTCreator) CreateUserToken(userInfo auth.Info, duration time.Duration, xClaims map[string]interface{}) (tk string, expiredAt time.Time, err error) {
if jwtAuth == nil {
return
}
expiredAt = time.Now().Add(duration)
claims := map[string]interface{}{
"sub": userInfo.GetUserName(),
"exp": expiredAt.Unix(),
"grant": userInfo.GetExtensions()["grant"],
}
for c, v := range xClaims {
claims[c] = v
}
if _, tk, err = jwtAuth.Encode(claims); err != nil {
return
}
return
}