Skip to content

Commit 630c0dc

Browse files
committed
refactor: improve connection handling with context-based cancellation and reduced code duplication
Change-Id: I808b16216e09a94eca58cd8e95c99b3bb927596f Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 913fa9c commit 630c0dc

File tree

2 files changed

+139
-118
lines changed

2 files changed

+139
-118
lines changed

go/connection.go

Lines changed: 118 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package acp
22

33
import (
44
"bufio"
5+
"bytes"
56
"context"
67
"encoding/json"
8+
"errors"
79
"io"
810
"sync"
911
"sync/atomic"
@@ -34,71 +36,72 @@ type Connection struct {
3436
nextID atomic.Uint64
3537
pending map[string]*pendingResponse
3638

37-
done chan struct{}
39+
ctx context.Context
40+
cancel context.CancelCauseFunc
3841
}
3942

4043
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
44+
ctx, cancel := context.WithCancelCause(context.Background())
4145
c := &Connection{
4246
w: peerInput,
4347
r: peerOutput,
4448
handler: handler,
4549
pending: make(map[string]*pendingResponse),
46-
done: make(chan struct{}),
50+
ctx: ctx,
51+
cancel: cancel,
4752
}
4853
go c.receive()
4954
return c
5055
}
5156

5257
func (c *Connection) receive() {
58+
const (
59+
initialBufSize = 1024 * 1024
60+
maxBufSize = 10 * 1024 * 1024
61+
)
62+
5363
scanner := bufio.NewScanner(c.r)
54-
// increase buffer if needed
55-
buf := make([]byte, 0, 1024*1024)
56-
scanner.Buffer(buf, 10*1024*1024)
64+
buf := make([]byte, 0, initialBufSize)
65+
scanner.Buffer(buf, maxBufSize)
66+
5767
for scanner.Scan() {
5868
line := scanner.Bytes()
59-
if len(bytesTrimSpace(line)) == 0 {
69+
if len(bytes.TrimSpace(line)) == 0 {
6070
continue
6171
}
72+
6273
var msg anyMessage
6374
if err := json.Unmarshal(line, &msg); err != nil {
64-
// ignore parse errors on inbound
6575
continue
6676
}
67-
if msg.ID != nil && msg.Method == "" {
68-
// response
69-
idStr := string(*msg.ID)
70-
c.mu.Lock()
71-
pr := c.pending[idStr]
72-
if pr != nil {
73-
delete(c.pending, idStr)
74-
}
75-
c.mu.Unlock()
76-
if pr != nil {
77-
pr.ch <- msg
78-
}
79-
continue
80-
}
81-
if msg.Method != "" {
82-
// request or notification
77+
78+
switch {
79+
case msg.ID != nil && msg.Method == "":
80+
c.handleResponse(&msg)
81+
case msg.Method != "":
8382
go c.handleInbound(&msg)
8483
}
8584
}
86-
// Signal completion on EOF or read error
85+
86+
c.cancel(errors.New("peer connection closed"))
87+
}
88+
89+
func (c *Connection) handleResponse(msg *anyMessage) {
90+
idStr := string(*msg.ID)
91+
8792
c.mu.Lock()
88-
if c.done != nil {
89-
close(c.done)
90-
c.done = nil
93+
pr := c.pending[idStr]
94+
if pr != nil {
95+
delete(c.pending, idStr)
9196
}
9297
c.mu.Unlock()
98+
99+
if pr != nil {
100+
pr.ch <- *msg
101+
}
93102
}
94103

95104
func (c *Connection) handleInbound(req *anyMessage) {
96-
// Context that cancels when the connection is closed
97-
ctx, cancel := context.WithCancel(context.Background())
98-
go func() {
99-
<-c.Done()
100-
cancel()
101-
}()
102105
res := anyMessage{JSONRPC: "2.0"}
103106
// copy ID if present
104107
if req.ID != nil {
@@ -112,7 +115,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
112115
return
113116
}
114117

115-
result, err := c.handler(ctx, req.Method, req.Params)
118+
result, err := c.handler(c.ctx, req.Method, req.Params)
116119
if req.ID == nil {
117120
// notification: nothing to send
118121
return
@@ -147,93 +150,101 @@ func (c *Connection) sendMessage(msg anyMessage) error {
147150
// SendRequest sends a JSON-RPC request and returns a typed result.
148151
// For methods that do not return a result, use SendRequestNoResult instead.
149152
func SendRequest[T any](c *Connection, ctx context.Context, method string, params any) (T, error) {
150-
var zero T
151-
// allocate id
152-
id := c.nextID.Add(1)
153-
idRaw, _ := json.Marshal(id)
154-
msg := anyMessage{
155-
JSONRPC: "2.0",
156-
ID: (*json.RawMessage)(&idRaw),
157-
Method: method,
158-
}
159-
if params != nil {
160-
b, err := json.Marshal(params)
161-
if err != nil {
162-
return zero, NewInvalidParams(map[string]any{"error": err.Error()})
163-
}
164-
msg.Params = b
153+
var result T
154+
155+
msg, idKey, err := c.prepareRequest(method, params)
156+
if err != nil {
157+
return result, err
165158
}
159+
166160
pr := &pendingResponse{ch: make(chan anyMessage, 1)}
167-
idKey := string(idRaw)
168161
c.mu.Lock()
169162
c.pending[idKey] = pr
170163
c.mu.Unlock()
164+
171165
if err := c.sendMessage(msg); err != nil {
172-
return zero, NewInternalError(map[string]any{"error": err.Error()})
166+
c.cleanupPending(idKey)
167+
return result, NewInternalError(map[string]any{"error": err.Error()})
173168
}
174-
// wait for response or peer disconnect
175-
var resp anyMessage
176-
d := c.Done()
177-
select {
178-
case resp = <-pr.ch:
179-
case <-ctx.Done():
180-
// best-effort cleanup
181-
c.mu.Lock()
182-
delete(c.pending, idKey)
183-
c.mu.Unlock()
184-
return zero, NewInternalError(map[string]any{"error": ctx.Err().Error()})
185-
case <-d:
186-
return zero, NewInternalError(map[string]any{"error": "peer disconnected before response"})
169+
170+
resp, err := c.waitForResponse(ctx, pr, idKey)
171+
if err != nil {
172+
return result, err
187173
}
174+
188175
if resp.Error != nil {
189-
return zero, resp.Error
176+
return result, resp.Error
190177
}
191-
var out T
178+
192179
if len(resp.Result) > 0 {
193-
if err := json.Unmarshal(resp.Result, &out); err != nil {
194-
return zero, NewInternalError(map[string]any{"error": err.Error()})
180+
if err := json.Unmarshal(resp.Result, &result); err != nil {
181+
return result, NewInternalError(map[string]any{"error": err.Error()})
195182
}
196183
}
197-
return out, nil
184+
return result, nil
198185
}
199186

200-
// SendRequestNoResult sends a JSON-RPC request that returns no result payload.
201-
func (c *Connection) SendRequestNoResult(ctx context.Context, method string, params any) error {
202-
// allocate id
187+
func (c *Connection) prepareRequest(method string, params any) (anyMessage, string, error) {
203188
id := c.nextID.Add(1)
204189
idRaw, _ := json.Marshal(id)
190+
205191
msg := anyMessage{
206192
JSONRPC: "2.0",
207193
ID: (*json.RawMessage)(&idRaw),
208194
Method: method,
209195
}
196+
210197
if params != nil {
211198
b, err := json.Marshal(params)
212199
if err != nil {
213-
return NewInvalidParams(map[string]any{"error": err.Error()})
200+
return msg, "", NewInvalidParams(map[string]any{"error": err.Error()})
214201
}
215202
msg.Params = b
216203
}
204+
205+
return msg, string(idRaw), nil
206+
}
207+
208+
func (c *Connection) waitForResponse(ctx context.Context, pr *pendingResponse, idKey string) (anyMessage, error) {
209+
select {
210+
case resp := <-pr.ch:
211+
return resp, nil
212+
case <-ctx.Done():
213+
c.cleanupPending(idKey)
214+
return anyMessage{}, NewInternalError(map[string]any{"error": context.Cause(ctx).Error()})
215+
case <-c.Done():
216+
return anyMessage{}, NewInternalError(map[string]any{"error": "peer disconnected before response"})
217+
}
218+
}
219+
220+
func (c *Connection) cleanupPending(idKey string) {
221+
c.mu.Lock()
222+
delete(c.pending, idKey)
223+
c.mu.Unlock()
224+
}
225+
226+
// SendRequestNoResult sends a JSON-RPC request that returns no result payload.
227+
func (c *Connection) SendRequestNoResult(ctx context.Context, method string, params any) error {
228+
msg, idKey, err := c.prepareRequest(method, params)
229+
if err != nil {
230+
return err
231+
}
232+
217233
pr := &pendingResponse{ch: make(chan anyMessage, 1)}
218-
idKey := string(idRaw)
219234
c.mu.Lock()
220235
c.pending[idKey] = pr
221236
c.mu.Unlock()
237+
222238
if err := c.sendMessage(msg); err != nil {
239+
c.cleanupPending(idKey)
223240
return NewInternalError(map[string]any{"error": err.Error()})
224241
}
225-
var resp anyMessage
226-
d := c.Done()
227-
select {
228-
case resp = <-pr.ch:
229-
case <-ctx.Done():
230-
c.mu.Lock()
231-
delete(c.pending, idKey)
232-
c.mu.Unlock()
233-
return NewInternalError(map[string]any{"error": ctx.Err().Error()})
234-
case <-d:
235-
return NewInternalError(map[string]any{"error": "peer disconnected before response"})
242+
243+
resp, err := c.waitForResponse(ctx, pr, idKey)
244+
if err != nil {
245+
return err
236246
}
247+
237248
if resp.Error != nil {
238249
return resp.Error
239250
}
@@ -246,43 +257,37 @@ func (c *Connection) SendNotification(ctx context.Context, method string, params
246257
return NewInternalError(map[string]any{"error": ctx.Err().Error()})
247258
default:
248259
}
249-
msg := anyMessage{JSONRPC: "2.0", Method: method}
260+
261+
msg, err := c.prepareNotification(method, params)
262+
if err != nil {
263+
return err
264+
}
265+
266+
if err := c.sendMessage(msg); err != nil {
267+
return NewInternalError(map[string]any{"error": err.Error()})
268+
}
269+
return nil
270+
}
271+
272+
func (c *Connection) prepareNotification(method string, params any) (anyMessage, error) {
273+
msg := anyMessage{
274+
JSONRPC: "2.0",
275+
Method: method,
276+
}
277+
250278
if params != nil {
251279
b, err := json.Marshal(params)
252280
if err != nil {
253-
return NewInvalidParams(map[string]any{"error": err.Error()})
281+
return msg, NewInvalidParams(map[string]any{"error": err.Error()})
254282
}
255283
msg.Params = b
256284
}
257-
if err := c.sendMessage(msg); err != nil {
258-
return NewInternalError(map[string]any{"error": err.Error()})
259-
}
260-
return nil
285+
286+
return msg, nil
261287
}
262288

263289
// Done returns a channel that is closed when the underlying reader loop exits
264290
// (typically when the peer disconnects or the input stream is closed).
265291
func (c *Connection) Done() <-chan struct{} {
266-
c.mu.Lock()
267-
d := c.done
268-
c.mu.Unlock()
269-
return d
270-
}
271-
272-
// Helper: lightweight TrimSpace for []byte without importing bytes only for this.
273-
func bytesTrimSpace(b []byte) []byte {
274-
i := 0
275-
for ; i < len(b); i++ {
276-
if b[i] != ' ' && b[i] != '\t' && b[i] != '\r' && b[i] != '\n' {
277-
break
278-
}
279-
}
280-
j := len(b)
281-
for j > i {
282-
if b[j-1] != ' ' && b[j-1] != '\t' && b[j-1] != '\r' && b[j-1] != '\n' {
283-
break
284-
}
285-
j--
286-
}
287-
return b[i:j]
292+
return c.ctx.Done()
288293
}

0 commit comments

Comments
 (0)