-
Notifications
You must be signed in to change notification settings - Fork 654
/
token.go
137 lines (113 loc) · 4.71 KB
/
token.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package auth
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/coreos/go-oidc"
grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/sets"
"github.com/flyteorg/flyte/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/errors"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
const (
ErrRefreshingToken errors.ErrorCode = "TOKEN_REFRESH_FAILURE"
ErrTokenExpired errors.ErrorCode = "JWT_EXPIRED"
ErrJwtValidation errors.ErrorCode = "JWT_VERIFICATION_FAILED"
)
// Refresh a JWT
func GetRefreshedToken(ctx context.Context, oauth *oauth2.Config, accessToken, refreshToken string) (*oauth2.Token, error) {
logger.Debugf(ctx, "Attempting to refresh token")
originalToken := oauth2.Token{
AccessToken: accessToken,
RefreshToken: refreshToken,
Expiry: time.Now().Add(-1 * time.Minute), // force expired by setting to the past
}
tokenSource := oauth.TokenSource(ctx, &originalToken)
newToken, err := tokenSource.Token()
if err != nil {
logger.Errorf(ctx, "Error refreshing token %s", err)
return nil, errors.Wrapf(ErrRefreshingToken, err, "Error refreshing token")
}
return newToken, nil
}
func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, provider *oidc.Provider) (*oidc.IDToken, error) {
cfg := &oidc.Config{
ClientID: clientID,
}
if len(clientID) == 0 {
cfg.SkipClientIDCheck = true
cfg.SkipIssuerCheck = true
cfg.SkipExpiryCheck = true
}
var verifier = provider.Verifier(cfg)
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
logger.Debugf(ctx, "JWT parsing with claims failed %s", err)
flyteErr := errors.Wrapf(ErrJwtValidation, err, "jwt parse with claims failed")
// TODO: Contribute an errors package to the go-oidc library for proper error handling
if strings.Contains(err.Error(), "token is expired") {
return idToken, errors.Wrapf(ErrTokenExpired, flyteErr, "token is expired")
}
return idToken, flyteErr
}
return idToken, nil
}
// GRPCGetIdentityFromAccessToken attempts to extract a token from the context, and will then call the validation
// function, passing up any errors.
func GRPCGetIdentityFromAccessToken(ctx context.Context, authCtx interfaces.AuthenticationContext) (
interfaces.IdentityContext, error) {
tokenStr, err := grpcauth.AuthFromMD(ctx, BearerScheme)
if err != nil {
logger.Debugf(ctx, "Could not retrieve bearer token from metadata %v", err)
return nil, errors.Wrapf(ErrJwtValidation, err, "Could not retrieve bearer token from metadata")
}
if tokenStr == "" {
logger.Debugf(ctx, "Found Bearer scheme but token was blank")
return nil, errors.Errorf(ErrJwtValidation, "%v token is blank", IDTokenScheme)
}
expectedAudience := GetPublicURL(ctx, nil, authCtx.Options()).String()
return authCtx.OAuth2ResourceServer().ValidateAccessToken(ctx, expectedAudience, tokenStr)
}
// GRPCGetIdentityFromIDToken attempts to extract a token from the context, and will then call the validation function,
// passing up any errors.
func GRPCGetIdentityFromIDToken(ctx context.Context, clientID string, provider *oidc.Provider) (
interfaces.IdentityContext, error) {
tokenStr, err := grpcauth.AuthFromMD(ctx, IDTokenScheme)
if err != nil {
logger.Debugf(ctx, "Could not retrieve id token from metadata %v", err)
return nil, errors.Wrapf(ErrJwtValidation, err, "Could not retrieve id token from metadata")
}
if tokenStr == "" {
logger.Debugf(ctx, "Found Bearer scheme but token was blank")
return nil, errors.Errorf(ErrJwtValidation, "%v token is blank", IDTokenScheme)
}
meta := metautils.ExtractIncoming(ctx)
userInfoDecoded := meta.Get(UserInfoMDKey)
userInfo := &service.UserInfoResponse{}
if len(userInfoDecoded) > 0 {
err = json.Unmarshal([]byte(userInfoDecoded), userInfo)
if err != nil {
logger.Infof(ctx, "Could not unmarshal user info from metadata %v", err)
}
}
return IdentityContextFromIDTokenToken(ctx, tokenStr, clientID, provider, userInfo)
}
func IdentityContextFromIDTokenToken(ctx context.Context, tokenStr, clientID string, provider *oidc.Provider,
userInfo *service.UserInfoResponse) (interfaces.IdentityContext, error) {
idToken, err := ParseIDTokenAndValidate(ctx, clientID, tokenStr, provider)
if err != nil {
return nil, err
}
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err)
}
// TODO: Document why automatically specify "all" scope
return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt,
sets.NewString(ScopeAll), userInfo, claims)
}