diff --git a/.changelog/unreleased/improvements/2434-jsonrpc-websocket-basic-auth.md b/.changelog/unreleased/improvements/2434-jsonrpc-websocket-basic-auth.md new file mode 100644 index 0000000000..e4db7c06c7 --- /dev/null +++ b/.changelog/unreleased/improvements/2434-jsonrpc-websocket-basic-auth.md @@ -0,0 +1 @@ +- `[jsonrpc]` enable HTTP basic auth in websocket client ([#2434](https://github.com/cometbft/cometbft/pull/2434)) diff --git a/rpc/jsonrpc/client/ws_client.go b/rpc/jsonrpc/client/ws_client.go index 77a189eb59..f93bdae0b0 100644 --- a/rpc/jsonrpc/client/ws_client.go +++ b/rpc/jsonrpc/client/ws_client.go @@ -2,6 +2,7 @@ package client import ( "context" + "encoding/base64" "encoding/json" "fmt" "net" @@ -35,7 +36,10 @@ type WSClient struct { //nolint: maligned Address string // IP:PORT or /path/to/socket Endpoint string // /websocket/url/endpoint - Dialer func(string, string) (net.Conn, error) + Username string + Password string + + Dialer func(string, string) (net.Conn, error) // Single user facing channel to read RPCResponses from, closed only when the // client is being stopped. @@ -96,6 +100,14 @@ func NewWS(remoteAddr, endpoint string, options ...func(*WSClient)) (*WSClient, parsedURL.Scheme = protoWS } + // extract username and password from URL if any + username := "" + password := "" + if parsedURL.User.String() != "" { + username = parsedURL.User.Username() + password, _ = parsedURL.User.Password() + } + dialFn, err := makeHTTPDialer(remoteAddr) if err != nil { return nil, err @@ -103,6 +115,8 @@ func NewWS(remoteAddr, endpoint string, options ...func(*WSClient)) (*WSClient, c := &WSClient{ Address: parsedURL.GetTrimmedHostWithPath(), + Username: username, + Password: password, Dialer: dialFn, Endpoint: endpoint, PingPongLatencyTimer: metrics.NewTimer(), @@ -267,6 +281,12 @@ func (c *WSClient) dial() error { Proxy: http.ProxyFromEnvironment, } rHeader := http.Header{} + + // Set basic auth header if username and password are provided + if c.Username != "" && c.Password != "" { + rHeader.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(c.Username+":"+c.Password))) + } + conn, _, err := dialer.Dial(c.protocol+"://"+c.Address+c.Endpoint, rHeader) //nolint:bodyclose if err != nil { return err