diff --git a/server/websocket.go b/server/websocket.go index 1103f47..70172b5 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -13,8 +13,9 @@ import ( ) type wsConnection struct { - conn *websocket.Conn - writeMu sync.Mutex + conn *websocket.Conn + writeMu sync.Mutex + handlerSem chan struct{} } type validationError struct { @@ -24,10 +25,11 @@ type validationError struct { } const ( - wsMaxMessageSize = 64 * 1024 - wsWriteWait = 10 * time.Second - wsPongWait = 60 * time.Second - wsPingPeriod = (wsPongWait * 9) / 10 + wsMaxMessageSize = 64 * 1024 + wsWriteWait = 10 * time.Second + wsPongWait = 60 * time.Second + wsPingPeriod = (wsPongWait * 9) / 10 + wsMaxConcurrentHandlers = 4 jsonRPCVersion = "2.0" errMsgParseError = "expecting jsonrpc payload" @@ -131,7 +133,7 @@ func NewWebSocketHandler(enableCORS bool) http.HandlerFunc { } defer conn.Close() - wsConn := &wsConnection{conn: conn} + wsConn := &wsConnection{conn: conn, handlerSem: make(chan struct{}, wsMaxConcurrentHandlers)} configureConnection(conn) stopPing := startPingRoutine(wsConn) defer stopPing() @@ -207,14 +209,28 @@ func handleWSMethodCall(wsConn *wsConnection, req JSONRPCRequest) { return } - result, err := handler(req.Params) - if err != nil { - log.Printf("Error executing method %s: %v", req.Method, err) - wsConn.sendError(req.ID, ErrCodeServerError, "Server error", err.Error()) + // non-blocking acquire; reject immediately when all slots are taken + select { + case wsConn.handlerSem <- struct{}{}: + default: + wsConn.sendError(req.ID, ErrCodeServerError, "Server error", "too many concurrent requests") return } - wsConn.sendResponse(req.ID, result) + // run in a goroutine so the read loop stays unblocked and can process + // pong frames — without this, long-running handlers cause the read + // deadline to expire and the connection closes with 1006 + go func() { + defer func() { <-wsConn.handlerSem }() + result, err := handler(req.Params) + if err != nil { + log.Printf("Error executing method %s: %v", req.Method, err) + wsConn.sendError(req.ID, ErrCodeServerError, "Server error", err.Error()) + return + } + + wsConn.sendResponse(req.ID, result) + }() } func (wsc *wsConnection) sendResponse(id any, result any) error { @@ -242,5 +258,8 @@ func (wsc *wsConnection) sendError(id any, code int, message string, data any) e func (wsc *wsConnection) sendJSON(v any) error { wsc.writeMu.Lock() defer wsc.writeMu.Unlock() + if err := wsc.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { + return err + } return wsc.conn.WriteJSON(v) }