Skip to content

Commit

Permalink
Add -R flag to tsh ssh (#38829)
Browse files Browse the repository at this point in the history
  • Loading branch information
atburke committed Mar 16, 2024
1 parent 2ce7d3e commit 5f3d817
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 22 deletions.
79 changes: 57 additions & 22 deletions integration/port_forwarding_test.go
Expand Up @@ -21,6 +21,7 @@ package integration
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -77,6 +78,24 @@ func waitForSessionToBeEstablished(ctx context.Context, namespace string, site a
}
}

// testPingLocalServer checks whether or not an HTTP server is serving on
// localhost at the given port.
func testPingLocalServer(t *testing.T, port int, expectSuccess bool) {
addr := fmt.Sprintf("http://%s:%d/", "localhost", port)
r, err := http.Get(addr)

if r != nil {
r.Body.Close()
}

if expectSuccess {
require.NoError(t, err)
require.NotNil(t, r)
} else {
require.Error(t, err)
}
}

func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
invalidOSLogin := uuid.NewString()[:12]
notFound := false
Expand Down Expand Up @@ -214,18 +233,33 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {

site := instance.GetSiteAPI(helpers.Site)

// ...and a running dummy server
remoteSvr := httptest.NewServer(http.HandlerFunc(
// ...and a pair of running dummy servers
handler := http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, World"))
}))
})
remoteListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
remoteSvr := httptest.NewUnstartedServer(handler)
remoteSvr.Listener = remoteListener
remoteSvr.Start()
defer remoteSvr.Close()

localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localSvr := httptest.NewUnstartedServer(handler)
localSvr.Listener = localListener
localSvr.Start()
defer localSvr.Close()

// ... and a client connection that was launched with port
// forwarding enabled to that dummy server
localPort := newPortValue()
remotePort, err := extractPort(remoteSvr)
// forwarding enabled to the dummy servers
localClientPort := newPortValue()
remoteServerPort, err := extractPort(remoteSvr)
require.NoError(t, err)
remoteClientPort := newPortValue()
localServerPort, err := extractPort(localSvr)
require.NoError(t, err)

nodeSSHPort := helpers.Port(t, instance.SSH)
Expand All @@ -239,9 +273,17 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
cl.Config.LocalForwardPorts = []client.ForwardedPort{
{
SrcIP: "127.0.0.1",
SrcPort: localPort,
SrcPort: localClientPort,
DestHost: "localhost",
DestPort: remotePort,
DestPort: remoteServerPort,
},
}
cl.Config.RemoteForwardPorts = []client.ForwardedPort{
{
SrcIP: "localhost",
SrcPort: remoteClientPort,
DestHost: "127.0.0.1",
DestPort: localServerPort,
},
}
term := NewTerminal(250)
Expand All @@ -259,20 +301,13 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// When everything is *finally* set up, and I attempt to use the
// forwarded connection
localURL := fmt.Sprintf("http://%s:%d/", "localhost", localPort)
r, err := http.Get(localURL)

if r != nil {
r.Body.Close()
}

if tt.expectSuccess {
require.NoError(t, err)
require.NotNil(t, r)
} else {
require.Error(t, err)
}
// forwarded connections
t.Run("local forwarding", func(t *testing.T) {
testPingLocalServer(t, localClientPort, tt.expectSuccess)
})
t.Run("remote forwarding", func(t *testing.T) {
testPingLocalServer(t, remoteClientPort, tt.expectSuccess)
})
})
}
}
22 changes: 22 additions & 0 deletions lib/client/api.go
Expand Up @@ -128,6 +128,8 @@ const (
ForwardAgentLocal
)

const remoteForwardUnsupportedMessage = "ssh: tcpip-forward request denied by peer"

