diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index dd52d52e0716e..b97eb70724253 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -80,6 +80,20 @@ func (c *ClusterClient) Close() error { return trace.NewAggregate(c.AuthClient.Close(), c.ProxyClient.Close()) } +// ceremonyFailedErr indicates that the mfa ceremony was attempted unsuccessfully. +type ceremonyFailedErr struct { + err error +} + +// Error returns the error string of the wrapped error if one exists. +func (c ceremonyFailedErr) Error() string { + if c.err == nil { + return "" + } + + return c.err.Error() +} + // SessionSSHConfig returns the [ssh.ClientConfig] that should be used to connected to the // provided target for the provided user. If per session MFA is required to establish the // connection, then the MFA ceremony will be performed. @@ -147,7 +161,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe log.Debug("Issued single-use user certificate after an MFA check.") am, err := key.AsAuthMethod() if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(ceremonyFailedErr{err}) } sshConfig.Auth = []ssh.AuthMethod{am} @@ -339,7 +353,7 @@ func (c *ClusterClient) performMFACeremony(ctx context.Context, clt *ClusterClie mfaResp, err := clt.tc.PromptMFA(ctx, mfaChal) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(ceremonyFailedErr{err}) } err = stream.Send(&proto.UserSingleUseCertsRequest{Request: &proto.UserSingleUseCertsRequest_MFAResponse{MFAResponse: mfaResp}}) if err != nil { diff --git a/tool/tsh/common/tsh_test.go b/tool/tsh/common/tsh_test.go index c945b2c00744e..d2fb834e293c4 100644 --- a/tool/tsh/common/tsh_test.go +++ b/tool/tsh/common/tsh_test.go @@ -1286,6 +1286,10 @@ func TestSSHOnMultipleNodes(t *testing.T) { } } + abortedChallenge := func(ctx context.Context, realOrigin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return nil, "", errors.New("aborted challenge") + } + cases := []struct { name string target string @@ -1499,6 +1503,30 @@ func TestSSHOnMultipleNodes(t *testing.T) { mfaPromptCount: 1, errAssertion: require.Error, }, + { + name: "aborted ceremony when role requires per session mfa", + authPreference: &types.AuthPreferenceV2{ + Spec: types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOptional, + Webauthn: &types.Webauthn{ + RPID: "localhost", + }, + }, + }, + proxyAddr: rootProxyAddr.String(), + auth: rootAuth.GetAuthServer(), + target: sshHostID, + roles: []string{perSessionMFARole.GetName()}, + webauthnLogin: abortedChallenge, + stdoutAssertion: require.Empty, + stderrAssertion: func(t require.TestingT, v any, i ...any) { + out, ok := v.(string) + require.True(t, ok, i...) + require.Contains(t, out, "failed to authenticate using all MFA devices", i...) + }, + errAssertion: require.Error, + }, { name: "mfa ceremony prevented when using headless auth", authPreference: &types.AuthPreferenceV2{