Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions cmd/forward.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package cmd

import (
"context"
"strconv"
"strings"

"github.com/apex/log"
"github.com/m-lab/bmctool/tunnel"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)

var (
sshUser string
ports []string
tunnelHost string

defaultPorts = []string{"4443:443", "5900"}
)

var forwardCmd = &cobra.Command{
Use: "forward <host>",
Short: "Forward ports via an SSH tunnel",
Long: `This command creates an SSH tunnel to a given <host>.

Ports to be forwarded can be specified with the (repeatable) --port flag.
Local and remote ports can be specified with the following syntax:

--port src[:dest]

e.g.:

bmctool forward <host> --port 4443:443 --port 5900

If dest is unspecified, it'll be the same as src.

The host to use for tunneling can be specified via the --tunnel-host flag,
or the BMCTUNNELHOST environment variable.

The username to use to connect to the intermediate host can be specified via
the --username flag, or the BMCTUNNELUSER environment variable.`,
Args: cobra.MinimumNArgs(1),
Run: func(cmd *cobra.Command, args []string) {
dstHost := args[0]
forward(dstHost)
},
}

func init() {
rootCmd.AddCommand(forwardCmd)

viper.AutomaticEnv()

forwardCmd.Flags().StringArrayVar(&ports, "port", defaultPorts, "source:destination")
forwardCmd.Flags().StringVar(&tunnelHost, "tunnel-host",
viper.GetString("BMCTUNNELHOST"), "intermediate host")
forwardCmd.Flags().StringVar(&sshUser, "username",
viper.GetString("BMCTUNNELUSER"), "username for intermediate host")

}

// splitPorts takes a string containing either a "local:remote" ports pair
// or just "port" and returns local/remote as separate variables. If the string
// contains a single port, it returns the same port for local and remote.
func splitPorts(ports string) (int32, int32, error) {
split := strings.Split(ports, ":")

srcPort, err := strconv.ParseInt(split[0], 10, 32)
if err != nil {
return 0, 0, err
}

if len(split) == 1 {
return int32(srcPort), int32(srcPort), nil
}

dstPort, err := strconv.ParseInt(split[1], 10, 32)
if err != nil {
return 0, 0, err
}

return int32(srcPort), int32(dstPort), nil
}

func forward(dstHost string) {

sshConfig := &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{
tunnel.SSHAgent(),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

serverEndpoint := &tunnel.Endpoint{
Host: tunnelHost,
Port: 22,
}

errs, _ := errgroup.WithContext(context.Background())

for _, port := range ports {
srcPort, dstPort, err := splitPorts(port)
if err != nil {
log.Errorf("Cannot parse provided ports: %v", err)
osExit(1)
}

localEndpoint := &tunnel.Endpoint{
Host: "localhost",
Port: srcPort,
}

remoteEndpoint := &tunnel.Endpoint{
Host: dstHost,
Port: dstPort,
}

tunnel := &tunnel.SSHTunnel{
Config: sshConfig,
Local: localEndpoint,
Server: serverEndpoint,
Remote: remoteEndpoint,
}

log.Infof("Forwarding %s -> %s -> %s", localEndpoint, serverEndpoint, remoteEndpoint)
errs.Go(tunnel.Start)

}

errs.Wait()
}
12 changes: 12 additions & 0 deletions tunnel/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package tunnel

import "fmt"

type Endpoint struct {
Host string
Port int32
}

func (ep *Endpoint) String() string {
return fmt.Sprintf("%s:%d", ep.Host, ep.Port)
}
77 changes: 77 additions & 0 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package tunnel

import (
"io"
"net"
"os"

"github.com/apex/log"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// SSHTunnel represents an SSH tunnel.
type SSHTunnel struct {
// Local server endpoint
Local *Endpoint

// Intermediate server endpoint
Server *Endpoint

// Remote server endpoint
Remote *Endpoint

// Client configuration
Config *ssh.ClientConfig
}

// Start initializes the SSH tunnel.
func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
log.Errorf("Cannot listen on %s: %v", tunnel.Local, err)
return err
}
defer listener.Close()

for {
conn, err := listener.Accept()
if err != nil {
log.Errorf("Cannot accept connection: %v", err)
return err
}
go tunnel.forward(conn)
}
}

func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
log.Errorf("Server dial error: %s", err)
return
}

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
log.Errorf("Remote dial error: %s", err)
return
}

copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
log.Debugf("io.Copy error: %s", err)
}
}

go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}

// SSHAgent gets a ssh.AuthMethod from the local ssh-agent instance (if any).
func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}
111 changes: 111 additions & 0 deletions tunnel/tunnel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package tunnel

import (
"fmt"
"io"
"log"
"net"
"testing"
"time"

sshserver "github.com/gliderlabs/ssh"
"golang.org/x/crypto/ssh"
)

func TestSSHTunnel_Start(t *testing.T) {
handlerFunc := func(s sshserver.Session) {
io.WriteString(s, "test")
}

// Create intermediate SSH server.
bounceSSHListener, err := net.Listen("tcp", ":0")
bounceSSHServer := &sshserver.Server{
Handler: handlerFunc,
LocalPortForwardingCallback: func(ctx sshserver.Context,
destinationHost string, destinationPort uint32) bool {
return true
},
}
if err != nil {
t.Fatalf("Cannot create listener: %v", err)
}
go func() {
log.Fatal(bounceSSHServer.Serve(bounceSSHListener))
}()

// Create destination SSH server.
destSSHServer, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Cannot create listener: %v", err)
}
go func() {
log.Fatal(sshserver.Serve(destSSHServer, handlerFunc))
}()

sshConfig := &ssh.ClientConfig{
User: "test",
Auth: []ssh.AuthMethod{
SSHAgent(),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

tun := &SSHTunnel{
Local: &Endpoint{
Host: "127.0.0.1",
Port: int32(destSSHServer.Addr().(*net.TCPAddr).Port) + 1,
},
Server: &Endpoint{
Host: "127.0.0.1",
Port: int32(bounceSSHListener.Addr().(*net.TCPAddr).Port),
},
Remote: &Endpoint{
Host: "127.0.0.1",
Port: int32(destSSHServer.Addr().(*net.TCPAddr).Port),
},
Config: sshConfig,
}

go func() { tun.Start() }()

time.Sleep(2 * time.Second)

// Connect to the tunnel and verify that the received message is the
// expected one from the remote server.
cl, err := ssh.Dial("tcp", tun.Local.String(), sshConfig)
if err != nil {
t.Fatalf("Cannot connect to the local endpoint: %v", err)
}

sess, err := cl.NewSession()
if err != nil {
t.Fatalf("Cannot create SSH session: %v", err)
}

sshout, err := sess.StdoutPipe()
if err != nil {
t.Fatalf("Cannot pipe stdout: %v", err)
}

err = sess.Shell()
if err != nil {
t.Fatalf("Cannot start shell: %v", err)
}

output := readBuffForString(sshout)
if output != "test" {
t.Fatalf("Unexpected output: %s", output)
}
fmt.Println("Done.")

}

func readBuffForString(sshOut io.Reader) string {
buf := make([]byte, 1000)
n, err := sshOut.Read(buf) //this reads the ssh terminal
waitingString := ""
if err == nil {
waitingString = string(buf[:n])
}
return waitingString
}