From 87bffeb68530529e68d2bbd93ff79835877f6466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Skrz=C4=99tnicki?= Date: Tue, 19 Sep 2023 11:02:46 +0200 Subject: [PATCH] Verify expected token properties in WithProvisionTokenAuth. The token must have join method "token" and be non-expired. --- lib/web/apiserver.go | 30 ++++++++-- lib/web/apiserver_test.go | 114 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 4 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 484ecfdfc7822..2728ebc821cbc 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3652,7 +3652,7 @@ func (h *Handler) WithProvisionTokenAuth(fn ProvisionTokenHandler) httprouter.Ha return nil, trace.AccessDenied("need auth") } - token, err := h.consumeTokenForAPICall(ctx, creds.Password) + token, err := consumeTokenForAPICall(ctx, h.GetProxyClient(), creds.Password) if err != nil { h.log.WithError(err).Warn("Failed to authenticate.") return nil, trace.AccessDenied("need auth") @@ -3675,17 +3675,39 @@ func (h *Handler) WithProvisionTokenAuth(fn ProvisionTokenHandler) httprouter.Ha // This is possible because the latest call - DeleteToken - returns an error if the resource doesn't exist // This is currently true for all the backends as explained here // https://github.com/gravitational/teleport/commit/24fcadc375d8359e80790b3ebeaa36bd8dd2822f -func (h *Handler) consumeTokenForAPICall(ctx context.Context, tokenName string) (types.ProvisionToken, error) { - token, err := h.GetProxyClient().GetToken(ctx, tokenName) +func consumeTokenForAPICall(ctx context.Context, proxyClient auth.ClientI, tokenName string) (types.ProvisionToken, error) { + token, err := proxyClient.GetToken(ctx, tokenName) if err != nil { return nil, trace.Wrap(err) } - if err := h.GetProxyClient().DeleteToken(ctx, token.GetName()); err != nil { + if token.GetJoinMethod() != types.JoinMethodToken { + return nil, trace.BadParameter("unexpected join method %q for token %q", token.GetJoinMethod(), token.GetSafeName()) + } + + if !checkTokenTTL(token) { + return nil, trace.BadParameter("expired token %q", token.GetSafeName()) + } + + if err := proxyClient.DeleteToken(ctx, token.GetName()); err != nil { return nil, trace.Wrap(err) } return token, nil + +} + +// checkTokenTTL returns true if the token is still valid. +// This is similar to checkTokenTTL in auth.Server, but does not delete expired tokens. +func checkTokenTTL(tok types.ProvisionToken) bool { + // Always accept tokens without an expiry configured. + if tok.Expiry().IsZero() { + return true + } + + now := time.Now().UTC() + + return tok.Expiry().After(now) } type redirectHandlerFunc func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (redirectURL string) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 935821f79d543..b757aa2b7c162 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -9346,3 +9346,117 @@ func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.Test require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } + +type proxyClientMock struct { + auth.ClientI + tokens map[string]types.ProvisionToken +} + +// GetToken returns provisioning token +func (pc *proxyClientMock) GetToken(_ context.Context, token string) (types.ProvisionToken, error) { + tok, ok := pc.tokens[token] + if ok { + return tok, nil + } + + return nil, trace.NotFound(token) +} + +func (pc *proxyClientMock) DeleteToken(_ context.Context, token string) error { + _, ok := pc.tokens[token] + if ok { + delete(pc.tokens, token) + return nil + } + return trace.NotFound(token) +} + +func Test_consumeTokenForAPICall(t *testing.T) { + pc := &proxyClientMock{tokens: map[string]types.ProvisionToken{}} + + tests := []struct { + name string + getToken func() (string, types.ProvisionToken) + wantErr require.ErrorAssertionFunc + }{ + { + name: "missing token is rejected", + getToken: func() (string, types.ProvisionToken) { + return "fake", nil + }, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err)) + }, + }, + { + name: "valid token is accepted", + getToken: func() (string, types.ProvisionToken) { + tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Now().Add(time.Hour)) + require.NoError(t, err) + pc.tokens[tok.GetName()] = tok + return tok.GetName(), tok + }, + }, + { + name: "token with no expiry is accepted", + getToken: func() (string, types.ProvisionToken) { + tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Time{}) + require.NoError(t, err) + pc.tokens[tok.GetName()] = tok + return tok.GetName(), tok + }, + }, + { + name: "expired token is rejected", + getToken: func() (string, types.ProvisionToken) { + tok, err := types.NewProvisionToken(uuid.New().String(), []types.SystemRole{types.RoleDatabase}, time.Now().Add(-time.Hour)) + require.NoError(t, err) + pc.tokens[tok.GetName()] = tok + return tok.GetName(), tok + }, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "expired token") + }, + }, + { + name: "token with invalid join type is rejected", + getToken: func() (string, types.ProvisionToken) { + tok, err := types.NewProvisionTokenFromSpec("ec2-token", time.Now().Add(time.Hour), types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleDatabase}, + Allow: []*types.TokenRule{{AWSAccount: "1234"}}, + JoinMethod: types.JoinMethodEC2, + }) + + require.NoError(t, err) + pc.tokens[tok.GetName()] = tok + return tok.GetName(), tok + }, + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "unexpected join method \"ec2\" for token \"ec2-token\"") + }, + }, + } + + tokenExists := func(tokenName string) bool { + tok, _ := pc.GetToken(context.Background(), tokenName) + return tok != nil + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokName, tok := tt.getToken() + tokenInitiallyPresent := tokenExists(tokName) + result, err := consumeTokenForAPICall(context.Background(), pc, tokName) + if tt.wantErr != nil { + tt.wantErr(t, err) + // verify that if token was present, then it has not been deleted. + require.Equal(t, tokenInitiallyPresent, tokenExists(tokName)) + } else { + require.NoError(t, err) + require.Equal(t, tok, result) + // verify that token does not exist now, even if it did. + require.False(t, tokenExists(tokName)) + } + }) + } +}