Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/canister-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub struct WebsocketMessage {
}

/// Element of the list of messages returned to the WS Gateway after polling.
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputMessage {
/// The client that the gateway will forward the message to or that sent the message.
pub client_key: ClientKey,
Expand Down Expand Up @@ -122,7 +122,7 @@ pub enum CanisterServiceMessage {
}

/// List of messages returned to the WS Gateway after polling.
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Debug, CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputCertifiedMessages {
pub messages: Vec<CanisterOutputMessage>, // List of messages.
#[serde(with = "serde_bytes")]
Expand Down
2 changes: 1 addition & 1 deletion src/ic-websocket-gateway/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ic_websocket_gateway"
version = "1.3.1"
version = "1.3.2"
edition.workspace = true
rust-version.workspace = true
repository.workspace = true
Expand Down
100 changes: 61 additions & 39 deletions src/ic-websocket-gateway/src/canister_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@ use canister_utils::{
use gateway_state::{CanisterEntry, CanisterPrincipal, ClientSender, GatewayState, PollerState};
use ic_agent::{Agent, AgentError};
use std::{sync::Arc, time::Duration};
use tokio::sync::mpsc::Sender;
use tokio::{sync::mpsc::Sender, time::timeout};
use tracing::{error, span, trace, warn, Instrument, Level, Span};

enum PollingStatus {
pub(crate) const POLLING_TIMEOUT_MS: u64 = 5_000;

type PollingTimeout = Duration;

#[derive(Debug, PartialEq, Eq)]
pub(crate) enum PollingStatus {
NoMessagesPolled,
MessagesPolled(CanisterOutputCertifiedMessages),
PollerTimedOut,
}

/// Poller which periodically queries a canister for new messages and relays them to the client
Expand Down Expand Up @@ -59,7 +65,7 @@ impl CanisterPoller {
// initially set to None as the first iteration will not have a previous span
let mut previous_polling_iteration_span: Option<Span> = None;
loop {
let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration);
let polling_iteration_span = span!(Level::TRACE, "Polling Iteration", canister_id = %self.canister_id, polling_iteration = self.polling_iteration, cargo_version = env!("CARGO_PKG_VERSION"));
if let Some(previous_polling_iteration_span) = previous_polling_iteration_span {
// create a follow from relationship between the current and previous polling iteration
// this enables to crawl polling iterations in reverse chronological order
Expand Down Expand Up @@ -97,33 +103,39 @@ impl CanisterPoller {
pub async fn poll_and_relay(&mut self) -> Result<(), String> {
let start_polling_instant = tokio::time::Instant::now();

if let PollingStatus::MessagesPolled(certified_canister_output) =
self.poll_canister().await?
{
let relay_messages_span =
span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages");
let end_of_queue_reached = {
match certified_canister_output.is_end_of_queue {
Some(is_end_of_queue_reached) => is_end_of_queue_reached,
// if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field
// in this case, assume that the queue is fully drained and therefore will be polled again
// after waiting for 'polling_interval_ms'
None => true,
match self.poll_canister().await? {
PollingStatus::MessagesPolled(certified_canister_output) => {
let relay_messages_span =
span!(parent: &Span::current(), Level::TRACE, "Relay Canister Messages");
let end_of_queue_reached = {
match certified_canister_output.is_end_of_queue {
Some(is_end_of_queue_reached) => is_end_of_queue_reached,
// if 'is_end_of_queue' is None, the CDK version is < 0.3.1 and does not have such a field
// in this case, assume that the queue is fully drained and therefore will be polled again
// after waiting for 'polling_interval_ms'
None => true,
}
};
self.update_nonce(&certified_canister_output)?;
// relaying of messages cannot be done in a separate task for each polling iteration
// as they might interleave and break the correct ordering of messages
// TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue
// and relays them in FIFO order
self.relay_messages(certified_canister_output)
.instrument(relay_messages_span)
.await;
if !end_of_queue_reached {
// if the queue is not fully drained, return immediately so that the next polling iteration can be started
warn!("Canister queue is not fully drained. Polling immediately");
return Ok(());
}
};
self.update_nonce(&certified_canister_output)?;
// relaying of messages cannot be done in a separate task for each polling iteration
// as they might interleave and break the correct ordering of messages
// TODO: create a separate task dedicated to relaying messages which receives the messages from the poller via a queue
// and relays them in FIFO order
self.relay_messages(certified_canister_output)
.instrument(relay_messages_span)
.await;
if !end_of_queue_reached {
// if the queue is not fully drained, return immediately so that the next polling iteration can be started
warn!("Canister queue is not fully drained. Polling immediately");
},
PollingStatus::PollerTimedOut => {
// if the poller timed out, it already waited way too long... return immediately so that the next polling iteration can be started
warn!("Poller timed out. Polling immediately");
return Ok(());
}
},
PollingStatus::NoMessagesPolled => (),
}

// compute the amout of time to sleep for before polling again
Expand All @@ -135,20 +147,26 @@ impl CanisterPoller {
}

/// Polls the canister for messages
async fn poll_canister(&mut self) -> Result<PollingStatus, String> {
pub(crate) async fn poll_canister(&mut self) -> Result<PollingStatus, String> {
trace!("Started polling iteration");

// get messages to be relayed to clients from canister (starting from 'message_nonce')
match ws_get_messages(
&self.agent,
&self.canister_id,
CanisterWsGetMessagesArguments {
nonce: self.next_message_nonce,
},
// the response timeout of the IC CDK is 2 minutes which implies that the poller would be stuck for that long waiting for a response
// to prevent this, we set a timeout of 5 seconds, if the poller does not receive a response in time, it polls immediately
// in case of a timeout, the message nonce is not updated so that no messages are lost by polling immediately again
match timeout(
PollingTimeout::from_millis(POLLING_TIMEOUT_MS),
ws_get_messages(
&self.agent,
&self.canister_id,
CanisterWsGetMessagesArguments {
nonce: self.next_message_nonce,
},
),
)
.await
{
Ok(certified_canister_output) => {
Ok(Ok(certified_canister_output)) => {
let number_of_polled_messages = certified_canister_output.messages.len();
if number_of_polled_messages == 0 {
trace!("No messages polled from canister");
Expand All @@ -161,7 +179,7 @@ impl CanisterPoller {
Ok(PollingStatus::MessagesPolled(certified_canister_output))
}
},
Err(IcError::Agent(e)) => {
Ok(Err(IcError::Agent(e))) => {
if is_recoverable_error(&e) {
// if the error is due to a replica which is either actively malicious or simply unavailable
// or to a malfunctioning boundary node,
Expand All @@ -174,8 +192,12 @@ impl CanisterPoller {
Err(format!("Unrecoverable agent error: {:?}", e))
}
},
Err(IcError::Candid(e)) => Err(format!("Unrecoverable candid error: {:?}", e)),
Err(IcError::Cdk(e)) => Err(format!("Unrecoverable CDK error: {:?}", e)),
Ok(Err(IcError::Candid(e))) => Err(format!("Unrecoverable candid error: {:?}", e)),
Ok(Err(IcError::Cdk(e))) => Err(format!("Unrecoverable CDK error: {:?}", e)),
Err(e) => {
warn!("Poller took too long to retrieve messages: {:?}", e);
Ok(PollingStatus::PollerTimedOut)
},
}
}

Expand Down
1 change: 1 addition & 0 deletions src/ic-websocket-gateway/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async fn main() -> Result<(), String> {

// must be printed after initializing tracing to ensure that the info are captured
info!("Deployment info: {:?}", deployment_info);
info!("Cargo version: {}", env!("CARGO_PKG_VERSION"));
info!("Gateway Agent principal: {}", gateway_principal);

let tls_config = if deployment_info.tls_certificate_pem_path.is_some()
Expand Down
54 changes: 51 additions & 3 deletions src/ic-websocket-gateway/src/tests/canister_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ mod test {
use lazy_static::lazy_static;
use std::{
sync::{Arc, Mutex},
thread,
time::Duration,
};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tracing::Span;

use crate::canister_poller::{get_nonce_from_message, CanisterPoller};
use crate::canister_poller::{
get_nonce_from_message, CanisterPoller, PollingStatus, POLLING_TIMEOUT_MS,
};

struct MockCanisterOutputCertifiedMessages(CanisterOutputCertifiedMessages);

Expand Down Expand Up @@ -237,7 +240,7 @@ mod test {
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
// run 'cargo test -- --nocapture' to see the elapsed time
println!("Elapsed: {:?}", elapsed);
println!("Elapsed after relaying (should not sleep): {:?}", elapsed);
assert!(
elapsed > Duration::from_millis(polling_interval_ms)
// Reasonable to expect that the time it takes to sleep
Expand Down Expand Up @@ -294,7 +297,7 @@ mod test {
poller.poll_and_relay().await.expect("Failed to poll");
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
println!("Elapsed: {:?}", elapsed);
println!("Elapsed after relaying (should sleep): {:?}", elapsed);
assert!(
// The `poll_and_relay` function should not sleep for `polling_interval_ms`
// if the queue is not empty.
Expand All @@ -319,6 +322,51 @@ mod test {
drop(guard);
}

#[tokio::test]
async fn should_not_sleep_after_timeout() {
let server = &*MOCK_SERVER;
let path = "/ws_get_messages";
let mut guard = server.lock().unwrap();
// do not drop the guard until the end of this test to make sure that no other test interleaves and overwrites the mock response
let mock = guard
.mock("GET", path)
.with_chunked_body(|w| {
thread::sleep(Duration::from_millis(POLLING_TIMEOUT_MS + 10));
w.write_all(&vec![])
})
.expect(2)
.create_async()
.await;

let polling_interval_ms = 100;
let (client_channel_tx, _): (Sender<IcWsCanisterMessage>, Receiver<IcWsCanisterMessage>) =
mpsc::channel(100);

let mut poller = create_poller(polling_interval_ms, client_channel_tx);

// check that the poller times out
assert_eq!(
Ok(PollingStatus::PollerTimedOut),
poller.poll_canister().await
);

// check that the poller does not wait for a polling interval after timing out
let start_polling_instant = tokio::time::Instant::now();
poller.poll_and_relay().await.expect("Failed to poll");
let end_polling_instant = tokio::time::Instant::now();
let elapsed = end_polling_instant - start_polling_instant;
println!("Elapsed due to timeout: {:?}", elapsed);
assert!(
// The `poll_canister` function should not sleep for `polling_interval_ms`
// after the poller times out.
elapsed < Duration::from_millis(POLLING_TIMEOUT_MS + polling_interval_ms)
);

mock.assert_async().await;
// just to make it explicit that the guard should be kept for the whole duration of the test
drop(guard);
}

#[tokio::test]
async fn should_terminate_polling_with_error() {
let server = &*MOCK_SERVER;
Expand Down
3 changes: 2 additions & 1 deletion src/ic-websocket-gateway/src/ws_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ impl WsListener {
Level::DEBUG,
"Accept Connection",
client_addr = ?client_addr.ip(),
client_id = self.next_client_id
client_id = self.next_client_id,
cargo_version = env!("CARGO_PKG_VERSION"),
);
let client_id = self.next_client_id;
let tls_acceptor = self.tls_acceptor.clone();
Expand Down