From 92b933c9cfc546ce98d9b7824aa82f63e12a80c6 Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Sun, 19 Apr 2020 13:51:54 -0400 Subject: [PATCH 1/8] RUST-384 Flush connection read buffer before sending new message --- Cargo.toml | 2 +- src/cmap/conn/mod.rs | 22 +++++++++++++-- src/cmap/conn/wire/message.rs | 8 +++--- src/test/client.rs | 40 +++++++++++++++++++++++++++ src/test/temp.rs | 51 +++++++++++++++++++++++++++++++++++ 5 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 src/test/temp.rs diff --git a/Cargo.toml b/Cargo.toml index 185b9d18e..384c1ea88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ sha2 = "0.8.0" stringprep = "0.1.2" time = "0.1.42" trust-dns-proto = "0.19.4" -trust-dns-resolver = "0.19.0" +trust-dns-resolver = "0.19.4" typed-builder = "0.3.0" version_check = "0.9.1" webpki = "0.21.0" diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index 0d2063db6..c8bfddce5 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -3,6 +3,7 @@ mod stream_description; mod wire; use std::{ + collections::VecDeque, sync::{Arc, Weak}, time::{Duration, Instant}, }; @@ -63,6 +64,8 @@ pub(crate) struct Connection { stream: AsyncStream, + pending_request_ids: VecDeque, + #[derivative(Debug = "ignore")] handler: Option>, } @@ -90,6 +93,7 @@ impl Connection { address, handler: options.and_then(|options| options.event_handler), stream_description: None, + pending_request_ids: Default::default(), }; Ok(conn) @@ -207,9 +211,11 @@ impl Connection { request_id: impl Into>, ) -> Result { let message = Message::with_command(command, request_id.into()); - message.write_to(&mut self.stream).await?; + let request_id = message.write_to(&mut self.stream).await?; + self.pending_request_ids.push_back(request_id); let response_message = Message::read_from(&mut self.stream).await?; + self.pending_request_ids.pop_front(); CommandResponse::new(self.address.clone(), response_message) } @@ -247,6 +253,7 @@ impl Connection { stream: std::mem::replace(&mut self.stream, AsyncStream::Null), handler: self.handler.take(), stream_description: self.stream_description.take(), + pending_request_ids: self.pending_request_ids.drain(..).collect(), } } } @@ -263,8 +270,17 @@ impl Drop for Connection { // helper explicitly, so we don't add it back to the pool or emit any events. if let Some(ref weak_pool_ref) = self.pool { if let Some(strong_pool_ref) = weak_pool_ref.upgrade() { - let dropped_connection_state = self.take(); + let mut dropped_connection_state = self.take(); RUNTIME.execute(async move { + while dropped_connection_state + .pending_request_ids + .pop_front() + .is_some() + { + let _: Result<_> = + Message::read_from(&mut dropped_connection_state.stream).await; + } + strong_pool_ref .check_in(dropped_connection_state.into()) .await; @@ -295,6 +311,7 @@ struct DroppedConnectionState { #[derivative(Debug = "ignore")] handler: Option>, stream_description: Option, + pending_request_ids: VecDeque, } impl Drop for DroppedConnectionState { @@ -322,6 +339,7 @@ impl From for Connection { handler: state.handler.take(), stream_description: state.stream_description.take(), ready_and_available_time: None, + pending_request_ids: state.pending_request_ids.drain(..).collect(), pool: None, } } diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 75bab3922..36450837f 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -117,7 +117,7 @@ impl Message { } /// Serializes the Message to bytes and writes them to `writer`. - pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result<()> { + pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result { let mut sections_bytes = Vec::new(); for section in &self.sections { @@ -133,9 +133,11 @@ impl Message { .map(std::mem::size_of_val) .unwrap_or(0); + let request_id = self.request_id.unwrap_or_else(super::util::next_request_id); + let header = Header { length: total_length as i32, - request_id: self.request_id.unwrap_or_else(super::util::next_request_id), + request_id, response_to: self.response_to, op_code: OpCode::Message, }; @@ -150,7 +152,7 @@ impl Message { writer.flush().await?; - Ok(()) + Ok(request_id) } } diff --git a/src/test/client.rs b/src/test/client.rs index 883a559a4..dd6ba21a1 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -13,6 +13,7 @@ use crate::{ selection_criteria::{ReadPreference, SelectionCriteria}, test::{util::TestClient, CLIENT_OPTIONS, LOCK}, Client, + RUNTIME, }; #[derive(Debug, Deserialize)] @@ -532,3 +533,42 @@ async fn saslprep_uri() { auth_test_uri("%E2%85%A8", "IV", None, true).await; auth_test_uri("%E2%85%A8", "I%C2%ADV", None, true).await; } + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn future_drop_corrupt_issue() { + let _guard = LOCK.run_concurrently().await; + + let options = CLIENT_OPTIONS.clone(); + + let client = Client::with_options(options.clone()).unwrap(); + let db = client.database("test"); + + db.collection("foo") + .insert_one(doc! { "x": 1 }, None) + .await + .unwrap(); + + let _: Result<_, _> = tokio::time::timeout( + Duration::from_millis(50), + db.run_command( + doc! { "count": "foo", + "query": { + "$where": "sleep(100) && true" + } + }, + None, + ), + ) + .await; + + RUNTIME.delay_for(Duration::from_millis(200)).await; + + let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await; + + // Ensure that the response to `isMaster` is read, not the response to `count`. + assert!(is_master_response + .ok() + .and_then(|value| value.get("ismaster").and_then(|value| value.as_bool())) + .is_some()); +} diff --git a/src/test/temp.rs b/src/test/temp.rs new file mode 100644 index 000000000..59c45e148 --- /dev/null +++ b/src/test/temp.rs @@ -0,0 +1,51 @@ +use std::time::Duration; + +use bson::doc; + +use crate::{ + test::{CLIENT_OPTIONS, LOCK}, + Client, + RUNTIME, +}; + +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn future_drop_corrupt_issue() { + let _guard = LOCK.run_concurrently().await; + + let options = CLIENT_OPTIONS.clone(); + + let client = Client::with_options(options.clone()).unwrap(); + let db = client.database("test"); + + db.collection("foo") + .insert_one(doc! { "x": 1 }, None) + .await + .unwrap(); + + let _: Result<_, _> = tokio::time::timeout( + Duration::from_millis(50), + db.run_command( + doc! { "count": "foo", + "query": { + "$where": "sleep(100) && true" + } + }, + None, + ), + ) + .await; + + RUNTIME.delay_for(Duration::from_millis(200)).await; + + let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await; + + // it's going to fail because is_master_response contains response for this count command + // instead of isMaster + assert!(is_master_response + .ok() + .and_then(|value| dbg!(value) + .get("ismaster") + .and_then(|value| value.as_bool())) + .is_some()); +} From f7c87945caef6405584d03525a96a2bb2fb18d6a Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Tue, 21 Apr 2020 14:40:42 -0400 Subject: [PATCH 2/8] code review feedback --- src/cmap/conn/mod.rs | 48 ++++++++++++++++++++--------------- src/cmap/conn/wire/message.rs | 33 ++++++++++++++---------- src/cmap/conn/wire/test.rs | 4 ++- src/test/client.rs | 4 +-- 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index c8bfddce5..b72a2eaa7 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -3,12 +3,12 @@ mod stream_description; mod wire; use std::{ - collections::VecDeque, sync::{Arc, Weak}, time::{Duration, Instant}, }; use derivative::Derivative; +use futures::io::AsyncReadExt; use self::wire::Message; use super::ConnectionPoolInner; @@ -42,6 +42,12 @@ pub struct ConnectionInfo { pub address: StreamAddress, } +#[derive(Clone, Debug, Default)] +pub(crate) struct PartialMessageState { + bytes_remaining: usize, + needs_response: bool, +} + /// A wrapper around Stream that contains all the CMAP information needed to maintain a connection. #[derive(Derivative)] #[derivative(Debug)] @@ -64,7 +70,7 @@ pub(crate) struct Connection { stream: AsyncStream, - pending_request_ids: VecDeque, + partial_message_state: PartialMessageState, #[derivative(Debug = "ignore")] handler: Option>, @@ -93,7 +99,7 @@ impl Connection { address, handler: options.and_then(|options| options.event_handler), stream_description: None, - pending_request_ids: Default::default(), + partial_message_state: Default::default(), }; Ok(conn) @@ -210,12 +216,23 @@ impl Connection { command: Command, request_id: impl Into>, ) -> Result { + if self.partial_message_state.needs_response { + Message::read_from(&mut self.stream, &mut self.partial_message_state).await?; + } + + if self.partial_message_state.bytes_remaining > 0 { + let mut bytes = vec![0u8; self.partial_message_state.bytes_remaining]; + self.stream.read_exact(&mut bytes).await?; + self.partial_message_state.bytes_remaining = 0; + } + let message = Message::with_command(command, request_id.into()); - let request_id = message.write_to(&mut self.stream).await?; - self.pending_request_ids.push_back(request_id); + message.write_to(&mut self.stream).await?; - let response_message = Message::read_from(&mut self.stream).await?; - self.pending_request_ids.pop_front(); + self.partial_message_state.needs_response = true; + + let response_message = + Message::read_from(&mut self.stream, &mut self.partial_message_state).await?; CommandResponse::new(self.address.clone(), response_message) } @@ -253,7 +270,7 @@ impl Connection { stream: std::mem::replace(&mut self.stream, AsyncStream::Null), handler: self.handler.take(), stream_description: self.stream_description.take(), - pending_request_ids: self.pending_request_ids.drain(..).collect(), + partial_message_state: self.partial_message_state.clone(), } } } @@ -270,17 +287,8 @@ impl Drop for Connection { // helper explicitly, so we don't add it back to the pool or emit any events. if let Some(ref weak_pool_ref) = self.pool { if let Some(strong_pool_ref) = weak_pool_ref.upgrade() { - let mut dropped_connection_state = self.take(); + let dropped_connection_state = self.take(); RUNTIME.execute(async move { - while dropped_connection_state - .pending_request_ids - .pop_front() - .is_some() - { - let _: Result<_> = - Message::read_from(&mut dropped_connection_state.stream).await; - } - strong_pool_ref .check_in(dropped_connection_state.into()) .await; @@ -311,7 +319,7 @@ struct DroppedConnectionState { #[derivative(Debug = "ignore")] handler: Option>, stream_description: Option, - pending_request_ids: VecDeque, + partial_message_state: PartialMessageState, } impl Drop for DroppedConnectionState { @@ -339,8 +347,8 @@ impl From for Connection { handler: state.handler.take(), stream_description: state.stream_description.take(), ready_and_available_time: None, - pending_request_ids: state.pending_request_ids.drain(..).collect(), pool: None, + partial_message_state: state.partial_message_state.clone(), } } } diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 36450837f..965c437b7 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ bson_util::async_encoding, - cmap::conn::command::Command, + cmap::conn::{command::Command, PartialMessageState}, error::{ErrorKind, Result}, runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite, AsyncStream}, }; @@ -75,33 +75,40 @@ impl Message { } /// Reads bytes from `reader` and deserializes them into a Message. - pub(crate) async fn read_from(reader: &mut AsyncStream) -> Result { + pub(crate) async fn read_from( + reader: &mut AsyncStream, + partial_message_state: &mut PartialMessageState, + ) -> Result { let header = Header::read_from(reader).await?; - let mut length_remaining = header.length - Header::LENGTH as i32; + partial_message_state.bytes_remaining = header.length as usize - Header::LENGTH; + partial_message_state.needs_response = false; let flags = MessageFlags::from_bits_truncate(reader.read_u32().await?); - length_remaining -= std::mem::size_of::() as i32; + partial_message_state.bytes_remaining -= std::mem::size_of::(); let mut count_reader = CountReader::new(reader); let mut sections = Vec::new(); - while length_remaining - count_reader.bytes_read() as i32 > 4 { + while partial_message_state.bytes_remaining - count_reader.bytes_read() > 4 { sections.push(MessageSection::read(&mut count_reader).await?); } - length_remaining -= count_reader.bytes_read() as i32; + partial_message_state.bytes_remaining -= count_reader.bytes_read(); let mut checksum = None; - if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) { + if partial_message_state.bytes_remaining == 4 + && flags.contains(MessageFlags::CHECKSUM_PRESENT) + { checksum = Some(reader.read_u32().await?); - } else if length_remaining != 0 { + } else if partial_message_state.bytes_remaining != 0 { return Err(ErrorKind::OperationError { message: format!( "The server indicated that the reply would be {} bytes long, but it instead \ was {}", header.length, - header.length - length_remaining + count_reader.bytes_read() as i32, + header.length as usize - partial_message_state.bytes_remaining + + count_reader.bytes_read(), ), } .into()); @@ -117,7 +124,7 @@ impl Message { } /// Serializes the Message to bytes and writes them to `writer`. - pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result { + pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result<()> { let mut sections_bytes = Vec::new(); for section in &self.sections { @@ -133,11 +140,9 @@ impl Message { .map(std::mem::size_of_val) .unwrap_or(0); - let request_id = self.request_id.unwrap_or_else(super::util::next_request_id); - let header = Header { length: total_length as i32, - request_id, + request_id: self.request_id.unwrap_or_else(super::util::next_request_id), response_to: self.response_to, op_code: OpCode::Message, }; @@ -152,7 +157,7 @@ impl Message { writer.flush().await?; - Ok(request_id) + Ok(()) } } diff --git a/src/cmap/conn/wire/test.rs b/src/cmap/conn/wire/test.rs index 59cbbbe70..8a8c000fe 100644 --- a/src/cmap/conn/wire/test.rs +++ b/src/cmap/conn/wire/test.rs @@ -35,7 +35,9 @@ async fn basic() { let mut stream = AsyncStream::connect(options).await.unwrap(); message.write_to(&mut stream).await.unwrap(); - let reply = Message::read_from(&mut stream).await.unwrap(); + let reply = Message::read_from(&mut stream, &mut Default::default()) + .await + .unwrap(); let response_doc = match reply.sections.into_iter().next().unwrap() { MessageSection::Document(doc) => doc, diff --git a/src/test/client.rs b/src/test/client.rs index dd6ba21a1..1d51e886e 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -536,7 +536,7 @@ async fn saslprep_uri() { #[cfg_attr(feature = "tokio-runtime", tokio::test)] #[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn future_drop_corrupt_issue() { +async fn future_drop_flush_response() { let _guard = LOCK.run_concurrently().await; let options = CLIENT_OPTIONS.clone(); @@ -567,7 +567,7 @@ async fn future_drop_corrupt_issue() { let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await; // Ensure that the response to `isMaster` is read, not the response to `count`. - assert!(is_master_response + assert!(dbg!(is_master_response) .ok() .and_then(|value| value.get("ismaster").and_then(|value| value.as_bool())) .is_some()); From 0a1ae23cd9570a7c76f77db79a778cb32fa8b28f Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Fri, 24 Apr 2020 13:32:28 -0400 Subject: [PATCH 3/8] fix timeout --- src/test/client.rs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/test/client.rs b/src/test/client.rs index 1d51e886e..114eafca2 100644 --- a/src/test/client.rs +++ b/src/test/client.rs @@ -549,25 +549,26 @@ async fn future_drop_flush_response() { .await .unwrap(); - let _: Result<_, _> = tokio::time::timeout( - Duration::from_millis(50), - db.run_command( - doc! { "count": "foo", - "query": { - "$where": "sleep(100) && true" - } - }, - None, - ), - ) - .await; + let _: Result<_, _> = RUNTIME + .timeout( + Duration::from_millis(50), + db.run_command( + doc! { "count": "foo", + "query": { + "$where": "sleep(100) && true" + } + }, + None, + ), + ) + .await; RUNTIME.delay_for(Duration::from_millis(200)).await; let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await; // Ensure that the response to `isMaster` is read, not the response to `count`. - assert!(dbg!(is_master_response) + assert!(is_master_response .ok() .and_then(|value| value.get("ismaster").and_then(|value| value.as_bool())) .is_some()); From 4ad5c3a2d3af4695a8c25699bdac72913e269a92 Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Tue, 28 Apr 2020 13:39:15 -0400 Subject: [PATCH 4/8] remove accidentally-committed file --- src/test/temp.rs | 51 ------------------------------------------------ 1 file changed, 51 deletions(-) delete mode 100644 src/test/temp.rs diff --git a/src/test/temp.rs b/src/test/temp.rs deleted file mode 100644 index 59c45e148..000000000 --- a/src/test/temp.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::time::Duration; - -use bson::doc; - -use crate::{ - test::{CLIENT_OPTIONS, LOCK}, - Client, - RUNTIME, -}; - -#[cfg_attr(feature = "tokio-runtime", tokio::test)] -#[cfg_attr(feature = "async-std-runtime", async_std::test)] -async fn future_drop_corrupt_issue() { - let _guard = LOCK.run_concurrently().await; - - let options = CLIENT_OPTIONS.clone(); - - let client = Client::with_options(options.clone()).unwrap(); - let db = client.database("test"); - - db.collection("foo") - .insert_one(doc! { "x": 1 }, None) - .await - .unwrap(); - - let _: Result<_, _> = tokio::time::timeout( - Duration::from_millis(50), - db.run_command( - doc! { "count": "foo", - "query": { - "$where": "sleep(100) && true" - } - }, - None, - ), - ) - .await; - - RUNTIME.delay_for(Duration::from_millis(200)).await; - - let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await; - - // it's going to fail because is_master_response contains response for this count command - // instead of isMaster - assert!(is_master_response - .ok() - .and_then(|value| dbg!(value) - .get("ismaster") - .and_then(|value| value.as_bool())) - .is_some()); -} From d6c2298939ab2b306446bbcdebb247542b6869f5 Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Tue, 28 Apr 2020 15:00:20 -0400 Subject: [PATCH 5/8] separate reading length from reading header --- src/cmap/conn/wire/header.rs | 4 ++-- src/cmap/conn/wire/message.rs | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/cmap/conn/wire/header.rs b/src/cmap/conn/wire/header.rs index 976618d49..12e22067f 100644 --- a/src/cmap/conn/wire/header.rs +++ b/src/cmap/conn/wire/header.rs @@ -51,9 +51,9 @@ impl Header { } /// Reads bytes from `r` and deserializes them into a header. - pub(crate) async fn read_from(stream: &mut AsyncStream) -> Result { + pub(crate) async fn read_from(length: i32, stream: &mut AsyncStream) -> Result { Ok(Self { - length: stream.read_i32().await?, + length, request_id: stream.read_i32().await?, response_to: stream.read_i32().await?, op_code: OpCode::from_i32(stream.read_i32().await?)?, diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 965c437b7..164d11e1e 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -79,10 +79,11 @@ impl Message { reader: &mut AsyncStream, partial_message_state: &mut PartialMessageState, ) -> Result { - let header = Header::read_from(reader).await?; - partial_message_state.bytes_remaining = header.length as usize - Header::LENGTH; + let length = reader.read_i32().await?; + partial_message_state.bytes_remaining = length as usize - Header::LENGTH; partial_message_state.needs_response = false; + let header = Header::read_from(length, reader).await?; let flags = MessageFlags::from_bits_truncate(reader.read_u32().await?); partial_message_state.bytes_remaining -= std::mem::size_of::(); From 77d897d5564c3a0cec04a18127cf13d7504354ea Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Fri, 1 May 2020 12:05:41 -0400 Subject: [PATCH 6/8] pass message state to header reading --- src/cmap/conn/wire/header.rs | 25 +++++++++++++++++++++---- src/cmap/conn/wire/message.rs | 5 +---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/cmap/conn/wire/header.rs b/src/cmap/conn/wire/header.rs index 12e22067f..887609547 100644 --- a/src/cmap/conn/wire/header.rs +++ b/src/cmap/conn/wire/header.rs @@ -1,6 +1,7 @@ use crate::runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite}; use crate::{ + cmap::conn::PartialMessageState, error::{ErrorKind, Result}, runtime::AsyncStream, }; @@ -51,12 +52,28 @@ impl Header { } /// Reads bytes from `r` and deserializes them into a header. - pub(crate) async fn read_from(length: i32, stream: &mut AsyncStream) -> Result { + pub(crate) async fn read_from( + stream: &mut AsyncStream, + partial_message_state: &mut PartialMessageState, + ) -> Result { + let length = stream.read_i32().await?; + partial_message_state.bytes_remaining = length as usize; + partial_message_state.needs_response = false; + + let request_id = stream.read_i32().await?; + partial_message_state.bytes_remaining -= 4; + + let response_to = stream.read_i32().await?; + partial_message_state.bytes_remaining -= 4; + + let op_code = stream.read_i32().await?; + partial_message_state.bytes_remaining -= 4; + Ok(Self { length, - request_id: stream.read_i32().await?, - response_to: stream.read_i32().await?, - op_code: OpCode::from_i32(stream.read_i32().await?)?, + request_id, + response_to, + op_code: OpCode::from_i32(op_code)?, }) } } diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index 164d11e1e..dc57fb2ec 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -79,11 +79,8 @@ impl Message { reader: &mut AsyncStream, partial_message_state: &mut PartialMessageState, ) -> Result { - let length = reader.read_i32().await?; - partial_message_state.bytes_remaining = length as usize - Header::LENGTH; - partial_message_state.needs_response = false; + let header = Header::read_from(reader, partial_message_state).await?; - let header = Header::read_from(length, reader).await?; let flags = MessageFlags::from_bits_truncate(reader.read_u32().await?); partial_message_state.bytes_remaining -= std::mem::size_of::(); From 67574d21e50aea556a805a40ba2514d9b8ce9763 Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Fri, 1 May 2020 12:36:04 -0400 Subject: [PATCH 7/8] subtract length of length field --- src/cmap/conn/wire/header.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/cmap/conn/wire/header.rs b/src/cmap/conn/wire/header.rs index 887609547..16c3af9e3 100644 --- a/src/cmap/conn/wire/header.rs +++ b/src/cmap/conn/wire/header.rs @@ -57,17 +57,17 @@ impl Header { partial_message_state: &mut PartialMessageState, ) -> Result { let length = stream.read_i32().await?; - partial_message_state.bytes_remaining = length as usize; + partial_message_state.bytes_remaining = length as usize - std::mem::size_of::(); partial_message_state.needs_response = false; let request_id = stream.read_i32().await?; - partial_message_state.bytes_remaining -= 4; + partial_message_state.bytes_remaining -= std::mem::size_of::(); let response_to = stream.read_i32().await?; - partial_message_state.bytes_remaining -= 4; + partial_message_state.bytes_remaining -= std::mem::size_of::(); let op_code = stream.read_i32().await?; - partial_message_state.bytes_remaining -= 4; + partial_message_state.bytes_remaining -= std::mem::size_of::(); Ok(Self { length, From b995f00ca973328c6b8050e4827198bcea8c2379 Mon Sep 17 00:00:00 2001 From: Saghm Rossi Date: Fri, 1 May 2020 12:36:15 -0400 Subject: [PATCH 8/8] close connection if write unfinished --- src/cmap/conn/mod.rs | 26 ++++++++++++++++++-------- src/cmap/conn/wire/header.rs | 8 +++++++- src/cmap/conn/wire/message.rs | 8 ++++++-- src/cmap/conn/wire/test.rs | 5 ++++- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/cmap/conn/mod.rs b/src/cmap/conn/mod.rs index b72a2eaa7..eb878b926 100644 --- a/src/cmap/conn/mod.rs +++ b/src/cmap/conn/mod.rs @@ -46,6 +46,7 @@ pub struct ConnectionInfo { pub(crate) struct PartialMessageState { bytes_remaining: usize, needs_response: bool, + unfinished_write: bool, } /// A wrapper around Stream that contains all the CMAP information needed to maintain a connection. @@ -227,8 +228,11 @@ impl Connection { } let message = Message::with_command(command, request_id.into()); - message.write_to(&mut self.stream).await?; + message + .write_to(&mut self.stream, &mut self.partial_message_state) + .await?; + self.partial_message_state.unfinished_write = false; self.partial_message_state.needs_response = true; let response_message = @@ -286,13 +290,19 @@ impl Drop for Connection { // dropped while it's not checked out. This means that the pool called the `close_and_drop` // helper explicitly, so we don't add it back to the pool or emit any events. if let Some(ref weak_pool_ref) = self.pool { - if let Some(strong_pool_ref) = weak_pool_ref.upgrade() { - let dropped_connection_state = self.take(); - RUNTIME.execute(async move { - strong_pool_ref - .check_in(dropped_connection_state.into()) - .await; - }); + // If there's an unfinished write on the connection, then we close the connection rather + // than returning it to the pool. We do this because finishing the write could + // potentially cause surprising side-effects for the user, who might expect + // the operation not to occur due to the future being dropped. + if !self.partial_message_state.unfinished_write { + if let Some(strong_pool_ref) = weak_pool_ref.upgrade() { + let dropped_connection_state = self.take(); + RUNTIME.execute(async move { + strong_pool_ref + .check_in(dropped_connection_state.into()) + .await; + }); + } } else { self.close(ConnectionClosedReason::PoolClosed); } diff --git a/src/cmap/conn/wire/header.rs b/src/cmap/conn/wire/header.rs index 16c3af9e3..4398424d3 100644 --- a/src/cmap/conn/wire/header.rs +++ b/src/cmap/conn/wire/header.rs @@ -42,8 +42,14 @@ impl Header { pub(crate) const LENGTH: usize = 4 * std::mem::size_of::(); /// Serializes the Header and writes the bytes to `w`. - pub(crate) async fn write_to(&self, stream: &mut AsyncStream) -> Result<()> { + pub(crate) async fn write_to( + &self, + stream: &mut AsyncStream, + partial_message_state: &mut PartialMessageState, + ) -> Result<()> { stream.write_i32(self.length).await?; + partial_message_state.unfinished_write = true; + stream.write_i32(self.request_id).await?; stream.write_i32(self.response_to).await?; stream.write_i32(self.op_code as i32).await?; diff --git a/src/cmap/conn/wire/message.rs b/src/cmap/conn/wire/message.rs index dc57fb2ec..a02585b82 100644 --- a/src/cmap/conn/wire/message.rs +++ b/src/cmap/conn/wire/message.rs @@ -122,7 +122,11 @@ impl Message { } /// Serializes the Message to bytes and writes them to `writer`. - pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result<()> { + pub(crate) async fn write_to( + &self, + writer: &mut AsyncStream, + partial_message_state: &mut PartialMessageState, + ) -> Result<()> { let mut sections_bytes = Vec::new(); for section in &self.sections { @@ -145,7 +149,7 @@ impl Message { op_code: OpCode::Message, }; - header.write_to(writer).await?; + header.write_to(writer, partial_message_state).await?; writer.write_u32(self.flags.bits()).await?; writer.write_all(§ions_bytes).await?; diff --git a/src/cmap/conn/wire/test.rs b/src/cmap/conn/wire/test.rs index 8a8c000fe..f94ae2512 100644 --- a/src/cmap/conn/wire/test.rs +++ b/src/cmap/conn/wire/test.rs @@ -33,7 +33,10 @@ async fn basic() { }; let mut stream = AsyncStream::connect(options).await.unwrap(); - message.write_to(&mut stream).await.unwrap(); + message + .write_to(&mut stream, &mut Default::default()) + .await + .unwrap(); let reply = Message::read_from(&mut stream, &mut Default::default()) .await