@@ -3,12 +3,10 @@ package agent
33import (
44 "bytes"
55 "context"
6- "encoding/binary"
76 "encoding/json"
87 "errors"
98 "fmt"
109 "io"
11- "net"
1210 "net/http"
1311 "net/netip"
1412 "os"
@@ -216,8 +214,8 @@ type agent struct {
216214 portCacheDuration time.Duration
217215 subsystems []codersdk.AgentSubsystem
218216
219- reconnectingPTYs sync.Map
220217 reconnectingPTYTimeout time.Duration
218+ reconnectingPTYServer * reconnectingpty.Server
221219
222220 // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
223221 // to start gracefully shutting down and "hard" which is Done when it is time to close
@@ -252,8 +250,6 @@ type agent struct {
252250 statsReporter * statsReporter
253251 logSender * agentsdk.LogSender
254252
255- connCountReconnectingPTY atomic.Int64
256-
257253 prometheusRegistry * prometheus.Registry
258254 // metrics are prometheus registered metrics that will be collected and
259255 // labeled in Coder with the agent + workspace.
@@ -297,6 +293,13 @@ func (a *agent) init() {
297293 // Register runner metrics. If the prom registry is nil, the metrics
298294 // will not report anywhere.
299295 a .scriptRunner .RegisterMetrics (a .prometheusRegistry )
296+
297+ a .reconnectingPTYServer = reconnectingpty .NewServer (
298+ a .logger .Named ("reconnecting-pty" ),
299+ a .sshServer ,
300+ a .metrics .connectionsTotal , a .metrics .reconnectingPTYErrors ,
301+ a .reconnectingPTYTimeout ,
302+ )
300303 go a .runLoop ()
301304}
302305
@@ -1181,55 +1184,12 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
11811184 }
11821185 }()
11831186 if err = a .trackGoroutine (func () {
1184- logger := a .logger .Named ("reconnecting-pty" )
1185- var wg sync.WaitGroup
1186- for {
1187- conn , err := reconnectingPTYListener .Accept ()
1188- if err != nil {
1189- if ! a .isClosed () {
1190- logger .Debug (ctx , "accept pty failed" , slog .Error (err ))
1191- }
1192- break
1193- }
1194- clog := logger .With (
1195- slog .F ("remote" , conn .RemoteAddr ().String ()),
1196- slog .F ("local" , conn .LocalAddr ().String ()))
1197- clog .Info (ctx , "accepted conn" )
1198- wg .Add (1 )
1199- closed := make (chan struct {})
1200- go func () {
1201- select {
1202- case <- closed :
1203- case <- a .hardCtx .Done ():
1204- _ = conn .Close ()
1205- }
1206- wg .Done ()
1207- }()
1208- go func () {
1209- defer close (closed )
1210- // This cannot use a JSON decoder, since that can
1211- // buffer additional data that is required for the PTY.
1212- rawLen := make ([]byte , 2 )
1213- _ , err = conn .Read (rawLen )
1214- if err != nil {
1215- return
1216- }
1217- length := binary .LittleEndian .Uint16 (rawLen )
1218- data := make ([]byte , length )
1219- _ , err = conn .Read (data )
1220- if err != nil {
1221- return
1222- }
1223- var msg workspacesdk.AgentReconnectingPTYInit
1224- err = json .Unmarshal (data , & msg )
1225- if err != nil {
1226- logger .Warn (ctx , "failed to unmarshal init" , slog .F ("raw" , data ))
1227- return
1228- }
1229- _ = a .handleReconnectingPTY (ctx , clog , msg , conn )
1230- }()
1187+ rPTYServeErr := a .reconnectingPTYServer .Serve (a .gracefulCtx , a .hardCtx , reconnectingPTYListener )
1188+ if rPTYServeErr != nil &&
1189+ a .gracefulCtx .Err () == nil &&
1190+ ! strings .Contains (rPTYServeErr .Error (), "use of closed network connection" ) {
1191+ a .logger .Error (ctx , "error serving reconnecting PTY" , slog .Error (err ))
12311192 }
1232- wg .Wait ()
12331193 }); err != nil {
12341194 return nil , err
12351195 }
@@ -1308,9 +1268,9 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
13081268 _ = server .Close ()
13091269 }()
13101270
1311- err := server .Serve (apiListener )
1312- if err != nil && ! xerrors .Is (err , http .ErrServerClosed ) && ! strings .Contains (err .Error (), "use of closed network connection" ) {
1313- a .logger .Critical (ctx , "serve HTTP API server" , slog .Error (err ))
1271+ apiServErr := server .Serve (apiListener )
1272+ if apiServErr != nil && ! xerrors .Is (apiServErr , http .ErrServerClosed ) && ! strings .Contains (apiServErr .Error (), "use of closed network connection" ) {
1273+ a .logger .Critical (ctx , "serve HTTP API server" , slog .Error (apiServErr ))
13141274 }
13151275 }); err != nil {
13161276 return nil , err
@@ -1394,87 +1354,6 @@ func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.D
13941354 }
13951355}
13961356
1397- func (a * agent ) handleReconnectingPTY (ctx context.Context , logger slog.Logger , msg workspacesdk.AgentReconnectingPTYInit , conn net.Conn ) (retErr error ) {
1398- defer conn .Close ()
1399- a .metrics .connectionsTotal .Add (1 )
1400-
1401- a .connCountReconnectingPTY .Add (1 )
1402- defer a .connCountReconnectingPTY .Add (- 1 )
1403-
1404- connectionID := uuid .NewString ()
1405- connLogger := logger .With (slog .F ("message_id" , msg .ID ), slog .F ("connection_id" , connectionID ))
1406- connLogger .Debug (ctx , "starting handler" )
1407-
1408- defer func () {
1409- if err := retErr ; err != nil {
1410- a .closeMutex .Lock ()
1411- closed := a .isClosed ()
1412- a .closeMutex .Unlock ()
1413-
1414- // If the agent is closed, we don't want to
1415- // log this as an error since it's expected.
1416- if closed {
1417- connLogger .Info (ctx , "reconnecting pty failed with attach error (agent closed)" , slog .Error (err ))
1418- } else {
1419- connLogger .Error (ctx , "reconnecting pty failed with attach error" , slog .Error (err ))
1420- }
1421- }
1422- connLogger .Info (ctx , "reconnecting pty connection closed" )
1423- }()
1424-
1425- var rpty reconnectingpty.ReconnectingPTY
1426- sendConnected := make (chan reconnectingpty.ReconnectingPTY , 1 )
1427- // On store, reserve this ID to prevent multiple concurrent new connections.
1428- waitReady , ok := a .reconnectingPTYs .LoadOrStore (msg .ID , sendConnected )
1429- if ok {
1430- close (sendConnected ) // Unused.
1431- connLogger .Debug (ctx , "connecting to existing reconnecting pty" )
1432- c , ok := waitReady .(chan reconnectingpty.ReconnectingPTY )
1433- if ! ok {
1434- return xerrors .Errorf ("found invalid type in reconnecting pty map: %T" , waitReady )
1435- }
1436- rpty , ok = <- c
1437- if ! ok || rpty == nil {
1438- return xerrors .Errorf ("reconnecting pty closed before connection" )
1439- }
1440- c <- rpty // Put it back for the next reconnect.
1441- } else {
1442- connLogger .Debug (ctx , "creating new reconnecting pty" )
1443-
1444- connected := false
1445- defer func () {
1446- if ! connected && retErr != nil {
1447- a .reconnectingPTYs .Delete (msg .ID )
1448- close (sendConnected )
1449- }
1450- }()
1451-
1452- // Empty command will default to the users shell!
1453- cmd , err := a .sshServer .CreateCommand (ctx , msg .Command , nil )
1454- if err != nil {
1455- a .metrics .reconnectingPTYErrors .WithLabelValues ("create_command" ).Add (1 )
1456- return xerrors .Errorf ("create command: %w" , err )
1457- }
1458-
1459- rpty = reconnectingpty .New (ctx , cmd , & reconnectingpty.Options {
1460- Timeout : a .reconnectingPTYTimeout ,
1461- Metrics : a .metrics .reconnectingPTYErrors ,
1462- }, logger .With (slog .F ("message_id" , msg .ID )))
1463-
1464- if err = a .trackGoroutine (func () {
1465- rpty .Wait ()
1466- a .reconnectingPTYs .Delete (msg .ID )
1467- }); err != nil {
1468- rpty .Close (err )
1469- return xerrors .Errorf ("start routine: %w" , err )
1470- }
1471-
1472- connected = true
1473- sendConnected <- rpty
1474- }
1475- return rpty .Attach (ctx , connectionID , conn , msg .Height , msg .Width , connLogger )
1476- }
1477-
14781357// Collect collects additional stats from the agent
14791358func (a * agent ) Collect (ctx context.Context , networkStats map [netlogtype.Connection ]netlogtype.Counts ) * proto.Stats {
14801359 a .logger .Debug (context .Background (), "computing stats report" )
@@ -1496,7 +1375,7 @@ func (a *agent) Collect(ctx context.Context, networkStats map[netlogtype.Connect
14961375 stats .SessionCountVscode = sshStats .VSCode
14971376 stats .SessionCountJetbrains = sshStats .JetBrains
14981377
1499- stats .SessionCountReconnectingPty = a .connCountReconnectingPTY . Load ()
1378+ stats .SessionCountReconnectingPty = a .reconnectingPTYServer . ConnCount ()
15001379
15011380 // Compute the median connection latency!
15021381 a .logger .Debug (ctx , "starting peer latency measurement for stats" )
0 commit comments