Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Verify expected token properties in WithProvisionTokenAuth #32215

Merged
merged 2 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 26 additions & 4 deletions lib/web/apiserver.go
Expand Up @@ -3653,7 +3653,7 @@ func (h *Handler) WithProvisionTokenAuth(fn ProvisionTokenHandler) httprouter.Ha
return nil, trace.AccessDenied("need auth")
}

token, err := h.consumeTokenForAPICall(ctx, creds.Password)
token, err := consumeTokenForAPICall(ctx, h.GetProxyClient(), creds.Password)
if err != nil {
h.log.WithError(err).Warn("Failed to authenticate.")
return nil, trace.AccessDenied("need auth")
Expand All @@ -3676,17 +3676,39 @@ func (h *Handler) WithProvisionTokenAuth(fn ProvisionTokenHandler) httprouter.Ha
// This is possible because the latest call - DeleteToken - returns an error if the resource doesn't exist
// This is currently true for all the backends as explained here
// https://github.com/gravitational/teleport/commit/24fcadc375d8359e80790b3ebeaa36bd8dd2822f
func (h *Handler) consumeTokenForAPICall(ctx context.Context, tokenName string) (types.ProvisionToken, error) {
token, err := h.GetProxyClient().GetToken(ctx, tokenName)
func consumeTokenForAPICall(ctx context.Context, proxyClient auth.ClientI, tokenName string) (types.ProvisionToken, error) {
token, err := proxyClient.GetToken(ctx, tokenName)
if err != nil {
return nil, trace.Wrap(err)
}

if err := h.GetProxyClient().DeleteToken(ctx, token.GetName()); err != nil {
if token.GetJoinMethod() != types.JoinMethodToken {
return nil, trace.BadParameter("unexpected join method %q for token %q", token.GetJoinMethod(), token.GetSafeName())
}

if !checkTokenTTL(token) {
return nil, trace.BadParameter("expired token %q", token.GetSafeName())
}

if err := proxyClient.DeleteToken(ctx, token.GetName()); err != nil {
return nil, trace.Wrap(err)
}

return token, nil

}

// checkTokenTTL returns true if the token is still valid.
// This is similar to checkTokenTTL in auth.Server, but does not delete expired tokens.
func checkTokenTTL(tok types.ProvisionToken) bool {
// Always accept tokens without an expiry configured.
if tok.Expiry().IsZero() {
return true
}

now := time.Now().UTC()

return tok.Expiry().After(now)
}

type redirectHandlerFunc func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (redirectURL string)
Expand Down
114 changes: 114 additions & 0 deletions lib/web/apiserver_test.go
Expand Up @@ -9346,3 +9346,117 @@ func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.Test
require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, envelopeBytes))

}

type proxyClientMock struct {
auth.ClientI
tokens map[string]types.ProvisionToken
}

// GetToken returns provisioning token
func (pc *proxyClientMock) GetToken(_ context.Context, token string) (types.ProvisionToken, error) {
tok, ok := pc.tokens[token]
if ok {
return tok, nil
}

return nil, trace.NotFound(token)
}

func (pc *proxyClientMock) DeleteToken(_ context.Context, token string) error {
_, ok := pc.tokens[token]
if ok {
delete(pc.tokens, token)
return nil
}
return trace.NotFound(token)
}

func Test_consumeTokenForAPICall(t *testing.T) {
pc := &proxyClientMock{tokens: map[string]types.ProvisionToken{}}

tests := []struct {
name string
getToken func() (string, types.ProvisionToken)
wantErr require.ErrorAssertionFunc
}{
{
name: "missing token is rejected",
getToken: func() (string, types.ProvisionToken) {
return "fake", nil
},
wantErr: func(t require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsNotFound(err))
},
},
{
name: "valid token is accepted",
getToken: func() (string, types.ProvisionToken) {
tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Now().Add(time.Hour))
require.NoError(t, err)
pc.tokens[tok.GetName()] = tok
return tok.GetName(), tok
},
},
{
name: "token with no expiry is accepted",
getToken: func() (string, types.ProvisionToken) {
tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Time{})
require.NoError(t, err)
pc.tokens[tok.GetName()] = tok
return tok.GetName(), tok
},
},
{
name: "expired token is rejected",
getToken: func() (string, types.ProvisionToken) {
tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Now().Add(-time.Hour))
require.NoError(t, err)
pc.tokens[tok.GetName()] = tok
return tok.GetName(), tok
},
wantErr: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "expired token")
},
},
{
name: "token with invalid join type is rejected",
getToken: func() (string, types.ProvisionToken) {
tok, err := types.NewProvisionTokenFromSpec("ec2-token", time.Now().Add(time.Hour), types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleDatabase},
Allow: []*types.TokenRule{{AWSAccount: "1234"}},
JoinMethod: types.JoinMethodEC2,
})

require.NoError(t, err)
pc.tokens[tok.GetName()] = tok
return tok.GetName(), tok
},
wantErr: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unexpected join method \"ec2\" for token \"ec2-token\"")
},
},
}

tokenExists := func(tokenName string) bool {
tok, _ := pc.GetToken(context.Background(), tokenName)
return tok != nil
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokName, tok := tt.getToken()
tokenInitiallyPresent := tokenExists(tokName)
result, err := consumeTokenForAPICall(context.Background(), pc, tokName)
if tt.wantErr != nil {
tt.wantErr(t, err)
// verify that if token was present, then it has not been deleted.
require.Equal(t, tokenInitiallyPresent, tokenExists(tokName))
} else {
require.NoError(t, err)
require.Equal(t, tok, result)
// verify that token does not exist now, even if it did.
require.False(t, tokenExists(tokName))
}
})
}
}