var log = logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentClient,
})
Expand Down Expand Up @@ -312,6 +314,10 @@ type Config struct {
// port forwarding (parameters to -D ssh flag).
DynamicForwardedPorts DynamicForwardedPorts

// RemoteForwardPorts are the list of ports the remote connection listens on
// for remote port forwarding (parameters to -R ssh flag).
RemoteForwardPorts ForwardedPorts

// HostKeyCallback will be called to check host keys of the remote
// node, if not specified will be using CheckHostSignature function
// that uses local cache to validate hosts
Expand Down Expand Up @@ -1953,6 +1959,22 @@ func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *N
}
go nodeClient.dynamicListenAndForward(ctx, socket, addr)
}
for _, fp := range tc.Config.RemoteForwardPorts {
addr := net.JoinHostPort(fp.SrcIP, strconv.Itoa(fp.SrcPort))
socket, err := nodeClient.Client.Listen("tcp", addr)
if err != nil {
// We log the error here instead of returning it to be consistent with
// the other port forwarding methods, which don't stop the session
// if forwarding fails.
message := fmt.Sprintf("Failed to bind on remote host to %v: %v.", addr, err)
if strings.Contains(err.Error(), remoteForwardUnsupportedMessage) {
message = "Node does not support remote port forwarding (-R)."
}
log.Error(message)
} else {
go nodeClient.remoteListenAndForward(ctx, socket, net.JoinHostPort(fp.DestHost, strconv.Itoa(fp.DestPort)), addr)
}
}
return nil
}

Expand Down
25 changes: 25 additions & 0 deletions lib/client/client.go
Expand Up @@ -2016,6 +2016,31 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene
log.WithError(ctx.Err()).Infof("Shutting down dynamic port forwarding.")
}

// remoteListenAndForward requests a listening socket and forwards all incoming
// commands to the local address through the SSH tunnel.
func (c *NodeClient) remoteListenAndForward(ctx context.Context, ln net.Listener, localAddr, remoteAddr string) {
defer ln.Close()
log := log.WithField("localAddr", localAddr).WithField("remoteAddr", remoteAddr)
log.Infof("Starting remote port forwarding")

for ctx.Err() == nil {
conn, err := acceptWithContext(ctx, ln)
if err != nil {
if ctx.Err() == nil {
log.WithError(err).Errorf("Remote port forwarding failed.")
}
continue
}

go func() {
if err := proxyConnection(ctx, conn, localAddr, &net.Dialer{}); err != nil {
log.WithError(err).Warnf("Failed to proxy connection")
}
}()
}
log.WithError(ctx.Err()).Infof("Shutting down remote port forwarding.")
}

// GetRemoteTerminalSize fetches the terminal size of a given SSH session.
func (c *NodeClient) GetRemoteTerminalSize(ctx context.Context, sessionID string) (*term.Winsize, error) {
ctx, span := c.Tracer.Start(
Expand Down
11 changes: 11 additions & 0 deletions tool/tsh/common/tsh.go
Expand Up @@ -177,6 +177,8 @@ type CLIConf struct {
// DynamicForwardedPorts is port forwarding using SOCKS5. It is similar to
// "ssh -D 8080 example.com".
DynamicForwardedPorts []string
// -R flag for ssh. Remote port forwarding like 'ssh -R 80:localhost:80 -R 443:localhost:443'
RemoteForwardPorts []string
// ForwardAgent agent to target node. Equivalent of -A for OpenSSH.
ForwardAgent bool
// ProxyJump is an optional -J flag pointing to the list of jumphosts,
Expand Down Expand Up @@ -740,6 +742,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
ssh.Flag("forward-agent", "Forward agent to target node").Short('A').BoolVar(&cf.ForwardAgent)
ssh.Flag("forward", "Forward localhost connections to remote server").Short('L').StringsVar(&cf.LocalForwardPorts)
ssh.Flag("dynamic-forward", "Forward localhost connections to remote server using SOCKS5").Short('D').StringsVar(&cf.DynamicForwardedPorts)
ssh.Flag("remote-forward", "Forward remote connections to localhost").Short('R').StringsVar(&cf.RemoteForwardPorts)
ssh.Flag("local", "Execute command on localhost after connecting to SSH node").Default("false").BoolVar(&cf.LocalExec)
ssh.Flag("tty", "Allocate TTY").Short('t').BoolVar(&cf.Interactive)
ssh.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName)
Expand Down Expand Up @@ -3639,6 +3642,11 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
return nil, trace.Wrap(err)
}

rPorts, err := client.ParsePortForwardSpec(cf.RemoteForwardPorts)
if err != nil {
return nil, trace.Wrap(err)
}

// 1: start with the defaults
c := client.MakeDefaultConfig()

Expand Down Expand Up @@ -3784,6 +3792,9 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
if len(dPorts) > 0 {
c.DynamicForwardedPorts = dPorts
}
if len(rPorts) > 0 {
c.RemoteForwardPorts = rPorts
}
if cf.SiteName != "" {
c.SiteName = cf.SiteName
}
Expand Down

0 comments on commit 5f3d817

Please sign in to comment.