diff --git a/cmd/forward.go b/cmd/forward.go new file mode 100644 index 0000000..0fe151a --- /dev/null +++ b/cmd/forward.go @@ -0,0 +1,135 @@ +package cmd + +import ( + "context" + "strconv" + "strings" + + "github.com/apex/log" + "github.com/m-lab/bmctool/tunnel" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" +) + +var ( + sshUser string + ports []string + tunnelHost string + + defaultPorts = []string{"4443:443", "5900"} +) + +var forwardCmd = &cobra.Command{ + Use: "forward ", + Short: "Forward ports via an SSH tunnel", + Long: `This command creates an SSH tunnel to a given . + +Ports to be forwarded can be specified with the (repeatable) --port flag. +Local and remote ports can be specified with the following syntax: + +--port src[:dest] + +e.g.: + +bmctool forward --port 4443:443 --port 5900 + +If dest is unspecified, it'll be the same as src. + +The host to use for tunneling can be specified via the --tunnel-host flag, +or the BMCTUNNELHOST environment variable. + +The username to use to connect to the intermediate host can be specified via +the --username flag, or the BMCTUNNELUSER environment variable.`, + Args: cobra.MinimumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + dstHost := args[0] + forward(dstHost) + }, +} + +func init() { + rootCmd.AddCommand(forwardCmd) + + viper.AutomaticEnv() + + forwardCmd.Flags().StringArrayVar(&ports, "port", defaultPorts, "source:destination") + forwardCmd.Flags().StringVar(&tunnelHost, "tunnel-host", + viper.GetString("BMCTUNNELHOST"), "intermediate host") + forwardCmd.Flags().StringVar(&sshUser, "username", + viper.GetString("BMCTUNNELUSER"), "username for intermediate host") + +} + +// splitPorts takes a string containing either a "local:remote" ports pair +// or just "port" and returns local/remote as separate variables. If the string +// contains a single port, it returns the same port for local and remote. +func splitPorts(ports string) (int32, int32, error) { + split := strings.Split(ports, ":") + + srcPort, err := strconv.ParseInt(split[0], 10, 32) + if err != nil { + return 0, 0, err + } + + if len(split) == 1 { + return int32(srcPort), int32(srcPort), nil + } + + dstPort, err := strconv.ParseInt(split[1], 10, 32) + if err != nil { + return 0, 0, err + } + + return int32(srcPort), int32(dstPort), nil +} + +func forward(dstHost string) { + + sshConfig := &ssh.ClientConfig{ + User: sshUser, + Auth: []ssh.AuthMethod{ + tunnel.SSHAgent(), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + serverEndpoint := &tunnel.Endpoint{ + Host: tunnelHost, + Port: 22, + } + + errs, _ := errgroup.WithContext(context.Background()) + + for _, port := range ports { + srcPort, dstPort, err := splitPorts(port) + if err != nil { + log.Errorf("Cannot parse provided ports: %v", err) + osExit(1) + } + + localEndpoint := &tunnel.Endpoint{ + Host: "localhost", + Port: srcPort, + } + + remoteEndpoint := &tunnel.Endpoint{ + Host: dstHost, + Port: dstPort, + } + + tunnel := &tunnel.SSHTunnel{ + Config: sshConfig, + Local: localEndpoint, + Server: serverEndpoint, + Remote: remoteEndpoint, + } + + log.Infof("Forwarding %s -> %s -> %s", localEndpoint, serverEndpoint, remoteEndpoint) + errs.Go(tunnel.Start) + + } + + errs.Wait() +} diff --git a/tunnel/endpoint.go b/tunnel/endpoint.go new file mode 100644 index 0000000..e7f61c1 --- /dev/null +++ b/tunnel/endpoint.go @@ -0,0 +1,12 @@ +package tunnel + +import "fmt" + +type Endpoint struct { + Host string + Port int32 +} + +func (ep *Endpoint) String() string { + return fmt.Sprintf("%s:%d", ep.Host, ep.Port) +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go new file mode 100644 index 0000000..fdf57fc --- /dev/null +++ b/tunnel/tunnel.go @@ -0,0 +1,77 @@ +package tunnel + +import ( + "io" + "net" + "os" + + "github.com/apex/log" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// SSHTunnel represents an SSH tunnel. +type SSHTunnel struct { + // Local server endpoint + Local *Endpoint + + // Intermediate server endpoint + Server *Endpoint + + // Remote server endpoint + Remote *Endpoint + + // Client configuration + Config *ssh.ClientConfig +} + +// Start initializes the SSH tunnel. +func (tunnel *SSHTunnel) Start() error { + listener, err := net.Listen("tcp", tunnel.Local.String()) + if err != nil { + log.Errorf("Cannot listen on %s: %v", tunnel.Local, err) + return err + } + defer listener.Close() + + for { + conn, err := listener.Accept() + if err != nil { + log.Errorf("Cannot accept connection: %v", err) + return err + } + go tunnel.forward(conn) + } +} + +func (tunnel *SSHTunnel) forward(localConn net.Conn) { + serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + if err != nil { + log.Errorf("Server dial error: %s", err) + return + } + + remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String()) + if err != nil { + log.Errorf("Remote dial error: %s", err) + return + } + + copyConn := func(writer, reader net.Conn) { + _, err := io.Copy(writer, reader) + if err != nil { + log.Debugf("io.Copy error: %s", err) + } + } + + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) +} + +// SSHAgent gets a ssh.AuthMethod from the local ssh-agent instance (if any). +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} diff --git a/tunnel/tunnel_test.go b/tunnel/tunnel_test.go new file mode 100644 index 0000000..b914a00 --- /dev/null +++ b/tunnel/tunnel_test.go @@ -0,0 +1,111 @@ +package tunnel + +import ( + "fmt" + "io" + "log" + "net" + "testing" + "time" + + sshserver "github.com/gliderlabs/ssh" + "golang.org/x/crypto/ssh" +) + +func TestSSHTunnel_Start(t *testing.T) { + handlerFunc := func(s sshserver.Session) { + io.WriteString(s, "test") + } + + // Create intermediate SSH server. + bounceSSHListener, err := net.Listen("tcp", ":0") + bounceSSHServer := &sshserver.Server{ + Handler: handlerFunc, + LocalPortForwardingCallback: func(ctx sshserver.Context, + destinationHost string, destinationPort uint32) bool { + return true + }, + } + if err != nil { + t.Fatalf("Cannot create listener: %v", err) + } + go func() { + log.Fatal(bounceSSHServer.Serve(bounceSSHListener)) + }() + + // Create destination SSH server. + destSSHServer, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Cannot create listener: %v", err) + } + go func() { + log.Fatal(sshserver.Serve(destSSHServer, handlerFunc)) + }() + + sshConfig := &ssh.ClientConfig{ + User: "test", + Auth: []ssh.AuthMethod{ + SSHAgent(), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + tun := &SSHTunnel{ + Local: &Endpoint{ + Host: "127.0.0.1", + Port: int32(destSSHServer.Addr().(*net.TCPAddr).Port) + 1, + }, + Server: &Endpoint{ + Host: "127.0.0.1", + Port: int32(bounceSSHListener.Addr().(*net.TCPAddr).Port), + }, + Remote: &Endpoint{ + Host: "127.0.0.1", + Port: int32(destSSHServer.Addr().(*net.TCPAddr).Port), + }, + Config: sshConfig, + } + + go func() { tun.Start() }() + + time.Sleep(2 * time.Second) + + // Connect to the tunnel and verify that the received message is the + // expected one from the remote server. + cl, err := ssh.Dial("tcp", tun.Local.String(), sshConfig) + if err != nil { + t.Fatalf("Cannot connect to the local endpoint: %v", err) + } + + sess, err := cl.NewSession() + if err != nil { + t.Fatalf("Cannot create SSH session: %v", err) + } + + sshout, err := sess.StdoutPipe() + if err != nil { + t.Fatalf("Cannot pipe stdout: %v", err) + } + + err = sess.Shell() + if err != nil { + t.Fatalf("Cannot start shell: %v", err) + } + + output := readBuffForString(sshout) + if output != "test" { + t.Fatalf("Unexpected output: %s", output) + } + fmt.Println("Done.") + +} + +func readBuffForString(sshOut io.Reader) string { + buf := make([]byte, 1000) + n, err := sshOut.Read(buf) //this reads the ssh terminal + waitingString := "" + if err == nil { + waitingString = string(buf[:n]) + } + return waitingString +}