@@ -2,8 +2,10 @@ package acp
22
33import (
44 "bufio"
5+ "bytes"
56 "context"
67 "encoding/json"
8+ "errors"
79 "io"
810 "sync"
911 "sync/atomic"
@@ -34,71 +36,72 @@ type Connection struct {
3436 nextID atomic.Uint64
3537 pending map [string ]* pendingResponse
3638
37- done chan struct {}
39+ ctx context.Context
40+ cancel context.CancelCauseFunc
3841}
3942
4043func NewConnection (handler MethodHandler , peerInput io.Writer , peerOutput io.Reader ) * Connection {
44+ ctx , cancel := context .WithCancelCause (context .Background ())
4145 c := & Connection {
4246 w : peerInput ,
4347 r : peerOutput ,
4448 handler : handler ,
4549 pending : make (map [string ]* pendingResponse ),
46- done : make (chan struct {}),
50+ ctx : ctx ,
51+ cancel : cancel ,
4752 }
4853 go c .receive ()
4954 return c
5055}
5156
5257func (c * Connection ) receive () {
58+ const (
59+ initialBufSize = 1024 * 1024
60+ maxBufSize = 10 * 1024 * 1024
61+ )
62+
5363 scanner := bufio .NewScanner (c .r )
54- // increase buffer if needed
55- buf := make ([] byte , 0 , 1024 * 1024 )
56- scanner . Buffer ( buf , 10 * 1024 * 1024 )
64+ buf := make ([] byte , 0 , initialBufSize )
65+ scanner . Buffer ( buf , maxBufSize )
66+
5767 for scanner .Scan () {
5868 line := scanner .Bytes ()
59- if len (bytesTrimSpace (line )) == 0 {
69+ if len (bytes . TrimSpace (line )) == 0 {
6070 continue
6171 }
72+
6273 var msg anyMessage
6374 if err := json .Unmarshal (line , & msg ); err != nil {
64- // ignore parse errors on inbound
6575 continue
6676 }
67- if msg .ID != nil && msg .Method == "" {
68- // response
69- idStr := string (* msg .ID )
70- c .mu .Lock ()
71- pr := c .pending [idStr ]
72- if pr != nil {
73- delete (c .pending , idStr )
74- }
75- c .mu .Unlock ()
76- if pr != nil {
77- pr .ch <- msg
78- }
79- continue
80- }
81- if msg .Method != "" {
82- // request or notification
77+
78+ switch {
79+ case msg .ID != nil && msg .Method == "" :
80+ c .handleResponse (& msg )
81+ case msg .Method != "" :
8382 go c .handleInbound (& msg )
8483 }
8584 }
86- // Signal completion on EOF or read error
85+
86+ c .cancel (errors .New ("peer connection closed" ))
87+ }
88+
89+ func (c * Connection ) handleResponse (msg * anyMessage ) {
90+ idStr := string (* msg .ID )
91+
8792 c .mu .Lock ()
88- if c . done != nil {
89- close ( c . done )
90- c . done = nil
93+ pr := c . pending [ idStr ]
94+ if pr != nil {
95+ delete ( c . pending , idStr )
9196 }
9297 c .mu .Unlock ()
98+
99+ if pr != nil {
100+ pr .ch <- * msg
101+ }
93102}
94103
95104func (c * Connection ) handleInbound (req * anyMessage ) {
96- // Context that cancels when the connection is closed
97- ctx , cancel := context .WithCancel (context .Background ())
98- go func () {
99- <- c .Done ()
100- cancel ()
101- }()
102105 res := anyMessage {JSONRPC : "2.0" }
103106 // copy ID if present
104107 if req .ID != nil {
@@ -112,7 +115,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
112115 return
113116 }
114117
115- result , err := c .handler (ctx , req .Method , req .Params )
118+ result , err := c .handler (c . ctx , req .Method , req .Params )
116119 if req .ID == nil {
117120 // notification: nothing to send
118121 return
@@ -147,93 +150,101 @@ func (c *Connection) sendMessage(msg anyMessage) error {
147150// SendRequest sends a JSON-RPC request and returns a typed result.
148151// For methods that do not return a result, use SendRequestNoResult instead.
149152func SendRequest [T any ](c * Connection , ctx context.Context , method string , params any ) (T , error ) {
150- var zero T
151- // allocate id
152- id := c .nextID .Add (1 )
153- idRaw , _ := json .Marshal (id )
154- msg := anyMessage {
155- JSONRPC : "2.0" ,
156- ID : (* json .RawMessage )(& idRaw ),
157- Method : method ,
158- }
159- if params != nil {
160- b , err := json .Marshal (params )
161- if err != nil {
162- return zero , NewInvalidParams (map [string ]any {"error" : err .Error ()})
163- }
164- msg .Params = b
153+ var result T
154+
155+ msg , idKey , err := c .prepareRequest (method , params )
156+ if err != nil {
157+ return result , err
165158 }
159+
166160 pr := & pendingResponse {ch : make (chan anyMessage , 1 )}
167- idKey := string (idRaw )
168161 c .mu .Lock ()
169162 c .pending [idKey ] = pr
170163 c .mu .Unlock ()
164+
171165 if err := c .sendMessage (msg ); err != nil {
172- return zero , NewInternalError (map [string ]any {"error" : err .Error ()})
166+ c .cleanupPending (idKey )
167+ return result , NewInternalError (map [string ]any {"error" : err .Error ()})
173168 }
174- // wait for response or peer disconnect
175- var resp anyMessage
176- d := c .Done ()
177- select {
178- case resp = <- pr .ch :
179- case <- ctx .Done ():
180- // best-effort cleanup
181- c .mu .Lock ()
182- delete (c .pending , idKey )
183- c .mu .Unlock ()
184- return zero , NewInternalError (map [string ]any {"error" : ctx .Err ().Error ()})
185- case <- d :
186- return zero , NewInternalError (map [string ]any {"error" : "peer disconnected before response" })
169+
170+ resp , err := c .waitForResponse (ctx , pr , idKey )
171+ if err != nil {
172+ return result , err
187173 }
174+
188175 if resp .Error != nil {
189- return zero , resp .Error
176+ return result , resp .Error
190177 }
191- var out T
178+
192179 if len (resp .Result ) > 0 {
193- if err := json .Unmarshal (resp .Result , & out ); err != nil {
194- return zero , NewInternalError (map [string ]any {"error" : err .Error ()})
180+ if err := json .Unmarshal (resp .Result , & result ); err != nil {
181+ return result , NewInternalError (map [string ]any {"error" : err .Error ()})
195182 }
196183 }
197- return out , nil
184+ return result , nil
198185}
199186
200- // SendRequestNoResult sends a JSON-RPC request that returns no result payload.
201- func (c * Connection ) SendRequestNoResult (ctx context.Context , method string , params any ) error {
202- // allocate id
187+ func (c * Connection ) prepareRequest (method string , params any ) (anyMessage , string , error ) {
203188 id := c .nextID .Add (1 )
204189 idRaw , _ := json .Marshal (id )
190+
205191 msg := anyMessage {
206192 JSONRPC : "2.0" ,
207193 ID : (* json .RawMessage )(& idRaw ),
208194 Method : method ,
209195 }
196+
210197 if params != nil {
211198 b , err := json .Marshal (params )
212199 if err != nil {
213- return NewInvalidParams (map [string ]any {"error" : err .Error ()})
200+ return msg , "" , NewInvalidParams (map [string ]any {"error" : err .Error ()})
214201 }
215202 msg .Params = b
216203 }
204+
205+ return msg , string (idRaw ), nil
206+ }
207+
208+ func (c * Connection ) waitForResponse (ctx context.Context , pr * pendingResponse , idKey string ) (anyMessage , error ) {
209+ select {
210+ case resp := <- pr .ch :
211+ return resp , nil
212+ case <- ctx .Done ():
213+ c .cleanupPending (idKey )
214+ return anyMessage {}, NewInternalError (map [string ]any {"error" : context .Cause (ctx ).Error ()})
215+ case <- c .Done ():
216+ return anyMessage {}, NewInternalError (map [string ]any {"error" : "peer disconnected before response" })
217+ }
218+ }
219+
220+ func (c * Connection ) cleanupPending (idKey string ) {
221+ c .mu .Lock ()
222+ delete (c .pending , idKey )
223+ c .mu .Unlock ()
224+ }
225+
226+ // SendRequestNoResult sends a JSON-RPC request that returns no result payload.
227+ func (c * Connection ) SendRequestNoResult (ctx context.Context , method string , params any ) error {
228+ msg , idKey , err := c .prepareRequest (method , params )
229+ if err != nil {
230+ return err
231+ }
232+
217233 pr := & pendingResponse {ch : make (chan anyMessage , 1 )}
218- idKey := string (idRaw )
219234 c .mu .Lock ()
220235 c .pending [idKey ] = pr
221236 c .mu .Unlock ()
237+
222238 if err := c .sendMessage (msg ); err != nil {
239+ c .cleanupPending (idKey )
223240 return NewInternalError (map [string ]any {"error" : err .Error ()})
224241 }
225- var resp anyMessage
226- d := c .Done ()
227- select {
228- case resp = <- pr .ch :
229- case <- ctx .Done ():
230- c .mu .Lock ()
231- delete (c .pending , idKey )
232- c .mu .Unlock ()
233- return NewInternalError (map [string ]any {"error" : ctx .Err ().Error ()})
234- case <- d :
235- return NewInternalError (map [string ]any {"error" : "peer disconnected before response" })
242+
243+ resp , err := c .waitForResponse (ctx , pr , idKey )
244+ if err != nil {
245+ return err
236246 }
247+
237248 if resp .Error != nil {
238249 return resp .Error
239250 }
@@ -246,43 +257,37 @@ func (c *Connection) SendNotification(ctx context.Context, method string, params
246257 return NewInternalError (map [string ]any {"error" : ctx .Err ().Error ()})
247258 default :
248259 }
249- msg := anyMessage {JSONRPC : "2.0" , Method : method }
260+
261+ msg , err := c .prepareNotification (method , params )
262+ if err != nil {
263+ return err
264+ }
265+
266+ if err := c .sendMessage (msg ); err != nil {
267+ return NewInternalError (map [string ]any {"error" : err .Error ()})
268+ }
269+ return nil
270+ }
271+
272+ func (c * Connection ) prepareNotification (method string , params any ) (anyMessage , error ) {
273+ msg := anyMessage {
274+ JSONRPC : "2.0" ,
275+ Method : method ,
276+ }
277+
250278 if params != nil {
251279 b , err := json .Marshal (params )
252280 if err != nil {
253- return NewInvalidParams (map [string ]any {"error" : err .Error ()})
281+ return msg , NewInvalidParams (map [string ]any {"error" : err .Error ()})
254282 }
255283 msg .Params = b
256284 }
257- if err := c .sendMessage (msg ); err != nil {
258- return NewInternalError (map [string ]any {"error" : err .Error ()})
259- }
260- return nil
285+
286+ return msg , nil
261287}
262288
263289// Done returns a channel that is closed when the underlying reader loop exits
264290// (typically when the peer disconnects or the input stream is closed).
265291func (c * Connection ) Done () <- chan struct {} {
266- c .mu .Lock ()
267- d := c .done
268- c .mu .Unlock ()
269- return d
270- }
271-
272- // Helper: lightweight TrimSpace for []byte without importing bytes only for this.
273- func bytesTrimSpace (b []byte ) []byte {
274- i := 0
275- for ; i < len (b ); i ++ {
276- if b [i ] != ' ' && b [i ] != '\t' && b [i ] != '\r' && b [i ] != '\n' {
277- break
278- }
279- }
280- j := len (b )
281- for j > i {
282- if b [j - 1 ] != ' ' && b [j - 1 ] != '\t' && b [j - 1 ] != '\r' && b [j - 1 ] != '\n' {
283- break
284- }
285- j --
286- }
287- return b [i :j ]
292+ return c .ctx .Done ()
288293}
0 commit comments