diff --git a/api/bridge_handler.go b/api/bridge_handler.go index faad3cb..f7c4aac 100644 --- a/api/bridge_handler.go +++ b/api/bridge_handler.go @@ -7,8 +7,10 @@ import ( "log" "net/http" "sync" + "time" "github.com/go-chi/chi/v5" + "github.com/gorilla/websocket" "github.com/ledgerwatch/diagnostics" "github.com/ledgerwatch/diagnostics/api/internal" "github.com/ledgerwatch/diagnostics/internal/erigon_node" @@ -23,20 +25,40 @@ type BridgeHandler struct { cache sessions.CacheService } +const ( + wsReadBuffer = 1024 + wsWriteBuffer = 1024 + wsPingInterval = 60 * time.Second + wsPingWriteTimeout = 5 * time.Second + wsMessageSizeLimit = 32 * 1024 * 1024 +) + +var wsBufferPool = new(sync.Pool) + func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { //Sends a success Message to the Node client, to receive more information - flusher, _ := w.(http.Flusher) ctx, cancel := context.WithCancel(r.Context()) defer cancel() defer r.Body.Close() + upgrader := websocket.Upgrader{ + EnableCompression: true, + ReadBufferSize: wsReadBuffer, + WriteBufferSize: wsWriteBuffer, + WriteBufferPool: wsBufferPool, + } + // Update the request context with the connection context. // If the connection is closed by the server, it will also notify everything that waits on the request context. *r = *r.WithContext(ctx) - w.WriteHeader(http.StatusOK) - flusher.Flush() + conn, err := upgrader.Upgrade(w, r, nil) + + if err != nil { + internal.EncodeError(w, r, diagnostics.AsBadRequestErr(errors.Errorf("Error upgrading websocket: %v", err))) + return + } connectionInfo := struct { Version uint64 `json:"version"` @@ -44,14 +66,21 @@ func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { Nodes []*sessions.NodeInfo `json:"nodes"` }{} - err := json.NewDecoder(r.Body).Decode(&connectionInfo) + _, message, err := conn.ReadMessage() if err != nil { - log.Printf("Error reading connection info: %v\n", err) internal.EncodeError(w, r, diagnostics.AsBadRequestErr(errors.Errorf("Error reading connection info: %v", err))) return } + err = json.Unmarshal(message, &connectionInfo) + + if err != nil { + log.Printf("Error reading connection info: %v\n", err) + internal.EncodeError(w, r, diagnostics.AsBadRequestErr(errors.Errorf("Error unmarshaling connection info: %v", err))) + return + } + requestMap := map[string]*erigon_node.NodeRequest{} requestMutex := sync.Mutex{} @@ -97,12 +126,10 @@ func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { requestMap[rpcRequest.Id] = request requestMutex.Unlock() - if _, err := w.Write(bytes); err != nil { + if err := conn.WriteMessage(websocket.TextMessage, bytes); err != nil { requestMutex.Lock() delete(requestMap, rpcRequest.Id) requestMutex.Unlock() - - fmt.Println(request.Retries, err) request.Retries++ if request.Retries < 15 { select { @@ -119,18 +146,21 @@ func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { } continue } - - flusher.Flush() } }() } - decoder := json.NewDecoder(r.Body) - for { var response erigon_node.Response - if err = decoder.Decode(&response); err != nil { + _, message, err := conn.ReadMessage() + + if err != nil { + fmt.Printf("can't read response: %v\n", err) + continue + } + + if err = json.Unmarshal(message, &response); err != nil { fmt.Printf("can't read response: %v\n", err) continue } @@ -163,7 +193,7 @@ func NewBridgeHandler(cacheSvc sessions.CacheService) BridgeHandler { cache: cacheSvc, } - r.Post("/", r.Bridge) + r.Get("/", r.Bridge) return *r } diff --git a/cmd/diagnostics/main.go b/cmd/diagnostics/main.go index 42e99c9..dfa638b 100644 --- a/cmd/diagnostics/main.go +++ b/cmd/diagnostics/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "crypto/tls" "crypto/x509" "fmt" "log" @@ -50,11 +49,6 @@ func main() { certPool.AppendCertsFromPEM(caCert) } - tlsConfig := &tls.Config{ - RootCAs: certPool, - MinVersion: tls.VersionTLS12, - } - // Passing in the services to REST layer handlers := api.NewHandler( api.APIServices{ @@ -65,12 +59,13 @@ func main() { Addr: fmt.Sprintf("%s:%d", listenAddr, listenPort), Handler: handlers, MaxHeaderBytes: 1 << 20, - TLSConfig: tlsConfig, ReadHeaderTimeout: 1 * time.Minute, } go func() { - if err := srv.ListenAndServeTLS(serverCertFile, serverKeyFile); err != http.ErrServerClosed { + err := srv.ListenAndServe() + + if err != nil { log.Fatal(err) } }() diff --git a/internal/bridge/middleware.go b/internal/bridge/middleware.go index 7cc993f..cf0bd78 100644 --- a/internal/bridge/middleware.go +++ b/internal/bridge/middleware.go @@ -1,22 +1,13 @@ package bridge -import "net/http" +import ( + "net/http" +) var ErrHTTP2NotSupported = "HTTP2 not supported" func Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !r.ProtoAtLeast(2, 0) { - http.Error(w, ErrHTTP2NotSupported, http.StatusHTTPVersionNotSupported) - return - } - - _, ok := w.(http.Flusher) - if !ok { - http.Error(w, ErrHTTP2NotSupported, http.StatusHTTPVersionNotSupported) - return - } - next.ServeHTTP(w, r) }) }