Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abort reverse tunnel connections early if the proxy is already claimed #27683

Merged
merged 5 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/utils/sshutils/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type HostKeyCallbackConfig struct {
// FIPS allows to set FIPS mode which will validate algorithms.
FIPS bool
// OnCheckCert is called on SSH certificate validation.
OnCheckCert func(*ssh.Certificate)
OnCheckCert func(*ssh.Certificate) error
// Clock is used to set the Checker Time
Clock clockwork.Clock
}
Expand Down
10 changes: 7 additions & 3 deletions api/utils/sshutils/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type CertChecker struct {
FIPS bool

// OnCheckCert is called when validating host certificate.
OnCheckCert func(*ssh.Certificate)
OnCheckCert func(*ssh.Certificate) error
}

// Authenticate checks the validity of a user certificate.
Expand Down Expand Up @@ -67,7 +67,9 @@ func (c *CertChecker) CheckCert(principal string, cert *ssh.Certificate) error {
}

if c.OnCheckCert != nil {
c.OnCheckCert(cert)
if err := c.OnCheckCert(cert); err != nil {
return trace.Wrap(err)
}
}

return nil
Expand All @@ -86,7 +88,9 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key ssh.PublicK
}

if cert, ok := key.(*ssh.Certificate); ok && c.OnCheckCert != nil {
c.OnCheckCert(cert)
if err := c.OnCheckCert(cert); err != nil {
return trace.Wrap(err)
}
}

return nil
Expand Down
92 changes: 47 additions & 45 deletions lib/reversetunnel/agent_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,61 +41,63 @@ type agentDialer struct {
fips bool
options []proxy.DialerOptionFunc
log logrus.FieldLogger
isClaimed func(principals ...string) bool
}

// DialContext creates an ssh connection to the given address.
func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHClient, error) {
// Create a dialer (that respects HTTP proxies) and connect to remote host.
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
return nil, trace.Wrap(err)
}

for _, authMethod := range d.authMethods {
// Create a dialer (that respects HTTP proxies) and connect to remote host.
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
continue
}

principals := make([]string, 0)
callback, err := apisshutils.NewHostKeyCallback(
apisshutils.HostKeyCallbackConfig{
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) {
principals = c.ValidPrincipals
},
FIPS: d.fips,
})
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
continue
}

// Build a new client connection. This is done to get access to incoming
// global requests which dialer.Dial would not provide.
conn, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr.Addr, &ssh.ClientConfig{
User: d.username,
Auth: []ssh.AuthMethod{authMethod},
HostKeyCallback: callback,
Timeout: apidefaults.DefaultIOTimeout,
var principals []string
callback, err := apisshutils.NewHostKeyCallback(
apisshutils.HostKeyCallbackConfig{
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) error {
if d.isClaimed != nil && d.isClaimed(c.ValidPrincipals...) {
d.log.Debugf("Aborting SSH handshake because the proxy %q is already claimed by some other agent.", c.ValidPrincipals[0])
return trace.Errorf("proxy already claimed")
}

principals = c.ValidPrincipals
return nil
},
FIPS: d.fips,
})
if err != nil {
d.log.WithError(err).Debugf("Failed to create client to %v.", addr.Addr)
continue
}
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
return nil, trace.Wrap(err)
}

emptyRequests := make(chan *ssh.Request)
close(emptyRequests)
// Build a new client connection. This is done to get access to incoming
// global requests which dialer.Dial would not provide.
conn, chans, reqs, err := tracessh.NewClientConn(ctx, pconn, addr.Addr, &ssh.ClientConfig{
User: d.username,
Auth: d.authMethods,
HostKeyCallback: callback,
Timeout: apidefaults.DefaultIOTimeout,
})
if err != nil {
d.log.WithError(err).Debugf("Failed to create client to %v.", addr.Addr)
return nil, trace.Wrap(err)
}

client := tracessh.NewClient(conn, chans, emptyRequests)
emptyRequests := make(chan *ssh.Request)
close(emptyRequests)
espadolini marked this conversation as resolved.
Show resolved Hide resolved

return &sshClient{
Client: client,
requests: reqs,
newChannels: chans,
principals: principals,
}, nil
}
client := tracessh.NewClient(conn, chans, emptyRequests)

return nil, trace.BadParameter("failed to dial: all auth methods failed")
return &sshClient{
Client: client,
requests: reqs,
newChannels: chans,
principals: principals,
}, nil
}

// hostCheckerFunc wraps a apisshutils.CheckersGetter function with a context.
Expand Down
1 change: 1 addition & 0 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease
options: options,
username: p.HostUUID,
log: p.log,
isClaimed: p.tracker.IsClaimed,
}

agent, err := newAgent(agentConfig{
Expand Down
20 changes: 19 additions & 1 deletion lib/reversetunnel/track/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ func (t *Tracker) tick() {
}
t.wp.Set(uint64(count))
}

}

func (t *Tracker) getOrCreate() *proxySet {
Expand Down Expand Up @@ -210,6 +209,20 @@ func (t *Tracker) release(principals ...string) {
t.sets.release(principals...)
}

// IsClaimed returns true if the proxy identified by the given principals is
// already claimed by some other agent at the time of the call. Keep in mind
// that a return value of false doesn't imply that a subsequent call to Claim is
// guaranteed to succeed, as other goroutines might claim the same proxy between
// IsClaimed and Claim.
func (t *Tracker) IsClaimed(principals ...string) bool {
t.mu.Lock()
defer t.mu.Unlock()
if t.sets == nil {
return false
}
return t.sets.isClaimed(principals...)
}

type entry struct {
lastSeen time.Time
claimed bool
Expand Down Expand Up @@ -251,6 +264,11 @@ func (p *proxySet) release(principals ...string) {
}
}

func (p *proxySet) isClaimed(principals ...string) bool {
proxy := p.resolveName(principals)
return p.proxies[proxy].claimed
}

func (p *proxySet) markSeen(t time.Time, proxy string) {
e, ok := p.proxies[proxy]
if !ok {
Expand Down
27 changes: 26 additions & 1 deletion lib/reversetunnel/track/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ func TestUUIDHandling(t *testing.T) {

t.Logf("Successfully claimed proxy")
<-ctx.Done()

}()
// Wait for proxy to be claimed
Wait:
Expand Down Expand Up @@ -311,3 +310,29 @@ Wait:
}
}
}

func TestIsClaimed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

tracker, err := New(ctx, Config{ClusterName: "test-cluster"})
require.NoError(t, err)

tracker.Start()
t.Cleanup(tracker.StopAll)

tracker.TrackExpected("proxy1", "proxy2")
require.False(t, tracker.IsClaimed("proxy1.test-cluster"))

unclaim, ok := tracker.Claim("proxy1.test-cluster")
require.True(t, ok)

require.True(t, tracker.IsClaimed("proxy1"))
require.True(t, tracker.IsClaimed("proxy1.test-cluster"))
require.False(t, tracker.IsClaimed("proxy2"))

unclaim()

require.False(t, tracker.IsClaimed("proxy1"))
require.False(t, tracker.IsClaimed("proxy2"))
}