@@ -12,12 +12,15 @@ import (
1212
1313 "cdr.dev/slog"
1414 "cdr.dev/slog/sloggers/sloghuman"
15+ "github.com/fatih/color"
1516 "github.com/pion/webrtc/v3"
1617 "github.com/spf13/cobra"
18+ "golang.org/x/crypto/ssh"
1719 "golang.org/x/xerrors"
1820
1921 "cdr.dev/coder-cli/coder-sdk"
2022 "cdr.dev/coder-cli/internal/x/xcobra"
23+ "cdr.dev/coder-cli/pkg/clog"
2124 "cdr.dev/coder-cli/wsnet"
2225)
2326
@@ -59,20 +62,34 @@ coder tunnel my-dev 3000 3000
5962 }
6063 baseURL := sdk .BaseURL ()
6164
62- workspaces , err := getWorkspaces (ctx , sdk , coder .Me )
65+ workspace , err := findWorkspace (ctx , sdk , args [ 0 ] , coder .Me )
6366 if err != nil {
6467 return xerrors .Errorf ("get workspaces: %w" , err )
6568 }
6669
67- var workspaceID string
68- for _ , workspace := range workspaces {
69- if workspace .Name == args [0 ] {
70- workspaceID = workspace .ID
71- break
70+ if workspace .LatestStat .ContainerStatus != coder .WorkspaceOn {
71+ color .NoColor = false
72+ notAvailableError := clog .Error ("workspace not available" ,
73+ fmt .Sprintf ("current status: %q" , workspace .LatestStat .ContainerStatus ),
74+ clog .BlankLine ,
75+ clog .Tipf ("use \" coder workspaces rebuild %s\" to rebuild this workspace" , workspace .Name ),
76+ )
77+ // If we're attempting to forward our remote SSH port,
78+ // we want to communicate with the OpenSSH protocol so
79+ // SSH clients can properly display output to our users.
80+ if remotePort == 12213 {
81+ rawKey , err := sdk .SSHKey (ctx )
82+ if err != nil {
83+ return xerrors .Errorf ("get ssh key: %w" , err )
84+ }
85+ err = discardSSHConnection (& stdioConn {}, rawKey .PrivateKey , notAvailableError .String ())
86+ if err != nil {
87+ return err
88+ }
89+ return nil
7290 }
73- }
74- if workspaceID == "" {
75- return xerrors .Errorf ("No workspace found by name '%s'" , args [0 ])
91+
92+ return notAvailableError
7693 }
7794
7895 iceServers , err := sdk .ICEServers (ctx )
@@ -82,14 +99,14 @@ coder tunnel my-dev 3000 3000
8299 log .Debug (ctx , "got ICE servers" , slog .F ("ice" , iceServers ))
83100
84101 c := & tunnneler {
85- log : log ,
86- brokerAddr : & baseURL ,
87- token : sdk .Token (),
88- workspaceID : workspaceID ,
89- iceServers : iceServers ,
90- stdio : args [2 ] == "stdio" ,
91- localPort : uint16 (localPort ),
92- remotePort : uint16 (remotePort ),
102+ log : log ,
103+ brokerAddr : & baseURL ,
104+ token : sdk .Token (),
105+ workspace : workspace ,
106+ iceServers : iceServers ,
107+ stdio : args [2 ] == "stdio" ,
108+ localPort : uint16 (localPort ),
109+ remotePort : uint16 (remotePort ),
93110 }
94111
95112 err = c .start (ctx )
@@ -105,14 +122,14 @@ coder tunnel my-dev 3000 3000
105122}
106123
107124type tunnneler struct {
108- log slog.Logger
109- brokerAddr * url.URL
110- token string
111- workspaceID string
112- iceServers []webrtc.ICEServer
113- remotePort uint16
114- localPort uint16
115- stdio bool
125+ log slog.Logger
126+ brokerAddr * url.URL
127+ token string
128+ workspace * coder. Workspace
129+ iceServers []webrtc.ICEServer
130+ remotePort uint16
131+ localPort uint16
132+ stdio bool
116133}
117134
118135func (c * tunnneler ) start (ctx context.Context ) error {
@@ -121,7 +138,7 @@ func (c *tunnneler) start(ctx context.Context) error {
121138 dialLog := c .log .Named ("wsnet" )
122139 wd , err := wsnet .DialWebsocket (
123140 ctx ,
124- wsnet .ConnectEndpoint (c .brokerAddr , c .workspaceID , c .token ),
141+ wsnet .ConnectEndpoint (c .brokerAddr , c .workspace . ID , c .token ),
125142 & wsnet.DialOptions {
126143 Log : & dialLog ,
127144 TURNProxyAuthToken : c .token ,
@@ -156,7 +173,7 @@ func (c *tunnneler) start(ctx context.Context) error {
156173 return
157174 case <- ticker .C :
158175 // silently ignore failures so we don't spam the console
159- _ = sdk .UpdateLastConnectionAt (ctx , c .workspaceID )
176+ _ = sdk .UpdateLastConnectionAt (ctx , c .workspace . ID )
160177 }
161178 }
162179 }()
@@ -203,3 +220,78 @@ func (c *tunnneler) start(ctx context.Context) error {
203220 }()
204221 }
205222}
223+
224+ // Used to treat stdio like a connection for proxying SSH.
225+ type stdioConn struct {}
226+
227+ func (s * stdioConn ) Read (b []byte ) (n int , err error ) {
228+ return os .Stdin .Read (b )
229+ }
230+
231+ func (s * stdioConn ) Write (b []byte ) (n int , err error ) {
232+ return os .Stdout .Write (b )
233+ }
234+
235+ func (s * stdioConn ) Close () error {
236+ return nil
237+ }
238+
239+ func (s * stdioConn ) LocalAddr () net.Addr {
240+ return nil
241+ }
242+
243+ func (s * stdioConn ) RemoteAddr () net.Addr {
244+ return nil
245+ }
246+
247+ func (s * stdioConn ) SetDeadline (t time.Time ) error {
248+ return nil
249+ }
250+
251+ func (s * stdioConn ) SetReadDeadline (t time.Time ) error {
252+ return nil
253+ }
254+
255+ func (s * stdioConn ) SetWriteDeadline (t time.Time ) error {
256+ return nil
257+ }
258+
259+ // discardSSHConnection accepts a connection then outputs the message provided
260+ // to any channel opened, immediately closing the connection afterwards.
261+ //
262+ // Used to provide status to connecting clients while still aligning with the
263+ // native SSH protocol.
264+ func discardSSHConnection (nc net.Conn , privateKey string , msg string ) error {
265+ config := & ssh.ServerConfig {
266+ NoClientAuth : true ,
267+ }
268+ key , err := ssh .ParseRawPrivateKey ([]byte (privateKey ))
269+ if err != nil {
270+ return fmt .Errorf ("parse private key: %w" , err )
271+ }
272+ signer , err := ssh .NewSignerFromKey (key )
273+ if err != nil {
274+ return fmt .Errorf ("signer from private key: %w" , err )
275+ }
276+ config .AddHostKey (signer )
277+ conn , chans , reqs , err := ssh .NewServerConn (nc , config )
278+ if err != nil {
279+ return fmt .Errorf ("create server conn: %w" , err )
280+ }
281+ go ssh .DiscardRequests (reqs )
282+ ch , req , err := (<- chans ).Accept ()
283+ if err != nil {
284+ return fmt .Errorf ("accept channel: %w" , err )
285+ }
286+ go ssh .DiscardRequests (req )
287+
288+ _ , err = ch .Write ([]byte (msg ))
289+ if err != nil {
290+ return fmt .Errorf ("write channel: %w" , err )
291+ }
292+ err = ch .Close ()
293+ if err != nil {
294+ return fmt .Errorf ("close channel: %w" , err )
295+ }
296+ return conn .Close ()
297+ }
0 commit comments