Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/websocket-proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,16 @@ async fn main() {
let (send, _rec) = broadcast::channel(args.message_buffer_size);
let sender = send.clone();

let listener = move |data: String| {
trace!(message = "received data", data = data);
let listener = move |data: Vec<u8>| {
trace!(message = "received data", data = ?data);
// Subtract one from receiver count, as we have to keep one receiver open at all times (see _rec)
// to avoid the channel being closed. However this is not an active client connection.
metrics_clone
.active_connections
.set((send.receiver_count() - 1) as f64);

let message_data = if args.enable_compression {
let data_bytes = data.as_bytes();
let data_bytes = data.as_slice();
let mut compressed_data_bytes = Vec::new();
{
let mut compressor =
Expand All @@ -260,7 +260,7 @@ async fn main() {
}
compressed_data_bytes
} else {
data.into_bytes()
data
};

match send.send(message_data.into()) {
Expand Down
59 changes: 35 additions & 24 deletions crates/websocket-proxy/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Default for SubscriberOptions {

pub struct WebsocketSubscriber<F>
where
F: Fn(String) + Send + Sync + 'static,
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
uri: Uri,
handler: F,
Expand All @@ -74,7 +74,7 @@ where

impl<F> WebsocketSubscriber<F>
where
F: Fn(String) + Send + Sync + 'static,
F: Fn(Vec<u8>) + Send + Sync + 'static,
{
pub fn new(uri: Uri, handler: F, metrics: Arc<Metrics>, options: SubscriberOptions) -> Self {
let backoff = ExponentialBackoff {
Expand Down Expand Up @@ -255,14 +255,17 @@ where
);
self.metrics
.message_received_from_upstream(self.uri.to_string().as_str());
(self.handler)(text.to_string());
(self.handler)(text.as_bytes().to_vec());
}
Message::Binary(data) => {
warn!(
message = "received binary message, unsupported",
trace!(
message = "received binary message",
uri = self.uri.to_string(),
size = data.len()
payload = ?data.as_ref()
);
self.metrics
.message_received_from_upstream(self.uri.to_string().as_str());
(self.handler)(data.as_ref().to_vec());
}
Message::Pong(_) => {
trace!(
Expand Down Expand Up @@ -300,15 +303,15 @@ mod tests {

struct MockServer {
addr: SocketAddr,
message_sender: broadcast::Sender<String>,
message_sender: broadcast::Sender<Vec<u8>>,
shutdown: CancellationToken,
}

impl MockServer {
async fn new() -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, _) = broadcast::channel::<String>(100);
let (tx, _) = broadcast::channel::<Vec<u8>>(100);
let shutdown = CancellationToken::new();
let shutdown_clone = shutdown.clone();
let tx_clone = tx.clone();
Expand Down Expand Up @@ -347,7 +350,7 @@ mod tests {

async fn handle_connection(
stream: TcpStream,
tx: broadcast::Sender<String>,
tx: broadcast::Sender<Vec<u8>>,
shutdown: CancellationToken,
) {
let ws_stream = match accept_async(stream).await {
Expand All @@ -369,8 +372,8 @@ mod tests {
}
msg = rx.recv() => {
match msg {
Ok(text) => {
if let Err(e) = ws_sender.send(Message::Text(text.into())).await {
Ok(data) => {
if let Err(e) = ws_sender.send(data.into()).await {
eprintln!("Error sending message: {}", e);
break;
}
Expand All @@ -386,9 +389,9 @@ mod tests {

async fn send_message(
&self,
msg: &str,
) -> Result<usize, broadcast::error::SendError<String>> {
self.message_sender.send(msg.to_string())
msg: &[u8],
) -> Result<usize, broadcast::error::SendError<Vec<u8>>> {
self.message_sender.send(msg.to_vec())
}

async fn shutdown(self) {
Expand Down Expand Up @@ -440,7 +443,7 @@ mod tests {
}
});

let listener_fn = move |_data: String| {
let listener_fn = move |_data: Vec<u8>| {
// Handler for received messages - not needed for this test
};

Expand Down Expand Up @@ -482,7 +485,7 @@ mod tests {
let received_messages = Arc::new(Mutex::new(Vec::new()));
let received_clone = received_messages.clone();

let listener = move |data: String| {
let listener = move |data: Vec<u8>| {
if let Ok(mut messages) = received_clone.lock() {
messages.push(data);
}
Expand Down Expand Up @@ -526,13 +529,21 @@ mod tests {

sleep(Duration::from_millis(500)).await;

let _ = server1.send_message("Message from server 1").await;
let _ = server2.send_message("Message from server 2").await;
let _ = server1
.send_message("Message from server 1".as_bytes())
.await;
let _ = server2
.send_message("Message from server 2".as_bytes())
.await;

sleep(Duration::from_millis(500)).await;

let _ = server1.send_message("Another message from server 1").await;
let _ = server2.send_message("Another message from server 2").await;
let _ = server1
.send_message("Another message from server 1".as_bytes())
.await;
let _ = server2
.send_message("Another message from server 2".as_bytes())
.await;

// Wait for messages to be processed
sleep(Duration::from_millis(500)).await;
Expand All @@ -552,10 +563,10 @@ mod tests {

assert_eq!(messages.len(), 4);

assert!(messages.contains(&"Message from server 1".to_string()));
assert!(messages.contains(&"Message from server 2".to_string()));
assert!(messages.contains(&"Another message from server 1".to_string()));
assert!(messages.contains(&"Another message from server 2".to_string()));
assert!(messages.contains(&"Message from server 1".as_bytes().to_vec()));
assert!(messages.contains(&"Message from server 2".as_bytes().to_vec()));
assert!(messages.contains(&"Another message from server 1".as_bytes().to_vec()));
assert!(messages.contains(&"Another message from server 2".as_bytes().to_vec()));

assert!(!messages.is_empty());
}
Expand Down
Loading