Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Simplify single destination reverse tunnels #36131

Merged
merged 1 commit into from Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 10 additions & 12 deletions lib/reversetunnel/agent.go
Expand Up @@ -60,10 +60,10 @@ const (
// AgentStateCallback is called when an agent's state changes.
type AgentStateCallback func(AgentState)

// transporter handles the creation of new transports over ssh.
type transporter interface {
// Transport creates a new transport.
transport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn) *transport
// transportHandler handles the creation of new transports over ssh.
type transportHandler interface {
// handleTransport runs the receiver of a teleport-transport channel.
handleTransport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn)
}

// sshDialer is an ssh dialer that returns an SSHClient
Expand Down Expand Up @@ -100,8 +100,8 @@ type agentConfig struct {
stateCallback AgentStateCallback
// sshDialer creates a new ssh connection.
sshDialer sshDialer
// transporter creates a new transport.
transporter transporter
// transportHandler handles teleport-transport channels.
transportHandler transportHandler
// versionGetter gets the connected auth server version.
versionGetter versionGetter
// tracker tracks existing proxies.
Expand All @@ -128,8 +128,8 @@ func (c *agentConfig) checkAndSetDefaults() error {
if c.sshDialer == nil {
return trace.BadParameter("missing parameter sshDialer")
}
if c.transporter == nil {
return trace.BadParameter("missing parameter transporter")
if c.transportHandler == nil {
return trace.BadParameter("missing parameter transportHandler")
}
if c.versionGetter == nil {
return trace.BadParameter("missing parameter versionGetter")
Expand Down Expand Up @@ -577,12 +577,10 @@ func (a *agent) handleDrainChannels() error {
continue
}

t := a.transporter.transport(a.ctx, ch, req, a.client)

a.drainWG.Add(1)
go func() {
t.start()
a.drainWG.Done()
defer a.drainWG.Done()
a.transportHandler.handleTransport(a.ctx, ch, req, a.client)
}()

}
Expand Down
17 changes: 8 additions & 9 deletions lib/reversetunnel/agent_test.go
Expand Up @@ -140,8 +140,7 @@ type mockAgentInjection struct {
client SSHClient
}

func (m *mockAgentInjection) transport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn) *transport {
return &transport{}
func (m *mockAgentInjection) handleTransport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn) {
}

func (m *mockAgentInjection) DialContext(context.Context, utils.NetAddr) (SSHClient, error) {
Expand Down Expand Up @@ -174,13 +173,13 @@ func testAgent(t *testing.T) (*agent, *mockSSHClient) {
}

agent, err := newAgent(agentConfig{
keepAlive: time.Millisecond * 100,
addr: addr,
transporter: inject,
sshDialer: inject,
versionGetter: inject,
tracker: tracker,
lease: lease,
keepAlive: time.Millisecond * 100,
addr: addr,
transportHandler: inject,
sshDialer: inject,
versionGetter: inject,
tracker: tracker,
lease: lease,
})
require.NoError(t, err, "Unexpected error during agent construction.")

Expand Down
69 changes: 64 additions & 5 deletions lib/reversetunnel/agentpool.go
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"sync"
Expand All @@ -35,6 +36,7 @@ import (
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/defaults"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/api/utils/sshutils"
Expand Down Expand Up @@ -461,7 +463,7 @@ func (p *AgentPool) newAgent(ctx context.Context, tracker *track.Tracker, lease
addr: *addr,
keepAlive: p.runtimeConfig.keepAliveInterval,
sshDialer: dialer,
transporter: p,
transportHandler: p,
versionGetter: p,
tracker: tracker,
lease: lease,
Expand Down Expand Up @@ -512,8 +514,13 @@ func (p *AgentPool) getVersion(ctx context.Context) (string, error) {
return pong.ServerVersion, nil
}

// transport creates a new transport instance.
func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn sshutils.Conn) *transport {
// handleTransport runs a new teleport-transport channel.
func (p *AgentPool) handleTransport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn sshutils.Conn) {
if !p.IsRemoteCluster {
p.handleLocalTransport(ctx, channel, requests, conn)
return
}

t := &transport{
closeContext: ctx,
component: p.Component,
Expand All @@ -537,11 +544,63 @@ func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests
// the leaf proxy to track sessions that are initiated via the root cluster. Without providing
// the user tracker the leaf cluster metrics will be incorrect and graceful shutdown will not
// wait for user sessions to be terminated prior to proceeding with the shutdown operation.
if p.IsRemoteCluster && p.ReverseTunnelServer != nil {
if p.ReverseTunnelServer != nil {
t.trackUserConnection = p.ReverseTunnelServer.TrackUserConnection
}

return t
t.start()
}

func (p *AgentPool) handleLocalTransport(ctx context.Context, channel ssh.Channel, reqC <-chan *ssh.Request, sconn sshutils.Conn) {
defer channel.Close()
go io.Copy(io.Discard, channel.Stderr())

// the only valid teleport-transport-dial request here is to reach the local service
var req *ssh.Request
select {
case <-ctx.Done():
go ssh.DiscardRequests(reqC)
return
case <-time.After(apidefaults.DefaultIOTimeout):
go ssh.DiscardRequests(reqC)
p.log.Warn("Timed out waiting for transport dial request.")
return
case r, ok := <-reqC:
if !ok {
return
}
go ssh.DiscardRequests(reqC)
req = r
}

// sconn should never be nil, but it's sourced from the agent state and
// starts as nil, and the original transport code checked against it
if sconn == nil || p.Server == nil {
p.log.Error("Missing client or server (this is a bug).")
fmt.Fprintf(channel.Stderr(), "internal server error")
req.Reply(false, nil)
return
}

if err := req.Reply(true, nil); err != nil {
p.log.Errorf("Failed to respond to dial request: %v.", err)
return
}

var conn net.Conn = sshutils.NewChConn(sconn, channel)

dialReq := parseDialReq(req.Payload)
switch dialReq.Address {
case reversetunnelclient.LocalNode, reversetunnelclient.LocalKubernetes, reversetunnelclient.LocalWindowsDesktop:
default:
p.log.WithField("address", dialReq.Address).
Warn("Received dial request for unexpected address, routing to the local service anyway.")
}
if src, err := utils.ParseAddr(dialReq.ClientSrcAddr); err == nil {
conn = utils.NewConnWithSrcAddr(conn, getTCPAddr(src))
}

p.Server.HandleConnection(conn)
}

// agentPoolRuntimeConfig contains configurations dynamically set and updated
Expand Down
97 changes: 81 additions & 16 deletions lib/reversetunnel/srv.go
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/retryutils"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
Expand Down Expand Up @@ -667,28 +668,92 @@ func (s *server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont
}

func (s *server) handleTransport(sconn *ssh.ServerConn, nch ssh.NewChannel) {
s.log.Debugf("Transport request: %v.", nch.ChannelType())
channel, requestCh, err := nch.Accept()
s.log.Debug("Received transport request.")
channel, requestC, err := nch.Accept()
if err != nil {
sconn.Close()
// avoid WithError to reduce log spam on network errors
s.log.Warnf("Failed to accept request: %v.", err)
return
}

t := &transport{
log: s.log,
closeContext: s.ctx,
authClient: s.LocalAccessPoint,
authServers: s.LocalAuthAddresses,
channel: channel,
requestCh: requestCh,
component: teleport.ComponentReverseTunnelServer,
localClusterName: s.ClusterName,
emitter: s.Emitter,
proxySigner: s.proxySigner,
sconn: sconn,
}
go t.start()
go s.handleTransportChannel(sconn, channel, requestC)
}

func (s *server) handleTransportChannel(sconn *ssh.ServerConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
defer ch.Close()
go io.Copy(io.Discard, ch.Stderr())

// the only valid teleport-transport-dial request here is to reach the auth server
var req *ssh.Request
select {
case <-s.ctx.Done():
go ssh.DiscardRequests(reqC)
return
case <-time.After(apidefaults.DefaultIOTimeout):
go ssh.DiscardRequests(reqC)
s.log.Warn("Timed out waiting for transport dial request.")
return
case r, ok := <-reqC:
if !ok {
return
}
go ssh.DiscardRequests(reqC)
req = r
}

dialReq := parseDialReq(req.Payload)
if dialReq.Address != constants.RemoteAuthServer {
s.log.WithField("address", dialReq.Address).
Warn("Received dial request for unexpected address, routing to the auth server anyway.")
}

authAddress := utils.ChooseRandomString(s.LocalAuthAddresses)
if authAddress == "" {
s.log.Error("No auth servers configured.")
fmt.Fprint(ch.Stderr(), "internal server error")
req.Reply(false, nil)
return
}

var proxyHeader []byte
clientSrcAddr := sconn.RemoteAddr()
clientDstAddr := sconn.LocalAddr()
if s.proxySigner != nil && clientSrcAddr != nil && clientDstAddr != nil {
h, err := s.proxySigner.SignPROXYHeader(clientSrcAddr, clientDstAddr)
if err != nil {
s.log.WithError(err).Error("Failed to create signed PROXY header.")
fmt.Fprint(ch.Stderr(), "internal server error")
req.Reply(false, nil)
}
proxyHeader = h
}

d := net.Dialer{Timeout: apidefaults.DefaultIOTimeout}
conn, err := d.DialContext(s.ctx, "tcp", authAddress)
if err != nil {
s.log.Errorf("Failed to dial auth: %v.", err)
fmt.Fprint(ch.Stderr(), "failed to dial auth server")
req.Reply(false, nil)
return
}
defer conn.Close()

_ = conn.SetWriteDeadline(time.Now().Add(apidefaults.DefaultIOTimeout))
if _, err := conn.Write(proxyHeader); err != nil {
s.log.Errorf("Failed to send PROXY header: %v.", err)
fmt.Fprint(ch.Stderr(), "failed to dial auth server")
req.Reply(false, nil)
return
}
_ = conn.SetWriteDeadline(time.Time{})

if err := req.Reply(true, nil); err != nil {
s.log.Errorf("Failed to respond to dial request: %v.", err)
return
}

_ = utils.ProxyConn(s.ctx, ch, conn)
}

// TODO(awly): unit test this
Expand Down