diff --git a/Cargo.lock b/Cargo.lock index b7396d431..3ea161459 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,6 +150,28 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + [[package]] name = "async-trait" version = "0.1.77" @@ -1952,6 +1974,7 @@ dependencies = [ "tokio-native-tls", "tokio-rustls", "tokio-stream", + "tokio-test", "tokio-util", "url", "ws_stream_tungstenite", @@ -2557,6 +2580,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 1045cfcf1..f2e80ca6d 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Process multiple outgoing client requests before flushing the network buffer (reduces number of system calls) * `size()` method on `Packet` calculates size once serialized. * `read()` and `write()` methods on `Packet`. diff --git a/rumqttc/Cargo.toml b/rumqttc/Cargo.toml index bba64822c..b3a3f2165 100644 --- a/rumqttc/Cargo.toml +++ b/rumqttc/Cargo.toml @@ -57,6 +57,7 @@ matches = "0.1" pretty_assertions = "1" pretty_env_logger = "0.5" serde = { version = "1", features = ["derive"] } +tokio-test = "0.4.4" [[example]] name = "tls" diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index a9b1ce8c5..621a2830a 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -141,8 +141,8 @@ impl EventLoop { pub async fn poll(&mut self) -> Result { if self.network.is_none() { let (network, connack) = match time::timeout( - Duration::from_secs(self.network_options.connection_timeout()), - connect(&self.mqtt_options, self.network_options.clone()), + self.network_options.connection_timeout(), + connect(&self.mqtt_options, self.network_options()), ) .await { @@ -173,7 +173,7 @@ impl EventLoop { // let await_acks = self.state.await_acks; let inflight_full = self.state.inflight >= self.mqtt_options.inflight; let collision = self.state.collision.is_some(); - let network_timeout = Duration::from_secs(self.network_options.connection_timeout()); + let network_timeout = self.network_options.connection_timeout(); // Read buffered events from previous polls before calling a new poll if let Some(event) = self.state.events.pop_front() { @@ -258,10 +258,12 @@ impl EventLoop { } } - pub fn network_options(&self) -> NetworkOptions { - self.network_options.clone() + /// Get network options + pub fn network_options(&self) -> &NetworkOptions { + &self.network_options } + /// Set network options pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self { self.network_options = network_options; self @@ -293,7 +295,7 @@ impl EventLoop { /// between re-connections so that cancel semantics can be used during this sleep async fn connect( mqtt_options: &MqttOptions, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result<(Network, Incoming), ConnectionError> { // connect to the broker let mut network = network_connect(mqtt_options, network_options).await?; @@ -306,7 +308,7 @@ async fn connect( pub(crate) async fn socket_connect( host: String, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> io::Result { let addrs = lookup_host(host).await?; let mut last_err = None; @@ -352,7 +354,7 @@ pub(crate) async fn socket_connect( async fn network_connect( options: &MqttOptions, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result { // Process Unix files early, as proxy is not supported for them. #[cfg(unix)] diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 29cad1a34..7e28c4d3b 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -365,11 +365,12 @@ impl From for TlsConfiguration { } /// Provides a way to configure low level network connection configurations -#[derive(Clone, Default)] +#[derive(Clone, Debug, Default)] pub struct NetworkOptions { tcp_send_buffer_size: Option, tcp_recv_buffer_size: Option, - conn_timeout: u64, + /// Connection timeout + connection_timeout: Duration, #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] bind_device: Option, } @@ -379,7 +380,7 @@ impl NetworkOptions { NetworkOptions { tcp_send_buffer_size: None, tcp_recv_buffer_size: None, - conn_timeout: 5, + connection_timeout: Duration::from_secs(5), #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] bind_device: None, } @@ -393,15 +394,15 @@ impl NetworkOptions { self.tcp_recv_buffer_size = Some(size); } - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; + /// Set connection timeout + pub fn set_connection_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connection_timeout = timeout; self } - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout + /// Get connection timeout + pub fn connection_timeout(&self) -> Duration { + self.connection_timeout } /// bind connection to a specific network device by name @@ -443,7 +444,7 @@ pub struct MqttOptions { /// request (publish, subscribe) channel capacity request_channel_capacity: usize, /// Max internal request batching - max_request_batch: usize, + max_batch_size: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, @@ -483,7 +484,7 @@ impl MqttOptions { max_incoming_packet_size: 10 * 1024, max_outgoing_packet_size: 10 * 1024, request_channel_capacity: 10, - max_request_batch: 0, + max_batch_size: 0, pending_throttle: Duration::from_micros(0), inflight: 100, last_will: None, @@ -734,7 +735,7 @@ pub enum OptionError { RequestChannelCapacity, #[error("Invalid max-request-batch value.")] - MaxRequestBatch, + MaxBatchSize, #[error("Invalid pending-throttle value.")] PendingThrottle, @@ -842,12 +843,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") - .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) + if let Some(max_batch_size) = queries + .remove("max_batch_size") + .map(|v| v.parse::().map_err(|_| OptionError::MaxBatchSize)) .transpose()? { - options.max_request_batch = max_request_batch; + options.max_batch_size = max_batch_size; } if let Some(pending_throttle) = queries @@ -887,7 +888,7 @@ impl Debug for MqttOptions { .field("credentials", &self.credentials) .field("max_packet_size", &self.max_incoming_packet_size) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("max_batch_size", &self.max_batch_size) .field("pending_throttle", &self.pending_throttle) .field("inflight", &self.inflight) .field("last_will", &self.last_will) @@ -970,8 +971,8 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), - OptionError::MaxRequestBatch + err("mqtt://host:42?client_id=foo&max_batch_size=foo"), + OptionError::MaxBatchSize ); assert_eq!( err("mqtt://host:42?client_id=foo&pending_throttle_usecs=foo"), diff --git a/rumqttc/src/proxy.rs b/rumqttc/src/proxy.rs index 94c7aabd3..4d976df99 100644 --- a/rumqttc/src/proxy.rs +++ b/rumqttc/src/proxy.rs @@ -45,7 +45,7 @@ impl Proxy { self, broker_addr: &str, broker_port: u16, - network_options: NetworkOptions, + network_options: &NetworkOptions, ) -> Result, ProxyError> { let proxy_addr = format!("{}:{}", self.addr, self.port); diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index f8629b8c5..6b56640f5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -11,7 +11,7 @@ use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; use crate::valid_topic; use bytes::Bytes; -use flume::{SendError, Sender, TrySendError}; +use flume::{bounded, SendError, Sender, TrySendError}; use futures_util::FutureExt; use tokio::runtime::{self, Runtime}; use tokio::time::timeout; @@ -54,8 +54,8 @@ impl AsyncClient { /// /// `cap` specifies the capacity of the bounded async channel. pub fn new(options: MqttOptions, cap: usize) -> (AsyncClient, EventLoop) { - let eventloop = EventLoop::new(options, cap); - let request_tx = eventloop.requests_tx.clone(); + let (request_tx, request_rx) = bounded(cap); + let eventloop = EventLoop::new(options, request_rx); let client = AsyncClient { request_tx }; @@ -479,15 +479,15 @@ impl Client { /// /// `cap` specifies the capacity of the bounded async channel. pub fn new(options: MqttOptions, cap: usize) -> (Client, Connection) { - let (client, eventloop) = AsyncClient::new(options, cap); - let client = Client { client }; - let runtime = runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); + let (client, eventloop) = runtime.block_on(async { AsyncClient::new(options, cap) }); + let client = Client { client }; let connection = Connection::new(eventloop, runtime); + (client, connection) } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index a59094807..ac338b16f 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -1,19 +1,23 @@ -use super::framed::Network; -use super::mqttbytes::v5::*; -use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport}; -use crate::eventloop::socket_connect; -use crate::framed::AsyncReadWrite; - -use flume::{bounded, Receiver, Sender}; -use tokio::select; -use tokio::time::{self, error::Elapsed, Instant, Sleep}; +use super::{ + framed::Network, mqttbytes::v5::ConnectReturnCode, mqttbytes::v5::*, Incoming, MqttOptions, + MqttState, Outgoing, Request, StateError, Transport, +}; +use crate::{eventloop::socket_connect, framed::AsyncReadWrite}; -use std::collections::VecDeque; -use std::io; -use std::pin::Pin; -use std::time::Duration; +use flume::Receiver; +use futures_util::{Stream, StreamExt}; +use tokio::{ + select, + time::{self, error::Elapsed, timeout, Interval}, +}; -use super::mqttbytes::v5::ConnectReturnCode; +use std::{ + collections::VecDeque, + io, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] use crate::tls; @@ -51,6 +55,8 @@ pub enum ConnectionError { Io(#[from] io::Error), #[error("Connection refused, return code: `{0:?}`")] ConnectionRefused(ConnectReturnCode), + #[error("Connection closed")] + ConnectionClosed, #[error("Expected ConnAck packet, received: {0:?}")] NotConnAck(Box), #[error("Requests done")] @@ -73,15 +79,13 @@ pub struct EventLoop { /// Current state of the connection pub state: MqttState, /// Request stream - requests_rx: Receiver, - /// Requests handle to send requests - pub(crate) requests_tx: Sender, - /// Pending packets from last session - pub pending: VecDeque, + requests: Receiver, + /// Pending requests from the last session + pending: IntervalQueue, /// Network connection to the broker network: Option, /// Keep alive time - keepalive_timeout: Option>>, + keepalive_interval: Interval, } /// Events which can be yielded by the event loop @@ -93,23 +97,21 @@ pub enum Event { impl EventLoop { /// New MQTT `EventLoop` - /// - /// When connection encounters critical errors (like auth failure), user has a choice to - /// access and update `options`, `state` and `requests`. - pub fn new(options: MqttOptions, cap: usize) -> EventLoop { - let (requests_tx, requests_rx) = bounded(cap); - let pending = VecDeque::new(); + pub(crate) fn new(options: MqttOptions, requests: Receiver) -> EventLoop { let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX); let manual_acks = options.manual_acks; + let pending = IntervalQueue::new(options.pending_throttle); + let state = MqttState::new(inflight_limit, manual_acks); + assert!(!options.keep_alive.is_zero()); + let keepalive_interval = time::interval(options.keep_alive()); EventLoop { options, - state: MqttState::new(inflight_limit, manual_acks), - requests_tx, - requests_rx, + state, + requests, pending, network: None, - keepalive_timeout: None, + keepalive_interval, } } @@ -122,12 +124,10 @@ impl EventLoop { /// > For this reason we recommend setting [`AsycClient`](super::AsyncClient)'s channel capacity to `0`. pub fn clean(&mut self) { self.network = None; - self.keepalive_timeout = None; self.pending.extend(self.state.clean()); // drain requests from channel which weren't yet received - let requests_in_channel = self.requests_rx.drain(); - self.pending.extend(requests_in_channel); + self.pending.extend(self.requests.drain()); } /// Yields Next notification or outgoing request and periodically pings @@ -136,126 +136,183 @@ impl EventLoop { /// **NOTE** Don't block this while iterating pub async fn poll(&mut self) -> Result { if self.network.is_none() { - let (network, connack) = time::timeout( - Duration::from_secs(self.options.connection_timeout()), - connect(&mut self.options), - ) - .await??; + let connect_timeout = self.options.connect_timeout(); + let (network, connack) = timeout(connect_timeout, connect(&mut self.options)).await??; self.network = Some(network); - if self.keepalive_timeout.is_none() { - self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive))); - } - + // A connack never produces a response packet. Safe to ignore the return value + // of `handle_incoming_packet` self.state.handle_incoming_packet(connack)?; + + self.pending.reset_immediately(); + self.keepalive_interval.reset(); } - match self.select().await { - Ok(v) => Ok(v), - Err(e) => { - self.clean(); - Err(e) + // Read buffered events from previous polls before calling a new poll + if let Some(event) = self.state.events.pop_front() { + Ok(event) + } else { + match self.poll_process().await { + Ok(v) => Ok(v), + Err(e) => { + self.clean(); + Err(e) + } } } } /// Select on network and requests and generate keepalive pings when necessary - async fn select(&mut self) -> Result { + async fn poll_process(&mut self) -> Result { let network = self.network.as_mut().unwrap(); - // let await_acks = self.state.await_acks; + let network_timeout = self.options.network_options().connection_timeout(); + + for _ in 0..self.options.max_batch_size { + let inflight_full = self.state.is_inflight_full(); + let collision = self.state.has_collision(); + + select! { + // Handles pending and new requests. + // If available, prioritises pending requests from previous session. + // Else, pulls next request from user requests channel. + // If conditions in the below branch are for flow control. + // The branch is disabled if there's no pending messages and new user requests + // cannot be serviced due flow control. + // We read next user user request only when inflight messages are < configured inflight + // and there are no collisions while handling previous outgoing requests. + // + // Flow control is based on ack count. If inflight packet count in the buffer is + // less than max_inflight setting, next outgoing request will progress. For this + // to work correctly, broker should ack in sequence (a lot of brokers won't) + // + // E.g If max inflight = 5, user requests will be blocked when inflight queue + // looks like this -> [1, 2, 3, 4, 5]. + // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. + // This pulls next user request. But because max packet id = max_inflight, next + // user request's packet id will roll to 1. This replaces existing packet id 1. + // Resulting in a collision + // + // Eventloop can stop receiving outgoing user requests when previous outgoing + // request collided. I.e collision state. Collision state will be cleared only + // when correct ack is received + // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. + // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. + // After collision with pkid 1 -> [1b ,2, x, 4, 5]. + // 1a is saved to state and event loop is set to collision mode stopping new + // outgoing requests (along with 1b). + Some(request) = self.pending.next(), if !inflight_full && !collision => { + if let Some(packet) = self.state.handle_outgoing_packet(request)? { + timeout(network_timeout, network.write(packet)).await??; + } + }, + request = self.requests.recv_async(), if self.pending.is_empty() && !inflight_full && !collision => { + let request = request.map_err(|_| ConnectionError::RequestsDone)?; + if let Some(packet) = self.state.handle_outgoing_packet(request)? { + timeout(network_timeout, network.write(packet)).await??; + } + }, + // Process next packet from io + packet = network.read() => { + // Reset keepalive interval due to packet reception + self.keepalive_interval.reset(); + match packet? { + Some(packet) => if let Some(packet) = self.state.handle_incoming_packet(packet)? { + let flush = matches!(packet, Packet::PingResp(_)); + timeout(network_timeout, network.write(packet)).await??; + if flush { + break; + } + } + None => return Err(ConnectionError::ConnectionClosed), + } + }, + // Send a ping request on each interval tick + _ = self.keepalive_interval.tick() => { + if let Some(packet) = self.state.handle_outgoing_packet(Request::PingReq)? { + timeout(network_timeout, network.write(packet)).await??; + } + } + else => unreachable!("Eventloop select is exhaustive"), + } - let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight; - let collision = self.state.collision.is_some(); + // Break early if there is no request pending and no more incoming bytes polled into the read buffer + // This implementation is suboptimal: The loop is *not* broken if a incomplete packets resides in the + // rx buffer of `Network`. Until that frame is complete the outgoing queue is *not* flushed. + // Since the incomplete packet is already started to appear in the buffer it should be fine to await + // more data on the stream before flushing. + if self.pending.is_empty() + && self.requests.is_empty() + && network.read_buffer_remaining() == 0 + { + break; + } + } - // Read buffered events from previous polls before calling a new poll - if let Some(event) = self.state.events.pop_front() { - return Ok(event); + timeout(network_timeout, network.flush()).await??; + + self.state + .events + .pop_front() + .ok_or_else(|| unreachable!("empty event queue")) + } +} + +/// Pending items yielded with a configured rate. If the queue is empty the stream will yield pending. +struct IntervalQueue { + /// Interval + interval: Option, + /// Pending requests + queue: VecDeque, +} + +impl IntervalQueue { + /// Construct a new Pending instance + pub fn new(interval: Duration) -> Self { + let interval = (!interval.is_zero()).then(|| time::interval(interval)); + IntervalQueue { + interval, + queue: VecDeque::new(), } + } - // this loop is necessary since self.incoming.pop_front() might return None. In that case, - // instead of returning a None event, we try again. - select! { - // Handles pending and new requests. - // If available, prioritises pending requests from previous session. - // Else, pulls next request from user requests channel. - // If conditions in the below branch are for flow control. - // The branch is disabled if there's no pending messages and new user requests - // cannot be serviced due flow control. - // We read next user user request only when inflight messages are < configured inflight - // and there are no collisions while handling previous outgoing requests. - // - // Flow control is based on ack count. If inflight packet count in the buffer is - // less than max_inflight setting, next outgoing request will progress. For this - // to work correctly, broker should ack in sequence (a lot of brokers won't) - // - // E.g If max inflight = 5, user requests will be blocked when inflight queue - // looks like this -> [1, 2, 3, 4, 5]. - // If broker acking 2 instead of 1 -> [1, x, 3, 4, 5]. - // This pulls next user request. But because max packet id = max_inflight, next - // user request's packet id will roll to 1. This replaces existing packet id 1. - // Resulting in a collision - // - // Eventloop can stop receiving outgoing user requests when previous outgoing - // request collided. I.e collision state. Collision state will be cleared only - // when correct ack is received - // Full inflight queue will look like -> [1a, 2, 3, 4, 5]. - // If 3 is acked instead of 1 first -> [1a, 2, x, 4, 5]. - // After collision with pkid 1 -> [1b ,2, x, 4, 5]. - // 1a is saved to state and event loop is set to collision mode stopping new - // outgoing requests (along with 1b). - o = Self::next_request( - &mut self.pending, - &self.requests_rx, - self.options.pending_throttle - ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o { - Ok(request) => { - if let Some(outgoing) = self.state.handle_outgoing_packet(request)? { - network.write(outgoing).await?; - } - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - } - Err(_) => Err(ConnectionError::RequestsDone), - }, - // Pull a bunch of packets from network, reply in bunch and yield the first item - o = network.readb(&mut self.state) => { - o?; - // flush all the acks and return first incoming packet - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - }, - // We generate pings irrespective of network activity. This keeps the ping logic - // simple. We can change this behavior in future if necessary (to prevent extra pings) - _ = self.keepalive_timeout.as_mut().unwrap() => { - let timeout = self.keepalive_timeout.as_mut().unwrap(); - timeout.as_mut().reset(Instant::now() + self.options.keep_alive); - - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { - network.write(outgoing).await?; - } - network.flush().await?; - Ok(self.state.events.pop_front().unwrap()) - } + /// Returns true this queue is not empty + pub fn is_empty(&self) -> bool { + self.queue.is_empty() + } + + /// Extend the request queue + pub fn extend(&mut self, requests: impl IntoIterator) { + self.queue.extend(requests); + } + + /// Reset the pending interval tick. Next tick yields immediately + pub fn reset_immediately(&mut self) { + if let Some(interval) = self.interval.as_mut() { + interval.reset_immediately(); } } +} - async fn next_request( - pending: &mut VecDeque, - rx: &Receiver, - pending_throttle: Duration, - ) -> Result { - if !pending.is_empty() { - time::sleep(pending_throttle).await; - // We must call .next() AFTER sleep() otherwise .next() would - // advance the iterator but the future might be canceled before return - Ok(pending.pop_front().unwrap()) +impl Stream for IntervalQueue { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.is_empty() { + Poll::Pending } else { - match rx.recv_async().await { - Ok(r) => Ok(r), - Err(_) => Err(ConnectionError::RequestsDone), + match self.interval.as_mut() { + Some(interval) => match interval.poll_tick(cx) { + Poll::Ready(_) => Poll::Ready(self.queue.pop_front()), + Poll::Pending => Poll::Pending, + }, + None => Poll::Ready(self.queue.pop_front()), } } } + + fn size_hint(&self) -> (usize, Option) { + (self.queue.len(), Some(self.queue.len())) + } } /// This stream internally processes requests from the request stream provided to the eventloop @@ -269,12 +326,6 @@ async fn connect(options: &mut MqttOptions) -> Result<(Network, Incoming), Conne // make MQTT connection request (which internally awaits for ack) let packet = mqtt_connect(options, &mut network).await?; - - // Last session might contain packets which aren't acked. MQTT says these packets should be - // republished in the next session - // move pending messages from state to eventloop - // let pending = self.state.clean(); - // self.pending = pending.into_iter(); Ok((network, packet)) } @@ -391,6 +442,7 @@ async fn mqtt_connect( let keep_alive = options.keep_alive().as_secs() as u16; let clean_start = options.clean_start(); let client_id = options.client_id(); + let connect_timeout = options.connect_timeout(); let properties = options.connect_properties(); let connect = Connect { @@ -401,21 +453,90 @@ async fn mqtt_connect( }; // send mqtt connect packet - network.connect(connect, options).await?; + let last_will = options.last_will(); + let login = options.credentials(); + let connect = Packet::Connect(connect, last_will, login); + timeout(connect_timeout, async { + network.write(connect).await?; + network.flush().await + }) + .await??; // validate connack match network.read().await? { - Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => { - // Override local keep_alive value if set by server. + Some(Incoming::ConnAck(connack)) if connack.code == ConnectReturnCode::Success => { if let Some(props) = &connack.properties { + // Override local keep_alive value if set by server. if let Some(keep_alive) = props.server_keep_alive { options.keep_alive = Duration::from_secs(keep_alive as u64); } + + // Override max packet size network.set_max_outgoing_size(props.max_packet_size); } Ok(Packet::ConnAck(connack)) } - Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)), - packet => Err(ConnectionError::NotConnAck(Box::new(packet))), + Some(Incoming::ConnAck(connack)) => Err(ConnectionError::ConnectionRefused(connack.code)), + Some(packet) => Err(ConnectionError::NotConnAck(Box::new(packet))), + None => Err(ConnectionError::ConnectionClosed), } } + +#[tokio::test(start_paused = true)] +async fn connect_and_receive_connack() { + let mut options = MqttOptions::new("", "", 0); + + // Prepare a connect packet that is expected to be received. + let mut connect = bytes::BytesMut::new(); + Packet::Connect( + Connect { + keep_alive: options.keep_alive().as_secs() as u16, + client_id: options.client_id(), + clean_start: options.clean_start(), + properties: options.connect_properties(), + }, + options.last_will(), + options.credentials(), + ) + .write(&mut connect, None) + .ok(); + + // Prepare connect ack + let mut connect_ack = bytes::BytesMut::new(); + Packet::ConnAck(ConnAck { + session_present: false, + code: ConnectReturnCode::Success, + properties: None, + }) + .write(&mut connect_ack, None) + .ok(); + + // IO will assume a connect packet and *not* reply with a connack. + let io = tokio_test::io::Builder::new() + .write(&connect) + .read(&connect_ack) + .build(); + let mut network = Network::new(io, None); + + // Operation should timeout because io flush will not resolve. + let result = mqtt_connect(&mut options, &mut network).await; + + assert!(matches!(dbg!(result), Ok(Packet::ConnAck(ConnAck { .. })))); +} + +#[tokio::test(start_paused = true)] +async fn connect_timeouts_connect_packet_write() { + let mut options = MqttOptions::new("", "", 0); + options.set_connect_timeout(Duration::from_secs(10)); + + // IO will not accept the connect packet write + let io = tokio_test::io::Builder::new() + .wait(Duration::from_secs(30)) + .build(); + let mut network = Network::new(io, None); + + // Operation should timeout because io flush will not resolve. + let result = mqtt_connect(&mut options, &mut network).await; + + assert!(matches!(result, Err(ConnectionError::Timeout(_)))); +} diff --git a/rumqttc/src/v5/framed.rs b/rumqttc/src/v5/framed.rs index c7e06a250..ced4ebb62 100644 --- a/rumqttc/src/v5/framed.rs +++ b/rumqttc/src/v5/framed.rs @@ -1,12 +1,14 @@ -use futures_util::{FutureExt, SinkExt}; +use bytes::Buf; +use futures_util::SinkExt; use tokio_stream::StreamExt; use tokio_util::codec::Framed; use crate::framed::AsyncReadWrite; -use super::mqttbytes::v5::Packet; -use super::{mqttbytes, Codec, Connect, MqttOptions, MqttState}; -use super::{Incoming, StateError}; +use super::{ + mqttbytes::v5::Packet, + Codec, {Incoming, StateError}, +}; /// Network transforms packets <-> frames efficiently. It takes /// advantage of pre-allocation, buffering and vectorization when @@ -14,8 +16,6 @@ use super::{Incoming, StateError}; pub struct Network { /// Frame MQTT packets from network connection framed: Framed, Codec>, - /// Maximum readv count - max_readb_count: usize, } impl Network { pub fn new(socket: impl AsyncReadWrite + 'static, max_incoming_size: Option) -> Network { @@ -26,58 +26,26 @@ impl Network { }; let framed = Framed::new(socket, codec); - Network { - framed, - max_readb_count: 10, - } + Network { framed } } pub fn set_max_outgoing_size(&mut self, max_outgoing_size: Option) { self.framed.codec_mut().max_outgoing_size = max_outgoing_size; } + pub fn read_buffer_remaining(&self) -> usize { + self.framed.read_buffer().remaining() + } + /// Reads and returns a single packet from network - pub async fn read(&mut self) -> Result { + pub async fn read(&mut self) -> Result, StateError> { match self.framed.next().await { - Some(Ok(packet)) => Ok(packet), - Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), + Some(Ok(packet)) => Ok(Some(packet)), Some(Err(e)) => Err(StateError::Deserialization(e)), - None => Err(StateError::ConnectionAborted), + None => Ok(None), } } - /// Read packets in bulk. This allow replies to be in bulk. This method is used - /// after the connection is established to read a bunch of incoming packets - pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> { - // wait for the first read - let mut res = self.framed.next().await; - let mut count = 1; - loop { - match res { - Some(Ok(packet)) => { - if let Some(outgoing) = state.handle_incoming_packet(packet)? { - self.write(outgoing).await?; - } - - count += 1; - if count >= self.max_readb_count { - break; - } - } - Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(), - Some(Err(e)) => return Err(StateError::Deserialization(e)), - None => return Err(StateError::ConnectionAborted), - } - // do not wait for subsequent reads - match self.framed.next().now_or_never() { - Some(r) => res = r, - _ => break, - }; - } - - Ok(()) - } - /// Serializes packet into write buffer pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> { self.framed @@ -86,19 +54,7 @@ impl Network { .map_err(StateError::Deserialization) } - pub async fn connect( - &mut self, - connect: Connect, - options: &MqttOptions, - ) -> Result<(), StateError> { - let last_will = options.last_will(); - let login = options.credentials(); - self.write(Packet::Connect(connect, last_will, login)) - .await?; - - self.flush().await - } - + /// Flush the outgoing sink pub async fn flush(&mut self) -> Result<(), StateError> { self.framed .flush() diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 44499cde2..f71f66330 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -78,15 +78,15 @@ pub struct MqttOptions { credentials: Option, /// request (publish, subscribe) channel capacity request_channel_capacity: usize, - /// Max internal request batching - max_request_batch: usize, + /// Max batch processing size + max_batch_size: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets pending_throttle: Duration, /// Last will that will be issued on unexpected disconnect last_will: Option, - /// Connection timeout - conn_timeout: u64, + /// Connect timeout + connect_timeout: Duration, /// Default value of for maximum incoming packet size. /// Used when `max_incomming_size` in `connect_properties` is NOT available. default_max_incoming_size: u32, @@ -95,6 +95,7 @@ pub struct MqttOptions { /// If set to `true` MQTT acknowledgements are not sent automatically. /// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method. manual_acks: bool, + /// Network options network_options: NetworkOptions, #[cfg(feature = "proxy")] /// Proxy configuration. @@ -126,10 +127,10 @@ impl MqttOptions { client_id: id.into(), credentials: None, request_channel_capacity: 10, - max_request_batch: 0, + max_batch_size: 10, pending_throttle: Duration::from_micros(0), last_will: None, - conn_timeout: 5, + connect_timeout: Duration::from_secs(5), default_max_incoming_size: 10 * 1024, connect_properties: None, manual_acks: false, @@ -290,15 +291,15 @@ impl MqttOptions { self.pending_throttle } - /// set connection timeout in secs - pub fn set_connection_timeout(&mut self, timeout: u64) -> &mut Self { - self.conn_timeout = timeout; + /// set connect timeout + pub fn set_connect_timeout(&mut self, timeout: Duration) -> &mut Self { + self.connect_timeout = timeout; self } - /// get timeout in secs - pub fn connection_timeout(&self) -> u64 { - self.conn_timeout + /// get connect timeout + pub fn connect_timeout(&self) -> Duration { + self.connect_timeout } /// set connection properties @@ -494,10 +495,12 @@ impl MqttOptions { self.manual_acks } - pub fn network_options(&self) -> NetworkOptions { - self.network_options.clone() + /// get network options + pub fn network_options(&self) -> &NetworkOptions { + &self.network_options } + /// set network options pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self { self.network_options = network_options; self @@ -654,12 +657,12 @@ impl std::convert::TryFrom for MqttOptions { options.request_channel_capacity = request_channel_capacity; } - if let Some(max_request_batch) = queries - .remove("max_request_batch_num") + if let Some(max_batch_size) = queries + .remove("max_batch_size") .map(|v| v.parse::().map_err(|_| OptionError::MaxRequestBatch)) .transpose()? { - options.max_request_batch = max_request_batch; + options.max_batch_size = max_batch_size; } if let Some(pending_throttle) = queries @@ -676,11 +679,12 @@ impl std::convert::TryFrom for MqttOptions { .transpose()?; if let Some(conn_timeout) = queries - .remove("conn_timeout_secs") + .remove("connect_timeout_secs") .map(|v| v.parse::().map_err(|_| OptionError::ConnTimeout)) .transpose()? { - options.set_connection_timeout(conn_timeout); + let conn_timeout = Duration::from_secs(conn_timeout); + options.set_connect_timeout(conn_timeout); } if let Some((opt, _)) = queries.into_iter().next() { @@ -704,11 +708,12 @@ impl Debug for MqttOptions { .field("client_id", &self.client_id) .field("credentials", &self.credentials) .field("request_channel_capacity", &self.request_channel_capacity) - .field("max_request_batch", &self.max_request_batch) + .field("max_request_batch", &self.max_batch_size) .field("pending_throttle", &self.pending_throttle) .field("last_will", &self.last_will) - .field("conn_timeout", &self.conn_timeout) + .field("connect_timeout", &self.connect_timeout) .field("manual_acks", &self.manual_acks) + .field("network_options", &self.network_options) .field("connect properties", &self.connect_properties) .finish() } @@ -785,7 +790,7 @@ mod test { OptionError::RequestChannelCapacity ); assert_eq!( - err("mqtt://host:42?client_id=foo&max_request_batch_num=foo"), + err("mqtt://host:42?client_id=foo&max_batch_size=foo"), OptionError::MaxRequestBatch ); assert_eq!( @@ -797,7 +802,7 @@ mod test { OptionError::Inflight ); assert_eq!( - err("mqtt://host:42?client_id=foo&conn_timeout_secs=foo"), + err("mqtt://host:42?client_id=foo&connect_timeout_secs=foo"), OptionError::ConnTimeout ); } diff --git a/rumqttc/src/v5/mqttbytes/v5/codec.rs b/rumqttc/src/v5/mqttbytes/v5/codec.rs index 76909d62d..832ceaefe 100644 --- a/rumqttc/src/v5/mqttbytes/v5/codec.rs +++ b/rumqttc/src/v5/mqttbytes/v5/codec.rs @@ -3,8 +3,8 @@ use tokio_util::codec::{Decoder, Encoder}; use super::{Error, Packet}; -/// MQTT v4 codec -#[derive(Debug, Clone)] +/// MQTT v5 codec +#[derive(Default, Debug, Clone)] pub struct Codec { /// Maximum packet size allowed by client pub max_incoming_size: Option, @@ -33,16 +33,14 @@ impl Encoder for Codec { type Error = Error; fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { - item.write(dst, self.max_outgoing_size)?; - - Ok(()) + item.write(dst, self.max_outgoing_size).map(drop) } } #[cfg(test)] mod tests { - use bytes::BytesMut; - use tokio_util::codec::Encoder; + use bytes::{Buf, BytesMut}; + use tokio_util::codec::{Decoder, Encoder}; use super::Codec; use crate::v5::{ @@ -73,4 +71,55 @@ mod tests { _ => unreachable!(), } } + + #[test] + fn encode_decode_multiple_packets() { + let mut buf = BytesMut::new(); + let mut codec = Codec::default(); + let publish = Packet::Publish(Publish::new( + "hello/world", + QoS::AtMostOnce, + vec![1; 10], + None, + )); + + // Encode a fixed number of publications into `buf` + for _ in 0..100 { + codec + .encode(publish.clone(), &mut buf) + .expect("failed to encode"); + } + + // Decode a fixed number of packets from `buf` + for _ in 0..100 { + let result = codec.decode(&mut buf).expect("failed to encode"); + assert!(matches!(result, Some(p) if p == publish)); + } + + assert_eq!(buf.remaining(), 0); + } + + #[test] + fn decode_insufficient() { + let mut buf = BytesMut::new(); + let mut codec = Codec::default(); + let publish = Packet::Publish(Publish::new( + "hello/world", + QoS::AtMostOnce, + vec![1; 100], + None, + )); + + // Encode packet into `buf` + codec + .encode(publish.clone(), &mut buf) + .expect("failed to encode"); + let result = codec.decode(&mut buf); + assert!(matches!(result, Ok(Some(p)) if p == publish)); + + buf.resize(buf.remaining() / 2, 0); + + let result = codec.decode(&mut buf); + assert!(matches!(result, Ok(None))); + } } diff --git a/rumqttc/src/v5/mqttbytes/v5/mod.rs b/rumqttc/src/v5/mqttbytes/v5/mod.rs index 342278596..d2cab8a8e 100644 --- a/rumqttc/src/v5/mqttbytes/v5/mod.rs +++ b/rumqttc/src/v5/mqttbytes/v5/mod.rs @@ -129,8 +129,10 @@ impl Packet { } pub fn write(&self, write: &mut BytesMut, max_size: Option) -> Result { + let size = self.size(); + if let Some(max_size) = max_size { - if self.size() > max_size as usize { + if size > max_size as usize { return Err(Error::OutgoingPacketTooLarge { pkt_size: self.size() as u32, max: max_size, @@ -138,6 +140,9 @@ impl Packet { } } + // Ensure that `write` can take the serialized packet + write.reserve(size); + match self { Self::Publish(publish) => publish.write(write), Self::Subscribe(subscription) => subscription.write(write), diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 854aa7b0f..ef1e4939c 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -9,14 +9,11 @@ use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use std::collections::{HashMap, VecDeque}; -use std::{io, time::Instant}; +use std::time::Instant; /// Errors during state handling #[derive(Debug, thiserror::Error)] pub enum StateError { - /// Io Error while state is passed to network - #[error("Io error: {0:?}")] - Io(#[from] io::Error), #[error("Conversion error {0:?}")] Coversion(#[from] core::num::TryFromIntError), /// Invalid state for a given operation @@ -64,8 +61,6 @@ pub enum StateError { PubCompFail { reason: PubCompReason }, #[error("Connection failed with reason '{reason:?}' ")] ConnFail { reason: ConnectReturnCode }, - #[error("Connection closed by peer abruptly")] - ConnectionAborted } impl From for StateError { @@ -185,6 +180,16 @@ impl MqttState { self.inflight } + /// Returns true if the inflight limit is reached + pub fn is_inflight_full(&self) -> bool { + self.inflight >= self.max_outgoing_inflight + } + + /// Returns true if the state has a unresolved collision + pub fn has_collision(&self) -> bool { + self.collision.is_some() + } + /// Consolidates handling of all outgoing mqtt packet logic. Returns a packet which should /// be put on to the network by the eventloop pub fn handle_outgoing_packet(