From 3b91930f84b5e9616e23b4937d464a0c821a4f42 Mon Sep 17 00:00:00 2001 From: Haardik H Date: Wed, 29 Oct 2025 13:24:01 -0400 Subject: [PATCH] feat: handle upstream binary messages to allow being run as a middleman --- crates/websocket-proxy/src/main.rs | 8 ++-- crates/websocket-proxy/src/subscriber.rs | 59 ++++++++++++++---------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/crates/websocket-proxy/src/main.rs b/crates/websocket-proxy/src/main.rs index 789e7f45..a496b91e 100644 --- a/crates/websocket-proxy/src/main.rs +++ b/crates/websocket-proxy/src/main.rs @@ -242,8 +242,8 @@ async fn main() { let (send, _rec) = broadcast::channel(args.message_buffer_size); let sender = send.clone(); - let listener = move |data: String| { - trace!(message = "received data", data = data); + let listener = move |data: Vec| { + trace!(message = "received data", data = ?data); // Subtract one from receiver count, as we have to keep one receiver open at all times (see _rec) // to avoid the channel being closed. However this is not an active client connection. metrics_clone @@ -251,7 +251,7 @@ async fn main() { .set((send.receiver_count() - 1) as f64); let message_data = if args.enable_compression { - let data_bytes = data.as_bytes(); + let data_bytes = data.as_slice(); let mut compressed_data_bytes = Vec::new(); { let mut compressor = @@ -260,7 +260,7 @@ async fn main() { } compressed_data_bytes } else { - data.into_bytes() + data }; match send.send(message_data.into()) { diff --git a/crates/websocket-proxy/src/subscriber.rs b/crates/websocket-proxy/src/subscriber.rs index 268f6293..23f6af9e 100644 --- a/crates/websocket-proxy/src/subscriber.rs +++ b/crates/websocket-proxy/src/subscriber.rs @@ -63,7 +63,7 @@ impl Default for SubscriberOptions { pub struct WebsocketSubscriber where - F: Fn(String) + Send + Sync + 'static, + F: Fn(Vec) + Send + Sync + 'static, { uri: Uri, handler: F, @@ -74,7 +74,7 @@ where impl WebsocketSubscriber where - F: Fn(String) + Send + Sync + 'static, + F: Fn(Vec) + Send + Sync + 'static, { pub fn new(uri: Uri, handler: F, metrics: Arc, options: SubscriberOptions) -> Self { let backoff = ExponentialBackoff { @@ -255,14 +255,17 @@ where ); self.metrics .message_received_from_upstream(self.uri.to_string().as_str()); - (self.handler)(text.to_string()); + (self.handler)(text.as_bytes().to_vec()); } Message::Binary(data) => { - warn!( - message = "received binary message, unsupported", + trace!( + message = "received binary message", uri = self.uri.to_string(), - size = data.len() + payload = ?data.as_ref() ); + self.metrics + .message_received_from_upstream(self.uri.to_string().as_str()); + (self.handler)(data.as_ref().to_vec()); } Message::Pong(_) => { trace!( @@ -300,7 +303,7 @@ mod tests { struct MockServer { addr: SocketAddr, - message_sender: broadcast::Sender, + message_sender: broadcast::Sender>, shutdown: CancellationToken, } @@ -308,7 +311,7 @@ mod tests { async fn new() -> Self { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let (tx, _) = broadcast::channel::(100); + let (tx, _) = broadcast::channel::>(100); let shutdown = CancellationToken::new(); let shutdown_clone = shutdown.clone(); let tx_clone = tx.clone(); @@ -347,7 +350,7 @@ mod tests { async fn handle_connection( stream: TcpStream, - tx: broadcast::Sender, + tx: broadcast::Sender>, shutdown: CancellationToken, ) { let ws_stream = match accept_async(stream).await { @@ -369,8 +372,8 @@ mod tests { } msg = rx.recv() => { match msg { - Ok(text) => { - if let Err(e) = ws_sender.send(Message::Text(text.into())).await { + Ok(data) => { + if let Err(e) = ws_sender.send(data.into()).await { eprintln!("Error sending message: {}", e); break; } @@ -386,9 +389,9 @@ mod tests { async fn send_message( &self, - msg: &str, - ) -> Result> { - self.message_sender.send(msg.to_string()) + msg: &[u8], + ) -> Result>> { + self.message_sender.send(msg.to_vec()) } async fn shutdown(self) { @@ -440,7 +443,7 @@ mod tests { } }); - let listener_fn = move |_data: String| { + let listener_fn = move |_data: Vec| { // Handler for received messages - not needed for this test }; @@ -482,7 +485,7 @@ mod tests { let received_messages = Arc::new(Mutex::new(Vec::new())); let received_clone = received_messages.clone(); - let listener = move |data: String| { + let listener = move |data: Vec| { if let Ok(mut messages) = received_clone.lock() { messages.push(data); } @@ -526,13 +529,21 @@ mod tests { sleep(Duration::from_millis(500)).await; - let _ = server1.send_message("Message from server 1").await; - let _ = server2.send_message("Message from server 2").await; + let _ = server1 + .send_message("Message from server 1".as_bytes()) + .await; + let _ = server2 + .send_message("Message from server 2".as_bytes()) + .await; sleep(Duration::from_millis(500)).await; - let _ = server1.send_message("Another message from server 1").await; - let _ = server2.send_message("Another message from server 2").await; + let _ = server1 + .send_message("Another message from server 1".as_bytes()) + .await; + let _ = server2 + .send_message("Another message from server 2".as_bytes()) + .await; // Wait for messages to be processed sleep(Duration::from_millis(500)).await; @@ -552,10 +563,10 @@ mod tests { assert_eq!(messages.len(), 4); - assert!(messages.contains(&"Message from server 1".to_string())); - assert!(messages.contains(&"Message from server 2".to_string())); - assert!(messages.contains(&"Another message from server 1".to_string())); - assert!(messages.contains(&"Another message from server 2".to_string())); + assert!(messages.contains(&"Message from server 1".as_bytes().to_vec())); + assert!(messages.contains(&"Message from server 2".as_bytes().to_vec())); + assert!(messages.contains(&"Another message from server 1".as_bytes().to_vec())); + assert!(messages.contains(&"Another message from server 2".as_bytes().to_vec())); assert!(!messages.is_empty()); }