From b4239ad71928bfa5bb87204bf86a36644aeac08c Mon Sep 17 00:00:00 2001 From: David Pinheiro Date: Sun, 26 Jul 2020 17:02:23 -0700 Subject: [PATCH] Add support for ssh remote port forwarding This change adds support for a new type of ssh tunnel, the [remote port forwarding](https://www.ssh.com/ssh/tunneling/example#remote-forwarding). Other significant changes: * Adds a new TunnelType attribute to TunnelFlags so it can be passed to the Tunnel object, preparing it to support different types of port forwarding It also improves the description of `start local` command. * Hides how the `SSHChannel` objects are built under `tunnel.New`, opening room to create either ssh `local` or `remote` forwarding tunnels * Adds support for reading `RemoteForward` configuration from the ssh configuration file * Extends the ssh server implementation using for tests to handle "tcpip-forward" requests Partially implements #22. --- Makefile | 1 + alias/alias.go | 6 +- alias/alias_test.go | 32 ++-- cmd/add_alias.go | 4 +- cmd/root.go | 21 +-- cmd/start_local.go | 14 +- cmd/start_remote.go | 37 +++++ go.mod | 3 +- go.sum | 4 + test-env/README.md | 44 +++-- test-env/http-server/Dockerfile | 2 + test-env/ssh-server/Dockerfile | 3 +- tunnel/config.go | 52 +++--- tunnel/config_test.go | 29 +++- tunnel/example_test.go | 9 +- tunnel/key_test.go | 4 - tunnel/tunnel.go | 192 ++++++++++++++-------- tunnel/tunnel_test.go | 278 ++++++++++++++++++++------------ 18 files changed, 492 insertions(+), 243 deletions(-) create mode 100644 cmd/start_remote.go diff --git a/Makefile b/Makefile index ec1ca27..97363f7 100644 --- a/Makefile +++ b/Makefile @@ -48,6 +48,7 @@ mole-http: rm-mole-http --detach \ --network mole \ --ip 192.168.33.11 \ + --publish 8080:8080 \ --name mole_http mole_http:latest rm-mole-ssh: diff --git a/alias/alias.go b/alias/alias.go index 09dffe9..1924dca 100644 --- a/alias/alias.go +++ b/alias/alias.go @@ -34,6 +34,7 @@ const ( // TunnelFlags is a struct that holds all flags required to establish a ssh // port forwarding tunnel. type TunnelFlags struct { + TunnelType string Verbose bool Insecure bool Detach bool @@ -49,10 +50,10 @@ type TunnelFlags struct { } // ParseAlias translates a TunnelFlags object to an Alias object -func (tf TunnelFlags) ParseAlias(name, tunnelType string) *Alias { +func (tf TunnelFlags) ParseAlias(name string) *Alias { return &Alias{ Name: name, - TunnelType: tunnelType, + TunnelType: tf.TunnelType, Verbose: tf.Verbose, Insecure: tf.Insecure, Detach: tf.Detach, @@ -109,6 +110,7 @@ func (a Alias) ParseTunnelFlags() (*TunnelFlags, error) { tf := &TunnelFlags{} + tf.TunnelType = a.TunnelType tf.Verbose = a.Verbose tf.Insecure = a.Insecure tf.Detach = a.Detach diff --git a/alias/alias_test.go b/alias/alias_test.go index 9afa070..c0bd4d7 100644 --- a/alias/alias_test.go +++ b/alias/alias_test.go @@ -17,6 +17,7 @@ import ( func TestParseTunnelFlags(t *testing.T) { tests := []struct { + tunnelType string verbose bool insecure bool detach bool @@ -31,6 +32,7 @@ func TestParseTunnelFlags(t *testing.T) { timeout string }{ { + "local", true, true, true, @@ -45,6 +47,7 @@ func TestParseTunnelFlags(t *testing.T) { "1m0s", }, { + "local", true, false, true, @@ -62,6 +65,7 @@ func TestParseTunnelFlags(t *testing.T) { for id, test := range tests { ai := &alias.Alias{ + TunnelType: test.tunnelType, Verbose: test.verbose, Insecure: test.insecure, Detach: test.detach, @@ -81,58 +85,62 @@ func TestParseTunnelFlags(t *testing.T) { t.Errorf("%v\n", err) } + if test.tunnelType != tf.TunnelType { + t.Errorf("tunnelType doesn't match on test %d: expected: %s, value: %s", id, test.tunnelType, tf.TunnelType) + } + if test.verbose != tf.Verbose { - t.Errorf("verbose doesn't match for test %d: expected: %t, value: %t", id, test.verbose, tf.Verbose) + t.Errorf("verbose doesn't match on test %d: expected: %t, value: %t", id, test.verbose, tf.Verbose) } if test.insecure != tf.Insecure { - t.Errorf("insecure doesn't match for test %d: expected: %t, value: %t", id, test.insecure, tf.Insecure) + t.Errorf("insecure doesn't match on test %d: expected: %t, value: %t", id, test.insecure, tf.Insecure) } if test.detach != tf.Detach { - t.Errorf("detach doesn't match for test %d: expected: %t, value: %t", id, test.detach, tf.Detach) + t.Errorf("detach doesn't match on test %d: expected: %t, value: %t", id, test.detach, tf.Detach) } for i, tsrc := range test.source { src := tf.Source[i].String() if tsrc != src { - t.Errorf("source %d doesn't match for test %d: expected: %s, value: %s", id, i, tsrc, src) + t.Errorf("source %d doesn't match on test %d: expected: %s, value: %s", id, i, tsrc, src) } } for i, tdst := range test.destination { dst := tf.Destination[i].String() if tdst != dst { - t.Errorf("destination %d doesn't match for test %d: expected: %s, value: %s", id, i, tdst, dst) + t.Errorf("destination %d doesn't match on test %d: expected: %s, value: %s", id, i, tdst, dst) } } if test.server != tf.Server.String() { - t.Errorf("server doesn't match for test %d: expected: %s, value: %s", id, test.server, tf.Server.String()) + t.Errorf("server doesn't match on test %d: expected: %s, value: %s", id, test.server, tf.Server.String()) } if test.key != tf.Key { - t.Errorf("key doesn't match for test %d: expected: %s, value: %s", id, test.key, tf.Key) + t.Errorf("key doesn't match on test %d: expected: %s, value: %s", id, test.key, tf.Key) } if test.keepAliveInterval != tf.KeepAliveInterval.String() { - t.Errorf("keepAliveInterval doesn't match for test %d: expected: %s, value: %s", id, test.keepAliveInterval, tf.KeepAliveInterval.String()) + t.Errorf("keepAliveInterval doesn't match on test %d: expected: %s, value: %s", id, test.keepAliveInterval, tf.KeepAliveInterval.String()) } if test.connectionRetries != tf.ConnectionRetries { - t.Errorf("connectionRetries doesn't match for test %d: expected: %d, value: %d", id, test.connectionRetries, tf.ConnectionRetries) + t.Errorf("connectionRetries doesn't match on test %d: expected: %d, value: %d", id, test.connectionRetries, tf.ConnectionRetries) } if test.waitAndRetry != tf.WaitAndRetry.String() { - t.Errorf("waitAndRetry doesn't match for test %d: expected: %s, value: %s", id, test.waitAndRetry, tf.WaitAndRetry.String()) + t.Errorf("waitAndRetry doesn't match on test %d: expected: %s, value: %s", id, test.waitAndRetry, tf.WaitAndRetry.String()) } if test.sshAgent != tf.SshAgent { - t.Errorf("sshAgent doesn't match for test %d: expected: %s, value: %s", id, test.sshAgent, tf.SshAgent) + t.Errorf("sshAgent doesn't match on test %d: expected: %s, value: %s", id, test.sshAgent, tf.SshAgent) } if test.timeout != tf.Timeout.String() { - t.Errorf("timeout doesn't match for test %d: expected: %s, value: %s", id, test.timeout, tf.Timeout.String()) + t.Errorf("timeout doesn't match on test %d: expected: %s, value: %s", id, test.timeout, tf.Timeout.String()) } } diff --git a/cmd/add_alias.go b/cmd/add_alias.go index 12d6e6a..93c8208 100644 --- a/cmd/add_alias.go +++ b/cmd/add_alias.go @@ -26,13 +26,13 @@ The alias configuration file is saved to ".mole", under your home directory. return errors.New("alias name not provided") } - tunnelType = args[0] + tunnelFlags.TunnelType = args[0] aliasName = args[1] return nil }, Run: func(cmd *cobra.Command, arg []string) { - if err := alias.Add(tunnelFlags.ParseAlias(aliasName, "local")); err != nil { + if err := alias.Add(tunnelFlags.ParseAlias(aliasName)); err != nil { log.WithError(err).Error("failed to add tunnel alias") os.Exit(1) } diff --git a/cmd/root.go b/cmd/root.go index a723e4c..5e62ac7 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -18,7 +18,6 @@ import ( ) var ( - tunnelType string aliasName string tunnelFlags = &alias.TunnelFlags{} @@ -114,31 +113,23 @@ func start(alias string, tunnelFlags *alias.TunnelFlags) { log.Debugf("server: %s", s) - // TODO: rename local to source - local := make([]string, len(tunnelFlags.Source)) + source := make([]string, len(tunnelFlags.Source)) for i, r := range tunnelFlags.Source { - local[i] = r.String() + source[i] = r.String() } - // TODO: rename remote to destination - remote := make([]string, len(tunnelFlags.Destination)) + destination := make([]string, len(tunnelFlags.Destination)) for i, r := range tunnelFlags.Destination { if r.Port == "" { - err := fmt.Errorf("missing port in remote address: %s", r.String()) + err := fmt.Errorf("missing port in destination address: %s", r.String()) log.Error(err) os.Exit(1) } - remote[i] = r.String() + destination[i] = r.String() } - channels, err := tunnel.BuildSSHChannels(s.Name, local, remote) - if err != nil { - log.Error(err) - os.Exit(1) - } - - t, err := tunnel.New(s, channels) + t, err := tunnel.New(tunnelFlags.TunnelType, s, source, destination) if err != nil { log.Error(err) os.Exit(1) diff --git a/cmd/start_local.go b/cmd/start_local.go index 600e57d..0c26a49 100644 --- a/cmd/start_local.go +++ b/cmd/start_local.go @@ -10,7 +10,19 @@ import ( var localCmd = &cobra.Command{ Use: "local", Short: "Starts a ssh local port forwarding tunnel", - Long: "Starts a ssh local port forwarding tunnel", + Long: `Local Forwarding allows anyone to access outside services like they were +running locally on the source machine. + +This could be particular useful for accesing web sites, databases or any kind of +service the source machine does not have direct access to. + +Source endpoints are addresses on the same machine where mole is getting executed where clients can connect to access services on the corresponding destination endpoints. +Destination endpoints are adrresess that can be reached from the jump server. +`, + Args: func(cmd *cobra.Command, args []string) error { + tunnelFlags.TunnelType = "local" + return nil + }, Run: func(cmd *cobra.Command, arg []string) { start("", tunnelFlags) }, diff --git a/cmd/start_remote.go b/cmd/start_remote.go new file mode 100644 index 0000000..927dfbb --- /dev/null +++ b/cmd/start_remote.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "os" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var remoteCmd = &cobra.Command{ + Use: "remote", + Short: "Starts a ssh remote port forwarding tunnel", + Long: `Remote Forwarding allows anyone to expose a service running locally to a remote machine. + +This could be particular useful for giving someone on the outside access to an internal web application, for example. + +Source endpoints are addresses on the jump server where clients can connect to access services running on the corresponding destination endpoints. +Destination endpoints are addresses of services running on the same machine where mole is getting executed. +`, + Args: func(cmd *cobra.Command, args []string) error { + tunnelFlags.TunnelType = "remote" + return nil + }, + Run: func(cmd *cobra.Command, arg []string) { + start("", tunnelFlags) + }, +} + +func init() { + err := bindFlags(tunnelFlags, remoteCmd) + if err != nil { + log.WithError(err).Error("error parsing command line arguments") + os.Exit(1) + } + + startCmd.AddCommand(remoteCmd) +} diff --git a/go.mod b/go.mod index c517134..8215963 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/kevinburke/ssh_config v0.0.0-20190630040420-2e50c441276c github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/pelletier/go-buffruneio v0.2.0 // indirect + github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/prometheus/common v0.10.0 github.com/sevlyar/go-daemon v0.1.5 github.com/sirupsen/logrus v1.4.2 @@ -17,5 +18,5 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.3.2 github.com/stretchr/testify v1.3.0 // indirect - golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 + golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de ) diff --git a/go.sum b/go.sum index a80d68d..be30821 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/pelletier/go-buffruneio v0.2.0 h1:U4t4R6YkofJ5xHm3dJzuRpPZ0mr5MMCoAWooScCR7aA= github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc= +github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -94,6 +96,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt572qowuyMDMJLLm3Db3ig= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/test-env/README.md b/test-env/README.md index 878b4f8..b915e89 100644 --- a/test-env/README.md +++ b/test-env/README.md @@ -86,13 +86,23 @@ The ssh authentication key files, `test-env/key` and `test-env/key,pub` will ```sh $ make test-env -$ mole --verbose --insecure --local :21112 --local :21113 --remote 192.168.33.11:80 --remote 192.168.33.11:8080 --server mole@127.0.0.1:22122 --key test-env/ssh-server/keys/key --keep-alive-interval 2s -INFO[0000] tunnel is ready local="127.0.0.1:21113" remote="192.168.33.11:8080" -INFO[0000] tunnel is ready local="127.0.0.1:21112" remote="192.168.33.11:80" -$ curl 127.0.0.1:21112 -:) -$ curl 127.0.0.1:21113 -:) +mole start local \ + --verbose \ + --insecure \ + --source :21112 \ + --source :21113 \ + --destination 192.168.33.11:80 \ + --destination 192.168.33.11:8080 \ + --server mole@127.0.0.1:22122 \ + --key test-env/ssh-server/keys/key \ + --keep-alive-interval 2s +DEBU[0000] using ssh config file from: /home/mole/.ssh/config +DEBU[0000] server: [name=127.0.0.1, address=127.0.0.1:22122, user=mole] +DEBU[0000] tunnel: [channels:[[source=127.0.0.1:21112, destination=192.168.33.11:80] [source=127.0.0.1:21113, destination=192.168.33.11:8080]], server:127.0.0.1:22122] +DEBU[0000] connection to the ssh server is established server="[name=127.0.0.1, address=127.0.0.1:22122, user=mole]" +DEBU[0000] start sending keep alive packets +INFO[0000] tunnel channel is waiting for connection destination="192.168.33.11:8080" source="127.0.0.1:21113" +INFO[0000] tunnel channel is waiting for connection destination="192.168.33.11:80" source="127.0.0.1:21112" ``` NOTE: If you're wondering about the smile face, that is the response from both @@ -116,9 +126,23 @@ $ make test-env 2. Start mole ```sh -$ mole --verbose --insecure --local :21112 --local :21113 --remote 192.168.33.11:80 --remote 192.168.33.11:8080 --server mole@127.0.0.1:22122 --key test-env/ssh-server/keys/key --keep-alive-interval 2s -INFO[0000] tunnel is ready local="127.0.0.1:21113" remote="192.168.33.11:8080" -INFO[0000] tunnel is ready local="127.0.0.1:21112" remote="192.168.33.11:80" +mole start local \ + --verbose \ + --insecure \ + --source :21112 \ + --source :21113 \ + --destination 192.168.33.11:80 \ + --destination 192.168.33.11:8080 \ + --server mole@127.0.0.1:22122 \ + --key test-env/ssh-server/keys/key \ + --keep-alive-interval 2s +DEBU[0000] using ssh config file from: /home/mole/.ssh/config +DEBU[0000] server: [name=127.0.0.1, address=127.0.0.1:22122, user=mole] +DEBU[0000] tunnel: [channels:[[source=127.0.0.1:21112, destination=192.168.33.11:80] [source=127.0.0.1:21113, destination=192.168.33.11:8080]], server:127.0.0.1:22122] +DEBU[0000] connection to the ssh server is established server="[name=127.0.0.1, address=127.0.0.1:22122, user=mole]" +DEBU[0000] start sending keep alive packets +INFO[0000] tunnel channel is waiting for connection destination="192.168.33.11:8080" source="127.0.0.1:21113" +INFO[0000] tunnel channel is waiting for connection destination="192.168.33.11:80" source="127.0.0.1:21112" ``` 3. Kill all ssh processes running on the container holding the ssh server diff --git a/test-env/http-server/Dockerfile b/test-env/http-server/Dockerfile index 438d1a5..f3e3d09 100644 --- a/test-env/http-server/Dockerfile +++ b/test-env/http-server/Dockerfile @@ -7,4 +7,6 @@ RUN mkdir -p /data/www COPY default.conf /etc/nginx/conf.d/ COPY index.html /data/www +EXPOSE 8080 + CMD nginx -g 'daemon off;' diff --git a/test-env/ssh-server/Dockerfile b/test-env/ssh-server/Dockerfile index 3bdff92..b07cb21 100644 --- a/test-env/ssh-server/Dockerfile +++ b/test-env/ssh-server/Dockerfile @@ -5,7 +5,8 @@ RUN apk update && apk add \ libcap \ openssh \ tcpdump \ - supervisor + supervisor \ + curl COPY sshd_config /etc/ssh/sshd_config COPY motd /etc/motd diff --git a/tunnel/config.go b/tunnel/config.go index 8b63389..171fb1f 100644 --- a/tunnel/config.go +++ b/tunnel/config.go @@ -62,9 +62,14 @@ func (r SSHConfigFile) Get(host string) *SSHHost { user = "" } - localForward, err := r.getLocalForward(host) + localForward, err := r.getForward("LocalForward", host) if err != nil { - log.Warningf("error reading LocalForward configuration from ssh config file: %v", err) + log.Warningf("error reading local forwarding configuration from ssh config file: %v", err) + } + + remoteForward, err := r.getForward("RemoteForward", host) + if err != nil { + log.Warningf("error reading remote configuration from ssh config file: %v", err) } key := r.getKey(host) @@ -81,6 +86,7 @@ func (r SSHConfigFile) Get(host string) *SSHHost { Key: key, IdentityAgent: identityAgent, LocalForward: localForward, + RemoteForward: remoteForward, } } @@ -93,10 +99,8 @@ func (r SSHConfigFile) getHostname(host string) string { return hostname } -func (r SSHConfigFile) getLocalForward(host string) (*LocalForward, error) { - var local, remote string - - c, err := r.sshConfig.Get(host, "LocalForward") +func (r SSHConfigFile) getForward(forwardType, host string) (*ForwardConfig, error) { + c, err := r.sshConfig.Get(host, forwardType) if err != nil { return nil, err } @@ -108,21 +112,21 @@ func (r SSHConfigFile) getLocalForward(host string) (*LocalForward, error) { l := strings.Fields(c) if len(l) < 2 { - return nil, fmt.Errorf("bad forwarding specification on ssh config file: %s", l) + return nil, fmt.Errorf("malformed forwarding configuration on ssh config file: %s", l) } - local = l[0] - remote = l[1] + source := l[0] + destination := l[1] - if strings.HasPrefix(local, ":") { - local = fmt.Sprintf("127.0.0.1%s", local) + if strings.HasPrefix(source, ":") { + source = fmt.Sprintf("127.0.0.1%s", source) } - if local != "" && !strings.Contains(local, ":") { - local = fmt.Sprintf("127.0.0.1:%s", local) + if source != "" && !strings.Contains(source, ":") { + source = fmt.Sprintf("127.0.0.1:%s", source) } - return &LocalForward{Local: local, Remote: remote}, nil + return &ForwardConfig{Source: source, Destination: destination}, nil } @@ -151,21 +155,23 @@ type SSHHost struct { User string Key string IdentityAgent string - LocalForward *LocalForward + LocalForward *ForwardConfig + RemoteForward *ForwardConfig } // String returns a string representation of a SSHHost. func (h SSHHost) String() string { - return fmt.Sprintf("[hostname=%s, port=%s, user=%s, key=%s, identity_agent=%s, local_forward=%s]", h.Hostname, h.Port, h.User, h.Key, h.IdentityAgent, h.LocalForward) + return fmt.Sprintf("[hostname=%s, port=%s, user=%s, key=%s, identity_agent=%s, local_forward=%s, remote_forward=%s]", h.Hostname, h.Port, h.User, h.Key, h.IdentityAgent, h.LocalForward, h.RemoteForward) } -// LocalForward represents a LocalForward configuration for SSHHost. -type LocalForward struct { - Local string - Remote string +// ForwardConfig represents either a LocalForward or a RemoteForward configuration +// for SSHHost. +type ForwardConfig struct { + Source string + Destination string } -// String returns a string representation of LocalForward. -func (f LocalForward) String() string { - return fmt.Sprintf("[local=%s, remote=%s]", f.Local, f.Remote) +// String returns a string representation of ForwardConfig. +func (f ForwardConfig) String() string { + return fmt.Sprintf("[source=%s, destination=%s]", f.Source, f.Destination) } diff --git a/tunnel/config_test.go b/tunnel/config_test.go index 17db4e1..ccb0a00 100644 --- a/tunnel/config_test.go +++ b/tunnel/config_test.go @@ -20,6 +20,11 @@ Host example2 LocalForward 8080 127.0.0.1:8080 Host example3 LocalForward 9090 127.0.0.1:9090 +Host example4 + RemoteForward 80 127.0.0.1:8080 +Host example5 + RemoteForward 192.168.1.100:80 my-server:8080 + ` c, _ := ssh_config.Decode(strings.NewReader(config)) @@ -46,7 +51,7 @@ Host example3 Port: "", User: "", Key: "", - LocalForward: &LocalForward{Local: "127.0.0.1:8080", Remote: "127.0.0.1:8080"}, + LocalForward: &ForwardConfig{Source: "127.0.0.1:8080", Destination: "127.0.0.1:8080"}, }, }, { @@ -56,7 +61,27 @@ Host example3 Port: "", User: "", Key: "", - LocalForward: &LocalForward{Local: "127.0.0.1:9090", Remote: "127.0.0.1:9090"}, + LocalForward: &ForwardConfig{Source: "127.0.0.1:9090", Destination: "127.0.0.1:9090"}, + }, + }, + { + "example4", + &SSHHost{ + Hostname: "", + Port: "", + User: "", + Key: "", + RemoteForward: &ForwardConfig{Source: "127.0.0.1:80", Destination: "127.0.0.1:8080"}, + }, + }, + { + "example5", + &SSHHost{ + Hostname: "", + Port: "", + User: "", + Key: "", + RemoteForward: &ForwardConfig{Source: "192.168.1.100:80", Destination: "my-server:8080"}, }, }, } diff --git a/tunnel/example_test.go b/tunnel/example_test.go index 8e0ba16..1942266 100644 --- a/tunnel/example_test.go +++ b/tunnel/example_test.go @@ -6,12 +6,13 @@ import ( "github.com/davrodpin/mole/tunnel" ) -// This example shows the basic usage of the package: define both the local and -// remote endpoints, the ssh server and then start the tunnel that will +// This example shows the basic usage of the package: define both the source and +// destination endpoints, the ssh server and then start the tunnel that will // exchange data from the local address to the remote address through the // established ssh channel. func Example() { - sshChan := &tunnel.SSHChannel{Local: "127.0.0.1:8080", Remote: "user@example.com:22"} + sourceEndpoints := []string{"127.0.0.1:8080"} + destinationEndpoints := []string{"user@example.com:80"} // Initialize the SSH Server configuration providing all values so // tunnel.NewServer will not try to lookup any value using $HOME/.ssh/config @@ -20,7 +21,7 @@ func Example() { log.Fatalf("error processing server options: %v\n", err) } - t, err := tunnel.New(server, []*tunnel.SSHChannel{sshChan}) + t, err := tunnel.New("local", server, sourceEndpoints, destinationEndpoints) if err != nil { log.Fatalf("error creating tunnel: %v\n", err) } diff --git a/tunnel/key_test.go b/tunnel/key_test.go index fd125a1..d4d90bf 100644 --- a/tunnel/key_test.go +++ b/tunnel/key_test.go @@ -5,10 +5,6 @@ import ( "testing" ) -func passwordHandler(password string) func() ([]byte, error) { - return func() ([]byte, error) { return []byte(password), nil } -} - func TestPemKey(t *testing.T) { tests := []struct { keyPath string diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 3a9678a..8566859 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -19,9 +19,9 @@ import ( ) const ( - HostMissing = "server host has to be provided as part of the server address" - RandomPortAddress = "127.0.0.1:0" - NoRemoteGiven = "cannot create a tunnel without at least one remote address" + HostMissing = "server host has to be provided as part of the server address" + RandomPortAddress = "127.0.0.1:0" + NoDestinationGiven = "cannot create a tunnel without at least one remote address" ) // Server holds the SSH Server attributes used for the client to connect to it. @@ -54,13 +54,12 @@ func NewServer(user, address, key, sshAgent string) (*Server, error) { c, err := NewSSHConfigFile() if err != nil { - // Ignore file doesnt exists if !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("error accessing %s: %v", host, err) } } - // // If ssh config file doesnt exists, create an empty ssh config struct to avoid nil pointer deference + // If ssh config file doesnt exists, create an empty ssh config struct to avoid nil pointer deference if errors.Is(err, os.ErrNotExist) { c = NewEmptySSHConfigStruct() } @@ -121,19 +120,63 @@ func (s Server) String() string { } type SSHChannel struct { - Local string - Remote string - listener net.Listener - conn net.Conn + ChannelType string + Source string + Destination string + listener net.Listener + conn net.Conn } +// Listen creates tcp listeners for each channel defined. +func (ch *SSHChannel) Listen(serverClient *ssh.Client) error { + var l net.Listener + var err error + + if ch.listener == nil { + if ch.ChannelType == "local" { + l, err = net.Listen("tcp", ch.Source) + } else if ch.ChannelType == "remote" { + l, err = serverClient.Listen("tcp", ch.Source) + } else { + return fmt.Errorf("channel can't listen on endpoint: unknown channel type %s", ch.ChannelType) + } + + if err != nil { + return err + } + + ch.listener = l + + // update the endpoint value with assigned port for the cases where the user + // haven't explicitily specified one + ch.Source = l.Addr().String() + } + + return nil +} + +// Accept waits for and return the next connection to the SSHChannel. +func (ch *SSHChannel) Accept() error { + var err error + + if ch.conn, err = ch.listener.Accept(); err != nil { + return fmt.Errorf("error while establishing connection: %v", err) + } + + return nil +} + +// String returns a string representation of a SSHChannel func (ch SSHChannel) String() string { - return fmt.Sprintf("[local=%s, remote=%s]", ch.Local, ch.Remote) + return fmt.Sprintf("[source=%s, destination=%s]", ch.Source, ch.Destination) } // Tunnel represents the ssh tunnel and the channels connecting local and // remote endpoints. type Tunnel struct { + // Type tells what kind of port forwarding this tunnel will handle: local or remote + Type string + // Ready tells when the Tunnel is ready to accept connections Ready chan bool @@ -158,15 +201,23 @@ type Tunnel struct { } // New creates a new instance of Tunnel. -func New(server *Server, channels []*SSHChannel) (*Tunnel, error) { +func New(tunnelType string, server *Server, source, destination []string) (*Tunnel, error) { + var channels []*SSHChannel + var err error + + channels, err = buildSSHChannels(server.Name, tunnelType, source, destination) + if err != nil { + return nil, err + } for _, channel := range channels { - if channel.Local == "" || channel.Remote == "" { - return nil, fmt.Errorf("invalid ssh channel: local=%s, remote=%s", channel.Local, channel.Remote) + if channel.Source == "" || channel.Destination == "" { + return nil, fmt.Errorf("invalid ssh channel: source=%s, destination=%s", channel.Source, channel.Destination) } } return &Tunnel{ + Type: tunnelType, Ready: make(chan bool, 1), channels: channels, server: server, @@ -210,14 +261,8 @@ func (t *Tunnel) Start() error { // Listen creates tcp listeners for each channel defined. func (t *Tunnel) Listen() error { for _, ch := range t.channels { - if ch.listener == nil { - l, err := net.Listen("tcp", ch.Local) - if err != nil { - return err - } - - ch.listener = l - ch.Local = l.Addr().String() // update the value with assigned port is the given value is :0 + if err := ch.Listen(t.client); err != nil { + return err } } @@ -227,26 +272,35 @@ func (t *Tunnel) Listen() error { func (t *Tunnel) startChannel(channel *SSHChannel) error { var err error - channel.conn, err = channel.listener.Accept() + err = channel.Accept() if err != nil { - return fmt.Errorf("error while establishing local connection: %v", err) + return err } log.WithFields(log.Fields{ "channel": channel, - }).Debug("local connection established") + }).Debug("connection established") if t.client == nil { return fmt.Errorf("tunnel channel can't be established: missing connection to the ssh server") } - remoteConn, err := t.client.Dial("tcp", channel.Remote) + var destinationConn net.Conn + + if t.Type == "local" { + destinationConn, err = t.client.Dial("tcp", channel.Destination) + } else if t.Type == "remote" { + destinationConn, err = net.Dial("tcp", channel.Destination) + } else { + return fmt.Errorf("unknown tunnel type %s", t.Type) + } + if err != nil { - return fmt.Errorf("remote dial error: %s", err) + return fmt.Errorf("dial error: %s", err) } - go copyConn(channel.conn, remoteConn) - go copyConn(remoteConn, channel.conn) + go copyConn(channel.conn, destinationConn) + go copyConn(destinationConn, channel.conn) log.WithFields(log.Fields{ "channel": channel, @@ -292,7 +346,7 @@ func (t *Tunnel) dial() error { log.WithError(err).WithFields(log.Fields{ "server": t.server, "retries": retries, - }).Debugf("error while connecting to ssh server") + }).Error("error while connecting to ssh server") if t.ConnectionRetries < 0 { break @@ -325,13 +379,15 @@ func (t *Tunnel) waitAndReconnect() { } func (t *Tunnel) connect() { - err := t.Listen() + var err error + + err = t.dial() if err != nil { t.done <- err return } - err = t.dial() + err = t.Listen() if err != nil { t.done <- err return @@ -355,8 +411,8 @@ func (t *Tunnel) connect() { for { once.Do(func() { log.WithFields(log.Fields{ - "local": channel.Local, - "remote": channel.Remote, + "source": channel.Source, + "destination": channel.Destination, }).Info("tunnel channel is waiting for connection") waitgroup.Done() @@ -487,41 +543,39 @@ func expandAddress(address string) string { return address } -// BuildSSHChannels normalizes the given set of local and remote addresses, -// combining them to build a set of ssh channel objects. -func BuildSSHChannels(serverName string, local, remote []string) ([]*SSHChannel, error) { - // if not local and remote were given, try to find the addresses from the SSH - // configuration file. - if len(local) == 0 && len(remote) == 0 { - lf, err := getLocalForward(serverName) +func buildSSHChannels(serverName, channelType string, source, destination []string) ([]*SSHChannel, error) { + // if source and destination were not given, try to find the addresses from the + // SSH configuration file. + if len(source) == 0 && len(destination) == 0 { + f, err := getForward(channelType, serverName) if err != nil { return nil, err } - local = []string{lf.Local} - remote = []string{lf.Remote} + source = []string{f.Source} + destination = []string{f.Destination} } else { - lSize := len(local) - rSize := len(remote) + lSize := len(source) + rSize := len(destination) if lSize > rSize { - // if there are more local than remote addresses given, the additional + // if there are more source than destination addresses given, the additional // addresses must be removed. if rSize == 0 { - return nil, fmt.Errorf(NoRemoteGiven) + return nil, fmt.Errorf(NoDestinationGiven) } - local = local[0:rSize] + source = source[0:rSize] } else if lSize < rSize { - // if there are more remote than local addresses given, the missing local - // addresses should be configured as localhost with random ports. + // if there are more destination than source addresses given, the missing + // source addresses should be configured as localhost with random ports. nl := make([]string, rSize) - for i, _ := range remote { + for i := range destination { if i < lSize { - if local[i] != "" { - nl[i] = local[i] + if source[i] != "" { + nl[i] = source[i] } else { nl[i] = RandomPortAddress } @@ -530,27 +584,29 @@ func BuildSSHChannels(serverName string, local, remote []string) ([]*SSHChannel, } } - local = nl + source = nl } } - for i, addr := range local { - local[i] = expandAddress(addr) + for i, addr := range source { + source[i] = expandAddress(addr) } - for i, addr := range remote { - remote[i] = expandAddress(addr) + for i, addr := range destination { + destination[i] = expandAddress(addr) } - channels := make([]*SSHChannel, len(remote)) - for i, r := range remote { - channels[i] = &SSHChannel{Local: local[i], Remote: r} + channels := make([]*SSHChannel, len(destination)) + for i, d := range destination { + channels[i] = &SSHChannel{ChannelType: channelType, Source: source[i], Destination: d} } return channels, nil } -func getLocalForward(serverName string) (*LocalForward, error) { +func getForward(channelType, serverName string) (*ForwardConfig, error) { + var f *ForwardConfig + cfg, err := NewSSHConfigFile() if err != nil { return nil, fmt.Errorf("error reading ssh configuration file: %v", err) @@ -558,9 +614,17 @@ func getLocalForward(serverName string) (*LocalForward, error) { sh := cfg.Get(serverName) - if sh.LocalForward == nil { - return nil, fmt.Errorf("LocalForward could not be found or has invalid syntax for host %s", serverName) + if channelType == "local" { + f = sh.LocalForward + } else if channelType == "remote" { + f = sh.RemoteForward + } else { + return nil, fmt.Errorf("could not retrieve forwarding information from ssh configuration file: unsupported channel type %s", channelType) + } + + if f == nil { + return nil, fmt.Errorf("forward config could not be found or has invalid syntax for host %s", serverName) } - return sh.LocalForward, nil + return f, nil } diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go index 2f5810a..4624042 100644 --- a/tunnel/tunnel_test.go +++ b/tunnel/tunnel_test.go @@ -11,11 +11,12 @@ import ( "os" "path/filepath" "reflect" - "strings" "testing" "time" + "github.com/phayes/freeport" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) const NoSshRetries = -1 @@ -112,8 +113,29 @@ func TestServerOptions(t *testing.T) { } } -func TestTunnel(t *testing.T) { - tun, _, _ := prepareTunnel(1, false, NoSshRetries) +func TestLocalTunnel(t *testing.T) { + c := &tunnelConfig{t, "local", 1, false, NoSshRetries} + tun, _, _ := prepareTunnel(c) + + select { + case <-tun.Ready: + t.Log("tunnel is ready to accept connections") + case <-time.After(1 * time.Second): + t.Errorf("error waiting for tunnel to be ready") + return + } + + err := validateTunnelConnectivity(t, "ABC", tun) + if err != nil { + t.Errorf("%v", err) + } + + tun.Stop() +} + +func TestRemoteTunnel(t *testing.T) { + c := &tunnelConfig{t, "remote", 1, true, NoSshRetries} + tun, _, _ := prepareTunnel(c) select { case <-tun.Ready: @@ -132,7 +154,8 @@ func TestTunnel(t *testing.T) { } func TestTunnelInsecure(t *testing.T) { - tun, _, _ := prepareTunnel(1, true, NoSshRetries) + c := &tunnelConfig{t, "local", 1, true, NoSshRetries} + tun, _, _ := prepareTunnel(c) select { case <-tun.Ready: @@ -150,8 +173,9 @@ func TestTunnelInsecure(t *testing.T) { tun.Stop() } -func TestTunnelMultipleRemotes(t *testing.T) { - tun, _, _ := prepareTunnel(2, false, NoSshRetries) +func TestTunnelMultipleDestinations(t *testing.T) { + c := &tunnelConfig{t, "local", 2, false, NoSshRetries} + tun, _, _ := prepareTunnel(c) select { case <-tun.Ready: @@ -170,7 +194,8 @@ func TestTunnelMultipleRemotes(t *testing.T) { } func TestReconnectSSHServer(t *testing.T) { - tun, ssh, _ := prepareTunnel(1, false, 3) + c := &tunnelConfig{t, "local", 1, false, 3} + tun, ssh, _ := prepareTunnel(c) select { case <-tun.Ready: @@ -195,7 +220,7 @@ func TestReconnectSSHServer(t *testing.T) { return } - _, err = createSSHServer(ssh.Addr().String(), keyPath) + _, err = createSSHServer(t, ssh.Addr().String(), keyPath) if err != nil { t.Errorf("error while recreating ssh server: %s", err) return @@ -242,7 +267,12 @@ func validateTunnelConnectivity(t *testing.T, expected string, tun *Tunnel) erro } func TestMain(m *testing.M) { - prepareTestEnv() + err := prepareTestEnv() + if err != nil { + fmt.Printf("could not start test suite: %v\n", err) + os.RemoveAll(sshDir) + os.Exit(1) + } code := m.Run() @@ -254,64 +284,64 @@ func TestMain(m *testing.M) { func TestBuildSSHChannels(t *testing.T) { tests := []struct { serverName string - local []string - remote []string + source []string + destination []string expected int expectedError error }{ { serverName: "test", - local: []string{":3360"}, - remote: []string{":3360"}, + source: []string{":3360"}, + destination: []string{":3360"}, expected: 1, expectedError: nil, }, { serverName: "test", - local: []string{":3360", ":8080"}, - remote: []string{":3360", ":8080"}, + source: []string{":3360", ":8080"}, + destination: []string{":3360", ":8080"}, expected: 2, expectedError: nil, }, { serverName: "test", - local: []string{}, - remote: []string{":3360"}, + source: []string{}, + destination: []string{":3360"}, expected: 1, expectedError: nil, }, { serverName: "test", - local: []string{":3360"}, - remote: []string{":3360", ":8080"}, + source: []string{":3360"}, + destination: []string{":3360", ":8080"}, expected: 2, expectedError: nil, }, { serverName: "hostWithLocalForward", - local: []string{}, - remote: []string{}, + source: []string{}, + destination: []string{}, expected: 1, expectedError: nil, }, { serverName: "test", - local: []string{":3360", ":8080"}, - remote: []string{":3360"}, + source: []string{":3360", ":8080"}, + destination: []string{":3360"}, expected: 1, expectedError: nil, }, { serverName: "test", - local: []string{":3360"}, - remote: []string{}, + source: []string{":3360"}, + destination: []string{}, expected: 0, - expectedError: fmt.Errorf(NoRemoteGiven), + expectedError: fmt.Errorf(NoDestinationGiven), }, } for testId, test := range tests { - sshChannels, err := BuildSSHChannels(test.serverName, test.local, test.remote) + sshChannels, err := buildSSHChannels(test.serverName, "local", test.source, test.destination) if err != nil { if test.expectedError != nil { if test.expectedError.Error() != err.Error() { @@ -326,23 +356,23 @@ func TestBuildSSHChannels(t *testing.T) { t.Errorf("wrong number of ssh channel objects created for test %d: expected: %d, value: %d", testId, test.expected, len(sshChannels)) } - localSize := len(test.local) - remoteSize := len(test.remote) + sourceSize := len(test.source) + destinationSize := len(test.destination) - // check if the local addresses match only if any address is given - if localSize > 0 && remoteSize > 0 { + // check if the source addresses match only if any address is given + if sourceSize > 0 && destinationSize > 0 { for i, sshChannel := range sshChannels { - local := "" - if i < localSize { - local = test.local[i] + source := "" + if i < sourceSize { + source = test.source[i] } else { - local = RandomPortAddress + source = RandomPortAddress } - local = expandAddress(local) + source = expandAddress(source) - if sshChannel.Local != local { - t.Errorf("local address don't match for test %d: expected: %s, value: %s", testId, sshChannel.Local, local) + if sshChannel.Source != source { + t.Errorf("source address don't match for test %d: expected: %s, value: %s", testId, sshChannel.Source, source) } } @@ -350,44 +380,70 @@ func TestBuildSSHChannels(t *testing.T) { } } +type tunnelConfig struct { + T *testing.T + TunnelType string + + // Destinations indicates how many endpoints should be available through the + // tunnel. + Destinations int + + Insecure bool + ConnectionRetries int +} + // prepareTunnel creates a Tunnel object making sure all infrastructure // dependencies (ssh and http servers) are ready. // // The 'remotes' argument tells how many remote endpoints will be available // through the tunnel. -func prepareTunnel(remotes int, insecure bool, sshConnectionRetries int) (tun *Tunnel, ssh net.Listener, hss []*http.Server) { - hss = make([]*http.Server, remotes) +func prepareTunnel(config *tunnelConfig) (tun *Tunnel, ssh net.Listener, hss []*http.Server) { + hss = make([]*http.Server, config.Destinations) - ssh, err := createSSHServer("", keyPath) + ssh, err := createSSHServer(config.T, "", keyPath) if err != nil { - // FIXME: return the error - fmt.Printf("error while creating ssh server: %s", err) + config.T.Errorf("error while creating ssh server: %s", err) return } srv, _ := NewServer("mole", ssh.Addr().String(), "", "") - srv.Insecure = insecure + srv.Insecure = config.Insecure + + if !config.Insecure { + err = generateKnownHosts(ssh.Addr(), publicKeyPath, knownHostsPath) + if err != nil { + config.T.Errorf("error generating known hosts file for tests: %v\n", err) + return + } - if !insecure { - generateKnownHosts(ssh.Addr().String(), publicKeyPath, knownHostsPath) } - sshChannels := []*SSHChannel{} - for i := 1; i <= remotes; i++ { + source := make([]string, config.Destinations) + destination := make([]string, config.Destinations) + + for i := 0; i <= (config.Destinations - 1); i++ { l, hs := createHttpServer() - sshChannels = append(sshChannels, &SSHChannel{Local: "127.0.0.1:0", Remote: l.Addr().String()}) + if config.TunnelType == "local" { + source[i] = "127.0.0.1:0" + destination[i] = l.Addr().String() + } else if config.TunnelType == "remote" { + source[i] = l.Addr().String() + destination[i] = "127.0.0.1:0" + } else { + config.T.Errorf("could not configure destination endpoints for testing: %v\n", err) + return + } hss = append(hss, hs) } - tun, _ = New(srv, sshChannels) - tun.ConnectionRetries = sshConnectionRetries + tun, _ = New(config.TunnelType, srv, source, destination) + tun.ConnectionRetries = config.ConnectionRetries tun.WaitAndRetry = 3 * time.Second tun.KeepAliveInterval = 10 * time.Second go func(tun *Tunnel) { - var err error - err = tun.Start() + err := tun.Start() // FIXME: this message should be shown through *testing.t but using it here // would cause the message to be printed after the test ends (goroutine), // making the test to fail @@ -399,7 +455,7 @@ func prepareTunnel(remotes int, insecure bool, sshConnectionRetries int) (tun *T return tun, ssh, hss } -func prepareTestEnv() { +func prepareTestEnv() error { home := "testdata" fixtureDir := filepath.Join(home, "dotssh") testDir := filepath.Join(home, ".ssh") @@ -433,31 +489,22 @@ func prepareTestEnv() { }, } - os.Mkdir(testDir, os.ModeDir|os.ModePerm) + err := os.Mkdir(testDir, os.ModeDir|os.ModePerm) + if err != nil { + return err + } for _, f := range fixtures { - os.Link(f["from"], f["to"]) + err = os.Link(f["from"], f["to"]) + if err != nil { + return err + } } os.Setenv("HOME", home) os.Setenv("USERPROFILE", home) -} - -// get performs a http request using the given client appending the given -// resource to a hard-coded URL. -// -// The request performed by this function is designed to reach the other side -// through a pipe (net.Pipe()) and this is the reason the URL is hard-coded. -func get(client http.Client, resource string) (string, error) { - resp, err := client.Get(fmt.Sprintf("%s%s", "http://any-url-is.fine", resource)) - if err != nil { - return "", err - } - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - - return string(body), nil + return nil } // createHttpServer spawns a new http server, listening on a random port. @@ -496,7 +543,7 @@ func createHttpServer() (net.Listener, *http.Server) { // References: // https://gist.github.com/jpillora/b480fde82bff51a06238 // https://tools.ietf.org/html/rfc4254#section-7.2 -func createSSHServer(address string, keyPath string) (net.Listener, error) { +func createSSHServer(t *testing.T, address string, keyPath string) (net.Listener, error) { conf := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { return &ssh.Permissions{}, nil @@ -519,6 +566,8 @@ func createSSHServer(address string, keyPath string) (net.Listener, error) { go func(listener net.Listener) { var conns []ssh.Conn for { + var err error + conn, err := listener.Accept() if err != nil { // closing all ssh connections if a new client can't connect to the server @@ -531,14 +580,53 @@ func createSSHServer(address string, keyPath string) (net.Listener, error) { serverConn, chans, reqs, _ := ssh.NewServerConn(conn, conf) conns = append(conns, serverConn) - go ssh.DiscardRequests(reqs) + // go routine to handle ssh client requests. In the context of mole's test, + // this is needed when a remote ssh forwarding listens to a port on the jump + // server and the port needs to be randomized (port is given as 0). + // The reply's needs to carry the port to be listened in its payload. + // All requests but "tcpip-forward" are discarded. + go func(reqs <-chan *ssh.Request) { + var err error + + for newReq := range reqs { + if newReq.Type != "tcpip-forward" { + err = newReq.Reply(false, nil) + if err != nil { + t.Errorf("error replying to tcpip-forward request: %v", err) + } + return + } + + if newReq.WantReply { + ports, err := freeport.GetFreePorts(1) + if err != nil { + t.Errorf("could not get a free port: %v", err) + return + } + port := make([]byte, 4) + binary.BigEndian.PutUint32(port, uint32(ports[0])) + err = newReq.Reply(true, port) + if err != nil { + t.Errorf("error replying to tcpip-forward request: %v", err) + return + } + } + } + }(reqs) + // go routine to handle requests to create new ssh channels. This particular + // implementation only supports "direct-tcpip", which is the identifier used + // for ssh port forwarding. go func(chans <-chan ssh.NewChannel) { for newChan := range chans { go func(newChan ssh.NewChannel) { + var err error - if t := newChan.ChannelType(); t != "direct-tcpip" { - newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + if ct := newChan.ChannelType(); ct != "direct-tcpip" { + err = newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", ct)) + if err != nil { + t.Errorf("error rejecting unsupported channel: %v", err) + } return } @@ -569,33 +657,19 @@ func createSSHServer(address string, keyPath string) (net.Listener, error) { // generateKnownHosts creates a new "known_hosts" file on a given path with a // single entry based on the given SSH server address and public key. -func generateKnownHosts(sshAddr, pubKeyPath, knownHostsPath string) { - i := strings.Split(sshAddr, ":")[0] - p := strings.Split(sshAddr, ":")[1] - - kc, _ := ioutil.ReadFile(pubKeyPath) - t := strings.Split(string(kc), " ")[0] - k := strings.Split(string(kc), " ")[1] +func generateKnownHosts(sshAddr net.Addr, pubKeyPath, knownHostsPath string) error { + d, err := ioutil.ReadFile(pubKeyPath) + if err != nil { + return err + } - c := fmt.Sprintf("[%s]:%s %s %s", i, p, t, k) - ioutil.WriteFile(knownHostsPath, []byte(c), 0600) -} + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(d)) + if err != nil { + return err + } -type MockConn struct { - isConnectionOpen bool -} + l := knownhosts.Line([]string{sshAddr.String()}, pk) + ioutil.WriteFile(knownHostsPath, []byte(l), 0600) -func (c MockConn) User() string { return "" } -func (c MockConn) SessionID() []byte { return []byte{} } -func (c MockConn) ClientVersion() []byte { return []byte{} } -func (c MockConn) ServerVersion() []byte { return []byte{} } -func (c MockConn) RemoteAddr() net.Addr { return nil } -func (c MockConn) LocalAddr() net.Addr { return nil } -func (c MockConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { - return false, []byte{}, nil -} -func (c MockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { - return nil, nil, nil + return nil } -func (c MockConn) Close() error { return nil } -func (c MockConn) Wait() error { return nil }