Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/websocket-proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,16 @@ async fn main() {
let (send, _rec) = broadcast::channel(args.message_buffer_size);
let sender = send.clone();

let listener = move |data: String| {
trace!(message = "received data", data = data);
let listener = move |data: Vec<u8>| {
trace!(message = "received data", data = ?data);
// Subtract one from receiver count, as we have to keep one receiver open at all times (see _rec)
// to avoid the channel being closed. However this is not an active client connection.
metrics_clone
.active_connections
.set((send.receiver_count() - 1) as f64);

let message_data = if args.enable_compression {
let data_bytes = data.as_bytes();
let data_bytes = data.as_slice();
let mut compressed_data_bytes = Vec::new();
{
let mut compressor =
Expand All @@ -267,7 +267,7 @@ async fn main() {
}
compressed_data_bytes
} else {
data.into_bytes()
data
};

match send.send(message_data.into()) {
Expand Down
65 changes: 40 additions & 25 deletions crates/websocket-proxy/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Default for SubscriberOptions {

pub struct WebsocketSubscriber<F>
where
F: Fn(String) + Send + Sync + 'static,
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
uri: Uri,
handler: F,
Expand All @@ -74,7 +74,7 @@ where

impl<F> WebsocketSubscriber<F>
where
F: Fn(String) + Send + Sync + 'static,
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
pub fn new(uri: Uri, handler: F, metrics: Arc<Metrics>, options: SubscriberOptions) -> Self {
let backoff = ExponentialBackoff {
Expand Down Expand Up @@ -258,14 +258,17 @@ where
);
self.metrics
.message_received_from_upstream(self.uri.to_string().as_str());
(self.handler)(text.to_string());
(self.handler)(text.as_bytes().to_vec());
}
Message::Binary(data) => {
warn!(
message = "received binary message, unsupported",
trace!(
message = "received binary message",
uri = self.uri.to_string(),
size = data.len()
payload = ?data.as_ref()
);
self.metrics
.message_received_from_upstream(self.uri.to_string().as_str());
(self.handler)(data.as_ref().to_vec());
}
Message::Pong(_) => {
trace!(
Expand Down Expand Up @@ -299,19 +302,19 @@ mod tests {
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
use tokio::time::{sleep, timeout, Duration};
use tokio_tungstenite::{accept_async, tungstenite::Message};
use tokio_tungstenite::accept_async;

struct MockServer {
addr: SocketAddr,
message_sender: broadcast::Sender<String>,
message_sender: broadcast::Sender<Vec<u8>>,
shutdown: CancellationToken,
}

impl MockServer {
async fn new() -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, _) = broadcast::channel::<String>(100);
let (tx, _) = broadcast::channel::<Vec<u8>>(100);
let shutdown = CancellationToken::new();
let shutdown_clone = shutdown.clone();
let tx_clone = tx.clone();
Expand Down Expand Up @@ -350,7 +353,7 @@ mod tests {

async fn handle_connection(
stream: TcpStream,
tx: broadcast::Sender<String>,
tx: broadcast::Sender<Vec<u8>>,
shutdown: CancellationToken,
) {
let ws_stream = match accept_async(stream).await {
Expand All @@ -372,8 +375,8 @@ mod tests {
}
msg = rx.recv() => {
match msg {
Ok(text) => {
if let Err(e) = ws_sender.send(Message::Text(text.into())).await {
Ok(data) => {
if let Err(e) = ws_sender.send(data.into()).await {
eprintln!("Error sending message: {}", e);
break;
}
Expand All @@ -389,9 +392,9 @@ mod tests {

async fn send_message(
&self,
msg: &str,
) -> Result<usize, broadcast::error::SendError<String>> {
self.message_sender.send(msg.to_string())
msg: &[u8],
) -> Result<usize, broadcast::error::SendError<Vec<u8>>> {
self.message_sender.send(msg.to_vec())
}

async fn shutdown(self) {
Expand Down Expand Up @@ -443,7 +446,7 @@ mod tests {
}
});

let listener_fn = move |_data: String| {
let listener_fn = move |_data: Vec<u8>| {
// Handler for received messages - not needed for this test
};

Expand Down Expand Up @@ -485,7 +488,8 @@ mod tests {
let received_messages = Arc::new(Mutex::new(Vec::new()));
let received_clone = received_messages.clone();

let listener = move |data: String| {
// Create a listener function that will be shared by both subscribers
let listener = move |data: Vec<u8>| {
if let Ok(mut messages) = received_clone.lock() {
messages.push(data);
}
Expand Down Expand Up @@ -529,13 +533,23 @@ mod tests {

sleep(Duration::from_millis(500)).await;

let _ = server1.send_message("Message from server 1").await;
let _ = server2.send_message("Message from server 2").await;
// Send different messages from each server
let _ = server1
.send_message("Message from server 1".as_bytes())
.await;
let _ = server2
.send_message("Message from server 2".as_bytes())
.await;

sleep(Duration::from_millis(500)).await;

let _ = server1.send_message("Another message from server 1").await;
let _ = server2.send_message("Another message from server 2").await;
// Send more messages to ensure continuous operation
let _ = server1
.send_message("Another message from server 1".as_bytes())
.await;
let _ = server2
.send_message("Another message from server 2".as_bytes())
.await;

sleep(Duration::from_millis(500)).await;

Expand All @@ -554,10 +568,11 @@ mod tests {

assert_eq!(messages.len(), 4);

assert!(messages.contains(&"Message from server 1".to_string()));
assert!(messages.contains(&"Message from server 2".to_string()));
assert!(messages.contains(&"Another message from server 1".to_string()));
assert!(messages.contains(&"Another message from server 2".to_string()));
// Check that we received messages from both servers
assert!(messages.contains(&"Message from server 1".as_bytes().to_vec()));
assert!(messages.contains(&"Message from server 2".as_bytes().to_vec()));
assert!(messages.contains(&"Another message from server 1".as_bytes().to_vec()));
assert!(messages.contains(&"Another message from server 2".as_bytes().to_vec()));

assert!(!messages.is_empty());
}
Expand Down
Loading