Skip to content

Commit

Permalink
feat: Separate workspace agent for tests (#567)
Browse files Browse the repository at this point in the history
This adds tests for Google Cloud authentication, and lays
the ground-work for future agent auth types in the future.
  • Loading branch information
kylecarbs committed Mar 25, 2022
1 parent 21fdb80 commit 6be949a
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 208 deletions.
48 changes: 9 additions & 39 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,6 @@ import (
"golang.org/x/xerrors"
)

func DialSSH(conn *peer.Conn) (net.Conn, error) {
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
return nil, err
}
return channel.NetConn(), nil
}

func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
netConn, err := DialSSH(conn)
if err != nil {
return nil, err
}
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{
Config: gossh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, err
}
return gossh.NewClient(sshConn, channels, requests), nil
}

type Options struct {
Logger slog.Logger
}
Expand All @@ -64,7 +34,7 @@ type Dialer func(ctx context.Context, options *peer.ConnOptions) (*peerbroker.Li

func New(dialer Dialer, options *peer.ConnOptions) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
server := &server{
server := &agent{
clientDialer: dialer,
options: options,
closeCancel: cancelFunc,
Expand All @@ -74,7 +44,7 @@ func New(dialer Dialer, options *peer.ConnOptions) io.Closer {
return server
}

type server struct {
type agent struct {
clientDialer Dialer
options *peer.ConnOptions

Expand All @@ -86,7 +56,7 @@ type server struct {
sshServer *ssh.Server
}

func (s *server) run(ctx context.Context) {
func (s *agent) run(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
Expand All @@ -103,7 +73,7 @@ func (s *server) run(ctx context.Context) {
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
s.options.Logger.Debug(context.Background(), "connected")
s.options.Logger.Info(context.Background(), "connected")
break
}
select {
Expand All @@ -129,7 +99,7 @@ func (s *server) run(ctx context.Context) {
}
}

func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() {
<-conn.Closed()
s.connCloseWait.Done()
Expand All @@ -156,7 +126,7 @@ func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
}
}

func (s *server) init(ctx context.Context) {
func (s *agent) init(ctx context.Context) {
// Clients' should ignore the host key when connecting.
// The agent needs to authenticate with coderd to SSH,
// so SSH authentication doesn't improve security.
Expand Down Expand Up @@ -221,7 +191,7 @@ func (s *server) init(ctx context.Context) {
go s.run(ctx)
}

func (*server) handleSSHSession(session ssh.Session) error {
func (*agent) handleSSHSession(session ssh.Session) error {
var (
command string
args = []string{}
Expand Down Expand Up @@ -316,7 +286,7 @@ func (*server) handleSSHSession(session ssh.Session) error {
}

// isClosed returns whether the API is closed or not.
func (s *server) isClosed() bool {
func (s *agent) isClosed() bool {
select {
case <-s.closed:
return true
Expand All @@ -325,7 +295,7 @@ func (s *server) isClosed() bool {
}
}

func (s *server) Close() error {
func (s *agent) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
Expand Down
6 changes: 4 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
client := agent.Conn{conn}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
Expand All @@ -64,7 +65,8 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
sshClient, err := agent.DialSSHClient(conn)
client := &agent.Conn{conn}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
Expand Down
50 changes: 50 additions & 0 deletions agent/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package agent

import (
"context"
"net"

"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"github.com/coder/coder/peer"
)

// Conn wraps a peer connection with helper functions to
// communicate with the agent.
type Conn struct {
*peer.Conn
}

// SSH dials the built-in SSH server.
func (c *Conn) SSH() (net.Conn, error) {
channel, err := c.Dial(context.Background(), "ssh", &peer.ChannelOptions{
Protocol: "ssh",
})
if err != nil {
return nil, xerrors.Errorf("dial: %w", err)
}
return channel.NetConn(), nil
}

// SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput.
func (c *Conn) SSHClient() (*ssh.Client, error) {
netConn, err := c.SSH()
if err != nil {
return nil, xerrors.Errorf("ssh: %w", err)
}
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
Config: ssh.Config{
Ciphers: []string{"arcfour"},
},
// SSH host validation isn't helpful, because obtaining a peer
// connection already signifies user-intent to dial a workspace.
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, xerrors.Errorf("ssh conn: %w", err)
}
return ssh.NewClient(sshConn, channels, requests), nil
}
106 changes: 61 additions & 45 deletions cli/ssh.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package cli

import (
"fmt"
"os"

"github.com/pion/webrtc/v3"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"golang.org/x/xerrors"

"github.com/coder/coder/agent"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
)

func workspaceSSH() *cobra.Command {
Expand All @@ -26,58 +24,76 @@ func workspaceSSH() *cobra.Command {
if err != nil {
return err
}
if workspace.LatestBuild.Transition == database.WorkspaceTransitionDelete {
return xerrors.New("workspace is deleting...")
}
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
if err != nil {
return err
}

resourceByAddress := make(map[string]codersdk.WorkspaceResource)
for _, resource := range resources {
_, _ = fmt.Printf("Got resource: %+v\n", resource)
if resource.Agent == nil {
continue
}

dialed, err := client.DialWorkspaceAgent(cmd.Context(), resource.ID)
if err != nil {
return err
}
stream, err := dialed.NegotiateConnection(cmd.Context())
if err != nil {
return err
resourceByAddress[resource.Address] = resource
}
var resourceAddress string
if len(args) >= 2 {
resourceAddress = args[1]
} else {
// No resource name was provided!
if len(resourceByAddress) > 1 {
// List available resources to connect into?
return xerrors.Errorf("multiple agents")
}
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{})
if err != nil {
return err
for _, resource := range resourceByAddress {
resourceAddress = resource.Address
break
}
client, err := agent.DialSSHClient(conn)
if err != nil {
return err
}
resource, exists := resourceByAddress[resourceAddress]
if !exists {
resourceKeys := make([]string, 0)
for resourceKey := range resourceByAddress {
resourceKeys = append(resourceKeys, resourceKey)
}
return xerrors.Errorf("no sshable agent with address %q: %+v", resourceAddress, resourceKeys)
}
if resource.Agent.LastConnectedAt == nil {
return xerrors.Errorf("agent hasn't connected yet")
}

session, err := client.NewSession()
if err != nil {
return err
}
_, _ = term.MakeRaw(int(os.Stdin.Fd()))
err = session.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.OCRNL: 1,
})
if err != nil {
return err
}
session.Stdin = os.Stdin
session.Stdout = os.Stdout
session.Stderr = os.Stderr
err = session.Shell()
if err != nil {
return err
}
err = session.Wait()
if err != nil {
return err
}
conn, err := client.DialWorkspaceAgent(cmd.Context(), resource.ID, nil, nil)
if err != nil {
return err
}
sshClient, err := conn.SSHClient()
if err != nil {
return err
}

sshSession, err := sshClient.NewSession()
if err != nil {
return err
}
_, _ = term.MakeRaw(int(os.Stdin.Fd()))
err = sshSession.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.OCRNL: 1,
})
if err != nil {
return err
}
sshSession.Stdin = os.Stdin
sshSession.Stdout = os.Stdout
sshSession.Stderr = os.Stderr
err = sshSession.Shell()
if err != nil {
return err
}
err = sshSession.Wait()
if err != nil {
return err
}

return nil
Expand Down

0 comments on commit 6be949a

Please sign in to comment.