Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssh: allow dialing named services in addition to port numbers #235

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
28 changes: 17 additions & 11 deletions ssh/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
}
ch := make(chan connErr)
go func() {
conn, err := c.Dial(n, addr)
conn, err := c.dialContext(ctx, n, addr)
select {
case ch <- connErr{conn, err}:
case <-ctx.Done():
Expand All @@ -369,7 +369,13 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err

// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
// For TCP addresses the port section of the address can be a port number or a service name.
// Service names are resolved at the client side, domain names are resolved on the server.
func (c *Client) Dial(n, addr string) (net.Conn, error) {
return c.dialContext(context.Background(), n, addr)
}

func (c *Client) dialContext(ctx context.Context, n, addr string) (net.Conn, error) {
var ch Channel
switch n {
case "tcp", "tcp4", "tcp6":
Expand All @@ -378,11 +384,11 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
port, err := net.DefaultResolver.LookupPort(ctx, n, portString)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
ch, err = c.dial(net.IPv4zero.String(), 0, host, port)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -441,18 +447,18 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)

// RFC 4254 7.2
type channelOpenDirectMsg struct {
raddr string
rport uint32
laddr string
lport uint32
Addr string
Port uint32
OriginAddr string
OriginPort uint32
}

func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
msg := channelOpenDirectMsg{
raddr: raddr,
rport: uint32(rport),
laddr: laddr,
lport: uint32(lport),
Addr: raddr,
Port: uint32(rport),
OriginAddr: laddr,
OriginPort: uint32(lport),
}
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
if err != nil {
Expand Down
79 changes: 79 additions & 0 deletions ssh/tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package ssh

import (
"context"
"fmt"
"net"
"testing"
"time"
Expand Down Expand Up @@ -51,3 +52,81 @@ func TestClientDialContextWithDeadline(t *testing.T) {
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
}
}

func TestDialNamedPort(t *testing.T) {
srvConn, clientConn, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer srvConn.Close()
defer clientConn.Close()

serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
srvErr := make(chan error, 10)
go func() {
defer close(srvErr)
_, chans, req, err := NewServerConn(srvConn, serverConf)
if err != nil {
srvErr <- fmt.Errorf("NewServerConn: %w", err)
return
}
go DiscardRequests(req)
for newChan := range chans {
if newChan.ChannelType() != "direct-tcpip" {
srvErr <- fmt.Errorf("expected direct-tcpip channel, got=%s", newChan.ChannelType())
if err := newChan.Reject(UnknownChannelType, "This test server only supports direct-tcpip"); err != nil {
srvErr <- err
}
continue
}
data := channelOpenDirectMsg{}
if err := Unmarshal(newChan.ExtraData(), &data); err != nil {
if err := newChan.Reject(ConnectionFailed, err.Error()); err != nil {
srvErr <- err
}
continue
}
// Below we dial for service `ssh` which should be translated to 22.
if data.Port != 22 {
if err := newChan.Reject(ConnectionFailed, fmt.Sprintf("expected port 22 got=%d", data.Port)); err != nil {
srvErr <- err
}
continue
}
ch, reqs, err := newChan.Accept()
if err != nil {
srvErr <- fmt.Errorf("Accept: %w", err)
continue
}
go DiscardRequests(reqs)
if err := ch.Close(); err != nil {
srvErr <- err
}
}
}()

clientConf := &ClientConfig{
User: "testuser",
HostKeyCallback: InsecureIgnoreHostKey(),
}
sshClientConn, newChans, reqs, err := NewClientConn(clientConn, "", clientConf)
if err != nil {
t.Fatal(err)
}
sshClient := NewClient(sshClientConn, newChans, reqs)

// The port section in the host:port string being a named service `ssh` is the main point of the test.
_, err = sshClient.Dial("tcp", "localhost:ssh")
if err != nil {
t.Error(err)
}

// Stop the ssh server.
clientConn.Close()
for err := range srvErr {
t.Errorf("ssh server: %s", err)
}
}