Skip to content

Commit

Permalink
Attempt ssh connections with and without mfa at the same time (#23865)
Browse files Browse the repository at this point in the history
* Attempt ssh connections with and without mfa at the same time

`tsh ssh` would fallback to doing the mfa ceremony if connecting
to the node with the already provisioned certificates failed with
an access denied error. This incurs the cost of a round trip to
the target host when per session mfa is required. To combat the
additional latency when per session mfa is required we can
attempt both the connection with the certs on hand AND start the
per session mfa flow at the same time. If per session mfa is not
required the client won't attempt the mfa ceremony which adds no
impact there. If per session mfa is required the initial connection
to the host is going to fail so the mfa ceremony will need to be
performed any how.

For this to work we need to ensure that users are not prompted for
mfa if completing the mfa ceremony will not actually help the user
gain access to the host. If users just flat out do not have access
to the host we don't want to confuse them by prompting them to
touch a hardware key. Since `tsh` first calls
`proto.AuthService/IsMFARequired` before initiating the mfa ceremony
we are guaranteed not to initiate the mfa ceremony when not required.

* fix: return an error if mfa is not required

* apply same connection racing to the web ui

* fix: prevent race on mfacheck

* fix: tests and return the correct errors

* wrap all uses of MFARequiredUnknown

* fix: changes to work correctly with ClusterClient
  • Loading branch information
rosstimothy committed Apr 5, 2023
1 parent 48951cf commit 3113cd2
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 151 deletions.
182 changes: 139 additions & 43 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1465,11 +1465,10 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, runLocally
}

// ConnectToNode attempts to establish a connection to the node resolved to by the provided
// NodeDetails. If the connection fails due to an Access Denied error, Auth is queried to
// determine if per-session MFA is required for the node. If it is required then the MFA
// ceremony is performed and another connection is attempted with the freshly minted
// certificates. If it is not required, then the original Access Denied error from the node
// is returned.
// NodeDetails. Connecting is attempted both with the already provisioned certificates and
// if per session mfa is required, after completing the mfa ceremony. In the event that both
// fail the error from the connection attempt with the already provisioned certificates will
// be returned. The client from whichever attempt succeeds first will be returned.
func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) {
node := nodeName(nodeDetails.Addr)
ctx, span := tc.Tracer.Start(
Expand All @@ -1483,67 +1482,164 @@ func (tc *TeleportClient) ConnectToNode(ctx context.Context, clt *ClusterClient,
)
defer span.End()

sshConfig := clt.ProxyClient.SSHConfig(user)

// if mfa is required generate new config after
// performing the mfa ceremony
// if per-session mfa is required, perform the mfa ceremony to get
// new certificates and use them to connect.
if nodeDetails.MFACheck != nil && nodeDetails.MFACheck.Required {
cfg, err := clt.SessionSSHConfig(ctx, user, nodeDetails)
clt, err := tc.connectToNodeWithMFA(ctx, clt, nodeDetails, user)
return clt, trace.Wrap(err)
}

type clientRes struct {
clt *NodeClient
err error
}

directResultC := make(chan clientRes, 1)
mfaResultC := make(chan clientRes, 1)

// use a child context so the goroutines can terminate the other if they succeed
directCtx, directCancel := context.WithCancel(ctx)
mfaCtx, mfaCancel := context.WithCancel(ctx)
go func() {
ctx, span := tc.Tracer.Start(
directCtx,
"teleportClient/connectToNode",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.String("cluster", nodeDetails.Cluster),
attribute.String("node", node),
),
)
defer span.End()

// try connecting to the node with the certs we already have
conn, details, err := clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent)
if err != nil {
return nil, trace.Wrap(err)
directResultC <- clientRes{err: err}
return
}

sshConfig = cfg
sshConfig := clt.ProxyClient.SSHConfig(user)
clt, err := NewNodeClient(ctx, sshConfig, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS)
directResultC <- clientRes{clt: clt, err: err}
}()

go func() {
// try performing mfa and then connecting with the single use certs
clt, err := tc.connectToNodeWithMFA(mfaCtx, clt, nodeDetails, user)
mfaResultC <- clientRes{clt: clt, err: err}
}()

var directErr, mfaErr error
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
mfaCancel()
directCancel()
return nil, ctx.Err()
case res := <-directResultC:
if res.clt != nil {
mfaCancel()
res.clt.AddCancel(directCancel)
return res.clt, nil
}

directErr = res.err
case res := <-mfaResultC:
if res.clt != nil {
directCancel()
res.clt.AddCancel(mfaCancel)
return res.clt, nil
}

mfaErr = res.err
}
}

