Skip to content

Commit

Permalink
Abort reverse tunnel connections early if the proxy is already claimed (
Browse files Browse the repository at this point in the history
#27683)

* Don't loop over authMethods in (*agentDialer)DialContext

* Make HostKeyCallbackConfig.OnCheckCert fallible

* Add a way to check if a proxy is claimed

* Abort reverse tunnel conns early if the proxy is already claimed

* Document the pre-closed global request channel
  • Loading branch information
espadolini committed Jun 9, 2023
1 parent 0b2cef8 commit 7ebaf6a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 51 deletions.
2 changes: 1 addition & 1 deletion api/utils/sshutils/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type HostKeyCallbackConfig struct {
// FIPS allows to set FIPS mode which will validate algorithms.
FIPS bool
// OnCheckCert is called on SSH certificate validation.
OnCheckCert func(*ssh.Certificate)
OnCheckCert func(*ssh.Certificate) error
// Clock is used to set the Checker Time
Clock clockwork.Clock
}
Expand Down
10 changes: 7 additions & 3 deletions api/utils/sshutils/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type CertChecker struct {
FIPS bool

// OnCheckCert is called when validating host certificate.
OnCheckCert func(*ssh.Certificate)
OnCheckCert func(*ssh.Certificate) error
}

// Authenticate checks the validity of a user certificate.
Expand Down Expand Up @@ -67,7 +67,9 @@ func (c *CertChecker) CheckCert(principal string, cert *ssh.Certificate) error {
}

if c.OnCheckCert != nil {
c.OnCheckCert(cert)
if err := c.OnCheckCert(cert); err != nil {
return trace.Wrap(err)
}
}

return nil
Expand All @@ -86,7 +88,9 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key ssh.PublicK
}

if cert, ok := key.(*ssh.Certificate); ok && c.OnCheckCert != nil {
c.OnCheckCert(cert)
if err := c.OnCheckCert(cert); err != nil {
return trace.Wrap(err)
}
}

return nil
Expand Down
95 changes: 50 additions & 45 deletions lib/reversetunnel/agent_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,61 +41,66 @@ type agentDialer struct {
fips bool
options []proxy.DialerOptionFunc
log logrus.FieldLogger
isClaimed func(principals ...string) bool
}

// DialContext creates an ssh connection to the given address.
func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHClient, error) {
// Create a dialer (that respects HTTP proxies) and connect to remote host.
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
return nil, trace.Wrap(err)
}

for _, authMethod := range d.authMethods {
// Create a dialer (that respects HTTP proxies) and connect to remote host.
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
continue
}

principals := make([]string, 0)
callback, err := apisshutils.NewHostKeyCallback(
apisshutils.HostKeyCallbackConfig{
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) {
principals = c.ValidPrincipals
},
FIPS: d.fips,
})
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
continue
}

// Build a new client connection. This is done to get access to incoming
// global requests which dialer.Dial would not provide.
conn, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr.Addr, &ssh.ClientConfig{
User: d.username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: callback,
Timeout: apidefaults.DefaultIOTimeout,
var principals []string
callback, err := apisshutils.NewHostKeyCallback(
apisshutils.HostKeyCallbackConfig{
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) error {
if d.isClaimed != nil && d.isClaimed(c.ValidPrincipals...) {
d.log.Debugf("Aborting SSH handshake because the proxy %q is already claimed by some other agent.", c.ValidPrincipals[0])
return trace.Errorf("proxy already claimed")
}

principals = c.ValidPrincipals
return nil
},
FIPS: d.fips,
})
if err != nil {
d.log.WithError(err).Debugf("Failed to create client to %v.", addr.Addr)
continue
}
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
return nil, trace.Wrap(err)
}

emptyRequests := make(chan *ssh.Request)
close(emptyRequests)
// Build a new client connection. This is done to get access to incoming
// global requests which dialer.Dial would not provide.
conn, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr.Addr, &ssh.ClientConfig{
User: d.username,
Auth: d.authMethods,
HostKeyCallback: callback,
Timeout: apidefaults.DefaultIOTimeout,
})
if err != nil {
d.log.WithError(err).Debugf("Failed to create client to %v.", addr.Addr)
return nil, trace.Wrap(err)
}

client := tracessh.NewClient(conn, chans, emptyRequests)
// ssh.NewClient will loop over the global requests channel in a goroutine,
// rejecting all requests; we want to handle the global requests ourselves,
// so we feed it a closed channel to have the goroutine exit immediately.
emptyRequests := make(chan *ssh.Request)
close(emptyRequests)

return &sshClient{
Client: client,
requests: reqs,
newChannels: chans,
principals: principals,
}, nil
}
client := tracessh.NewClient(conn, chans, emptyRequests)

return nil, trace.BadParameter("failed to dial: all auth methods failed")
return &sshClient{
Client: client,
requests: reqs,
newChannels: chans,
principals: principals,
}, nil
}

// hostCheckerFunc wraps a apisshutils.CheckersGetter function with a context.
Expand Down
1 change: 1 addition & 0 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease
options: options,
username: p.HostUUID,
log: p.log,
isClaimed: p.tracker.IsClaimed,
}

agent, err := newAgent(agentConfig{
Expand Down
20 changes: 19 additions & 1 deletion lib/reversetunnel/track/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ func (t *Tracker) tick() {
}
t.wp.Set(uint64(count))
}

}

func (t *Tracker) getOrCreate() *proxySet {
Expand Down Expand Up @@ -210,6 +209,20 @@ func (t *Tracker) release(principals ...string) {
t.sets.release(principals...)
}

// IsClaimed returns true if the proxy identified by the given principals is
// already claimed by some other agent at the time of the call. Keep in mind
// that a return value of false doesn't imply that a subsequent call to Claim is
// guaranteed to succeed, as other goroutines might claim the same proxy between
// IsClaimed and Claim.
func (t *Tracker) IsClaimed(principals ...string) bool {
t.mu.Lock()
defer t.mu.Unlock()
if t.sets == nil {
return false
}
return t.sets.isClaimed(principals...)
}

type entry struct {
lastSeen time.Time
claimed bool
Expand Down Expand Up @@ -251,6 +264,11 @@ func (p *proxySet) release(principals ...string) {
}
}

func (p *proxySet) isClaimed(principals ...string) bool {
proxy := p.resolveName(principals)
return p.proxies[proxy].claimed
}

func (p *proxySet) markSeen(t time.Time, proxy string) {
e, ok := p.proxies[proxy]
if !ok {
Expand Down
27 changes: 26 additions & 1 deletion lib/reversetunnel/track/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ func TestUUIDHandling(t *testing.T) {

t.Logf("Successfully claimed proxy")
<-ctx.Done()

}()
// Wait for proxy to be claimed
Wait:
Expand Down Expand Up @@ -311,3 +310,29 @@ Wait:
}
}
}

func TestIsClaimed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

tracker, err := New(ctx, Config{ClusterName: "test-cluster"})
require.NoError(t, err)

tracker.Start()
t.Cleanup(tracker.StopAll)

tracker.TrackExpected("proxy1", "proxy2")
require.False(t, tracker.IsClaimed("proxy1.test-cluster"))

unclaim, ok := tracker.Claim("proxy1.test-cluster")
require.True(t, ok)

require.True(t, tracker.IsClaimed("proxy1"))
require.True(t, tracker.IsClaimed("proxy1.test-cluster"))
require.False(t, tracker.IsClaimed("proxy2"))

unclaim()

require.False(t, tracker.IsClaimed("proxy1"))
require.False(t, tracker.IsClaimed("proxy2"))
}

0 comments on commit 7ebaf6a

Please sign in to comment.