Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 130 additions & 19 deletions crates/rollup-boost/src/flashblocks/inbound.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
use super::primitives::FlashblocksPayloadV1;
use crate::flashblocks::metrics::FlashblocksWsInboundMetrics;
use futures::StreamExt;
use tokio::sync::mpsc;
use std::time::Duration;

use super::{metrics::FlashblocksWsInboundMetrics, primitives::FlashblocksPayloadV1};
use futures::{SinkExt, StreamExt};
use tokio::{sync::mpsc, time::interval};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{error, info};
use url::Url;

#[derive(Debug, thiserror::Error)]
enum FlashblocksReceiverError {
#[error("WebSocket connection failed: {0}")]
Connection(#[from] tokio_tungstenite::tungstenite::Error),

#[error("Ping failed")]
PingFailed,

#[error("Read timeout")]
ReadTimeout,

#[error("Connection error: {0}")]
ConnectionError(String),

#[error("Connection closed")]
ConnectionClosed,

#[error("Task panicked: {0}")]
TaskPanic(String),

#[error("Failed to send message to sender: {0}")]
SendError(#[from] Box<tokio::sync::mpsc::error::SendError<FlashblocksPayloadV1>>),
}

pub struct FlashblocksReceiverService {
url: Url,
sender: mpsc::Sender<FlashblocksPayloadV1>,
Expand Down Expand Up @@ -36,23 +61,74 @@ impl FlashblocksReceiverService {
}
}

async fn connect_and_handle(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async fn connect_and_handle(&self) -> Result<(), FlashblocksReceiverError> {
let (ws_stream, _) = connect_async(self.url.as_str()).await?;
let (_, mut read) = ws_stream.split();
let (mut write, mut read) = ws_stream.split();

info!("Connected to Flashblocks receiver at {}", self.url);
self.metrics.connection_status.set(1);

while let Some(msg) = read.next().await {
if let Message::Text(text) = msg? {
self.metrics.messages_received.increment(1);
if let Ok(flashblocks_msg) = serde_json::from_str::<FlashblocksPayloadV1>(&text) {
self.sender.send(flashblocks_msg).await?;
let ping_task = tokio::spawn(async move {
let mut ping_interval = interval(Duration::from_millis(500));

loop {
tokio::select! {
_ = ping_interval.tick() => {
if write.send(Message::Ping(Default::default())).await.is_err() {
return Err(FlashblocksReceiverError::PingFailed);
}
}
}
}
}
});

Ok(())
let sender = self.sender.clone();
let metrics = self.metrics.clone();

let read_timeout = Duration::from_millis(500);
let message_handle = tokio::spawn(async move {
loop {
let result = tokio::time::timeout(read_timeout, read.next())
.await
.map_err(|_| FlashblocksReceiverError::ReadTimeout)?;

match result {
Some(Ok(msg)) => match msg {
Message::Text(text) => {
metrics.messages_received.increment(1);
if let Ok(flashblocks_msg) =
serde_json::from_str::<FlashblocksPayloadV1>(&text)
{
sender.send(flashblocks_msg).await.map_err(|e| {
FlashblocksReceiverError::SendError(Box::new(e))
})?;
}
}
Message::Close(_) => {
return Err(FlashblocksReceiverError::ConnectionClosed);
}
_ => {}
},
Some(Err(e)) => {
return Err(FlashblocksReceiverError::ConnectionError(e.to_string()));
}
None => {
return Err(FlashblocksReceiverError::ReadTimeout);
}
};
}
});

let result = tokio::select! {
result = message_handle => {
result.map_err(|e| FlashblocksReceiverError::TaskPanic(e.to_string()))?
},
result = ping_task => {
result.map_err(|e| FlashblocksReceiverError::TaskPanic(e.to_string()))?
},
};

result
}
}

Expand All @@ -70,10 +146,12 @@ mod tests {
) -> eyre::Result<(
watch::Sender<bool>,
mpsc::Sender<FlashblocksPayloadV1>,
mpsc::Receiver<()>,
url::Url,
)> {
let (term_tx, mut term_rx) = watch::channel(false);
let (send_tx, mut send_rx) = mpsc::channel::<FlashblocksPayloadV1>(100);
let (send_ping_tx, send_ping_rx) = mpsc::channel::<()>(100);

let listener = TcpListener::bind(addr)?;
let url = Url::parse(&format!("ws://{addr}"))?;
Expand All @@ -98,15 +176,26 @@ mod tests {
match result {
Ok((connection, _addr)) => {
match accept_async(connection).await {
Ok(mut ws_stream) => {
Ok(ws_stream) => {
let (mut write, mut read) = ws_stream.split();

loop {
tokio::select! {
Some(msg) = send_rx.recv() => {
let serialized = serde_json::to_string(&msg).unwrap();
let utf8_bytes = Utf8Bytes::from(serialized);

ws_stream.send(Message::Text(utf8_bytes)).await.unwrap();
write.send(Message::Text(utf8_bytes)).await.unwrap();
},
msg = read.next() => {
match msg {
// we need to read for the library to handle pong messages
Some(Ok(Message::Ping(_))) => {
send_ping_tx.send(()).await.unwrap();
},
_ => {}
}
}
_ = term_rx.changed() => {
if *term_rx.borrow() {
return;
Expand All @@ -132,13 +221,13 @@ mod tests {
}
});

Ok((term_tx, send_tx, url))
Ok((term_tx, send_tx, send_ping_rx, url))
}

#[tokio::test]
async fn test_flashblocks_receiver_service() -> eyre::Result<()> {
let addr = "127.0.0.1:8080".parse::<SocketAddr>().unwrap();
let (term, send_msg, url) = start(addr).await?;
let (term, send_msg, _, url) = start(addr).await?;

let (tx, mut rx) = mpsc::channel(100);

Expand All @@ -164,8 +253,8 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

// start a new server with the same address
let (term, send_msg, _url) = start(addr).await?;
send_msg
let (term, send_msg, _, _url) = start(addr).await?;
let _ = send_msg
.send(FlashblocksPayloadV1::default())
.await
.expect("Failed to send message");
Expand All @@ -176,4 +265,26 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_flashblocks_receiver_service_ping_pong() -> eyre::Result<()> {
// test that if the builder is not sending any messages back, the service will send
// ping messages to test the connection periodically

let addr = "127.0.0.1:8081".parse::<SocketAddr>().unwrap();
let (_term, _send_msg, mut ping_rx, url) = start(addr).await?;

let (tx, _rx) = mpsc::channel(100);
let service = FlashblocksReceiverService::new(url, tx, 100);
let _ = tokio::spawn(async move {
service.run().await;
});

// even if we do not send any messages, we should receive pings to keep the connection alive
for _ in 0..10 {
let _ = ping_rx.recv().await.expect("Failed to receive ping");
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/rollup-boost/src/flashblocks/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use metrics::{Counter, Gauge};
use metrics_derive::Metrics;

#[derive(Metrics)]
#[derive(Metrics, Clone)]
#[metrics(scope = "flashblocks.ws_inbound")]
pub struct FlashblocksWsInboundMetrics {
/// Total number of WebSocket reconnection attempts
Expand Down
Loading