// try connecting to the node
conn, details, err := clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent)
if err != nil {
return nil, trace.Wrap(err)
mfaCancel()
directCancel()

// Only return the error from connecting with mfa if the error
// originates from the mfa ceremony. If mfa is not required then
// the error from the direct connection to the node must be returned.
if mfaErr != nil && !errors.Is(mfaErr, MFARequiredUnknownErr{}) {
return nil, trace.Wrap(mfaErr)
}

nodeClient, connectErr := NewNodeClient(ctx, sshConfig, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS)
switch {
case connectErr == nil: // no error return client
return nodeClient, nil
case nodeDetails.MFACheck != nil: // per-session mfa ceremony was already performed, return the results
return nodeClient, trace.Wrap(connectErr)
case connectErr != nil && !trace.IsAccessDenied(connectErr): // catastrophic error, return it
return nil, trace.Wrap(connectErr)
return nil, trace.Wrap(directErr)
}

// MFARequiredUnknownErr indicates that connections to an instance failed
// due to being unable to determine if mfa is required
type MFARequiredUnknownErr struct {
err error
}

// MFARequiredUnknown creates a new MFARequiredUnknownErr that wraps the
// error encountered attempting to determine if the mfa ceremony should proceed.
func MFARequiredUnknown(err error) error {
return MFARequiredUnknownErr{err: err}
}

// Error returns the error string of the wrapped error if one exists.
func (m MFARequiredUnknownErr) Error() string {
if m.err == nil {
return ""
}

// access was denied, determine if it was because per-session mfa is required
nodeDetails.MFACheck, err = clt.AuthClient.IsMFARequired(ctx, &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_Node{
Node: &proto.NodeLogin{
Node: node,
Login: tc.HostLogin,
},
},
})
if err != nil {
log.Warnf("Unable to determine if session mfa is required: %v", err)
return nil, trace.Wrap(connectErr)
return m.err.Error()
}

// Unwrap returns the underlying error from checking if an mfa
// ceremony should have been performed.
func (m MFARequiredUnknownErr) Unwrap() error {
return m.err
}

// Is determines if the provided error is an MFARequiredUnknownErr.
func (m MFARequiredUnknownErr) Is(err error) bool {
switch err.(type) {
case MFARequiredUnknownErr:
return true
case *MFARequiredUnknownErr:
return true
default:
return false
}
}

// connectToNodeWithMFA checks if per session mfa is required to connect to the target host, and
// if it is required, then the mfa ceremony is attempted. The target host is dialed once the ceremony
// completes and new certificates are retrieved.
func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *ClusterClient, nodeDetails NodeDetails, user string) (*NodeClient, error) {
node := nodeName(nodeDetails.Addr)
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/connectToNodeWithMFA",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.String("cluster", nodeDetails.Cluster),
attribute.String("node", node),
),
)
defer span.End()

// per-session mfa isn't required, the user simply does not
// have access to the provided node
if !nodeDetails.MFACheck.Required {
return nil, trace.Wrap(connectErr)
if nodeDetails.MFACheck != nil && !nodeDetails.MFACheck.Required {
return nil, trace.Wrap(MFARequiredUnknown(trace.AccessDenied("no access to %s", nodeDetails.Addr)))
}

// generate new config after performing the mfa ceremony
// per-session mfa is required, perform the mfa ceremony
cfg, err := clt.SessionSSHConfig(ctx, user, nodeDetails)
if err != nil {
return nil, trace.Wrap(err)
}

conn, details, err = clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent)
conn, details, err := clt.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, tc.localAgent.ExtendedAgent)
if err != nil {
return nil, trace.Wrap(err)
}

nodeClient, err = NewNodeClient(ctx, cfg, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS)
nodeClient, err := NewNodeClient(ctx, cfg, conn, nodeDetails.ProxyFormat(), nodeDetails.Addr, tc, details.FIPS)
return nodeClient, trace.Wrap(err)
}

Expand Down
45 changes: 44 additions & 1 deletion lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -88,6 +89,36 @@ type NodeClient struct {
TC *TeleportClient
OnMFA func()
FIPSEnabled bool

mu sync.Mutex
closers []io.Closer
}

// AddCloser adds an [io.Closer] that will be closed when the
// client is closed.
func (c *NodeClient) AddCloser(closer io.Closer) {
c.mu.Lock()
defer c.mu.Unlock()

c.closers = append(c.closers, closer)
}

type closerFunc func() error

func (f closerFunc) Close() error {
return f()
}

// AddCancel adds a [context.CancelFunc] that will be canceled when the
// client is closed.
func (c *NodeClient) AddCancel(cancel context.CancelFunc) {
c.mu.Lock()
defer c.mu.Unlock()

c.closers = append(c.closers, closerFunc(func() error {
cancel()
return nil
}))
}

// ClusterName returns the name of the cluster the proxy is a member of.
Expand Down Expand Up @@ -2037,7 +2068,19 @@ func (c *NodeClient) GetRemoteTerminalSize(ctx context.Context, sessionID string

// Close closes client and it's operations
func (c *NodeClient) Close() error {
return c.Client.Close()
c.mu.Lock()
defer c.mu.Unlock()

var errors []error
for _, closer := range c.closers {
errors = append(errors, closer.Close())
}

c.closers = nil

errors = append(errors, c.Client.Close())

return trace.NewAggregate(errors...)
}

// localAgent returns for the Teleport client's local agent.
Expand Down
9 changes: 2 additions & 7 deletions lib/client/cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe

key, err := c.tc.localAgent.GetKey(target.Cluster, WithAllCerts...)
if err != nil {
if trace.IsNotFound(err) {
// Either running inside the web UI in a proxy or using an identity
// file. Fall back to whatever AuthMethod we currently have.
return sshConfig, nil
}
return nil, trace.Wrap(err)
return nil, trace.Wrap(MFARequiredUnknown(err))
}

params := ReissueParams{
Expand All @@ -93,7 +88,7 @@ func (c *ClusterClient) SessionSSHConfig(ctx context.Context, user string, targe
if target.MFACheck == nil {
check, err := c.AuthClient.IsMFARequired(ctx, params.isMFARequiredRequest(c.tc.HostLogin))
if err != nil {
return nil, trace.Wrap(err)
return nil, trace.Wrap(MFARequiredUnknown(err))
}
target.MFACheck = check
}
Expand Down
1 change: 1 addition & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3849,6 +3849,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}

tlscfg := serverTLSConfig.Clone()
setupTLSConfigClientCAsForCluster(tlscfg, accessPoint, clusterName)
tlscfg.ClientAuth = tls.RequireAndVerifyClientCert
if lib.IsInsecureDevMode() {
tlscfg.InsecureSkipVerify = true
Expand Down

0 comments on commit 3113cd2

Please sign in to comment.