Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ sha2 = "0.8.0"
stringprep = "0.1.2"
time = "0.1.42"
trust-dns-proto = "0.19.4"
trust-dns-resolver = "0.19.0"
trust-dns-resolver = "0.19.4"
typed-builder = "0.3.0"
version_check = "0.9.1"
webpki = "0.21.0"
Expand Down
54 changes: 45 additions & 9 deletions src/cmap/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
};

use derivative::Derivative;
use futures::io::AsyncReadExt;

use self::wire::Message;
use super::ConnectionPoolInner;
Expand Down Expand Up @@ -41,6 +42,13 @@ pub struct ConnectionInfo {
pub address: StreamAddress,
}

#[derive(Clone, Debug, Default)]
pub(crate) struct PartialMessageState {
bytes_remaining: usize,
needs_response: bool,
unfinished_write: bool,
}

/// A wrapper around Stream that contains all the CMAP information needed to maintain a connection.
#[derive(Derivative)]
#[derivative(Debug)]
Expand All @@ -63,6 +71,8 @@ pub(crate) struct Connection {

stream: AsyncStream,

partial_message_state: PartialMessageState,

#[derivative(Debug = "ignore")]
handler: Option<Arc<dyn CmapEventHandler>>,
}
Expand Down Expand Up @@ -90,6 +100,7 @@ impl Connection {
address,
handler: options.and_then(|options| options.event_handler),
stream_description: None,
partial_message_state: Default::default(),
};

Ok(conn)
Expand Down Expand Up @@ -206,10 +217,26 @@ impl Connection {
command: Command,
request_id: impl Into<Option<i32>>,
) -> Result<CommandResponse> {
if self.partial_message_state.needs_response {
Message::read_from(&mut self.stream, &mut self.partial_message_state).await?;
}

if self.partial_message_state.bytes_remaining > 0 {
let mut bytes = vec![0u8; self.partial_message_state.bytes_remaining];
self.stream.read_exact(&mut bytes).await?;
self.partial_message_state.bytes_remaining = 0;
}

let message = Message::with_command(command, request_id.into());
message.write_to(&mut self.stream).await?;
message
.write_to(&mut self.stream, &mut self.partial_message_state)
.await?;

self.partial_message_state.unfinished_write = false;
self.partial_message_state.needs_response = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could the future be dropped between after the write_to but before this needs_response gets set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the write_to completes, the future is not yielding to the runtime, so it will keep executing at least until it hits another await, meaning the boolean will always get set if the call to write_to completes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the future gets dropped midway through write_to? We'll probably need a way to flush the unsent data before starting to send new stuff, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we send an only partially complete message to the server on a connection, then I'm not sure there's much we can do about that. I think the only two options to finish sending the message on the next write and then read and throw away the response or to just close the connection. My instinct is that an unfinished write will be a lot less common than an unfinished read, as there isn't going to be as much of a delay in writing the data as there is for waiting for the response to come back, so closing the connection in this case doesn't seem like quite as big a deal. I think there's also a case to be made that sending an unfinished message when the user stopped awaiting the future is a lot more likely to cause unexpected behavior for the user, since unlike reading the response of an already sent message, completing a sent message could have side effects on the database.

Because of this, I've updated the PR to drop the connection in the case of an unfinished write.


let response_message = Message::read_from(&mut self.stream).await?;
let response_message =
Message::read_from(&mut self.stream, &mut self.partial_message_state).await?;
CommandResponse::new(self.address.clone(), response_message)
}

Expand Down Expand Up @@ -247,6 +274,7 @@ impl Connection {
stream: std::mem::replace(&mut self.stream, AsyncStream::Null),
handler: self.handler.take(),
stream_description: self.stream_description.take(),
partial_message_state: self.partial_message_state.clone(),
}
}
}
Expand All @@ -262,13 +290,19 @@ impl Drop for Connection {
// dropped while it's not checked out. This means that the pool called the `close_and_drop`
// helper explicitly, so we don't add it back to the pool or emit any events.
if let Some(ref weak_pool_ref) = self.pool {
if let Some(strong_pool_ref) = weak_pool_ref.upgrade() {
let dropped_connection_state = self.take();
RUNTIME.execute(async move {
strong_pool_ref
.check_in(dropped_connection_state.into())
.await;
});
// If there's an unfinished write on the connection, then we close the connection rather
// than returning it to the pool. We do this because finishing the write could
// potentially cause surprising side-effects for the user, who might expect
// the operation not to occur due to the future being dropped.
if !self.partial_message_state.unfinished_write {
if let Some(strong_pool_ref) = weak_pool_ref.upgrade() {
let dropped_connection_state = self.take();
RUNTIME.execute(async move {
strong_pool_ref
.check_in(dropped_connection_state.into())
.await;
});
}
} else {
self.close(ConnectionClosedReason::PoolClosed);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this closed reason needs to be changed, maybe to Error, or perhaps to new closed reason case (is this okay with CMAP?). The PoolClosed one needs to move to an else on the upgrade attempt.

}
Expand All @@ -295,6 +329,7 @@ struct DroppedConnectionState {
#[derivative(Debug = "ignore")]
handler: Option<Arc<dyn CmapEventHandler>>,
stream_description: Option<StreamDescription>,
partial_message_state: PartialMessageState,
}

impl Drop for DroppedConnectionState {
Expand Down Expand Up @@ -323,6 +358,7 @@ impl From<DroppedConnectionState> for Connection {
stream_description: state.stream_description.take(),
ready_and_available_time: None,
pool: None,
partial_message_state: state.partial_message_state.clone(),
}
}
}
35 changes: 29 additions & 6 deletions src/cmap/conn/wire/header.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite};

use crate::{
cmap::conn::PartialMessageState,
error::{ErrorKind, Result},
runtime::AsyncStream,
};
Expand Down Expand Up @@ -41,8 +42,14 @@ impl Header {
pub(crate) const LENGTH: usize = 4 * std::mem::size_of::<i32>();

/// Serializes the Header and writes the bytes to `w`.
pub(crate) async fn write_to(&self, stream: &mut AsyncStream) -> Result<()> {
pub(crate) async fn write_to(
&self,
stream: &mut AsyncStream,
partial_message_state: &mut PartialMessageState,
) -> Result<()> {
stream.write_i32(self.length).await?;
partial_message_state.unfinished_write = true;

stream.write_i32(self.request_id).await?;
stream.write_i32(self.response_to).await?;
stream.write_i32(self.op_code as i32).await?;
Expand All @@ -51,12 +58,28 @@ impl Header {
}

/// Reads bytes from `r` and deserializes them into a header.
pub(crate) async fn read_from(stream: &mut AsyncStream) -> Result<Self> {
pub(crate) async fn read_from(
stream: &mut AsyncStream,
partial_message_state: &mut PartialMessageState,
) -> Result<Self> {
let length = stream.read_i32().await?;
partial_message_state.bytes_remaining = length as usize - std::mem::size_of::<i32>();
partial_message_state.needs_response = false;

let request_id = stream.read_i32().await?;
partial_message_state.bytes_remaining -= std::mem::size_of::<i32>();

let response_to = stream.read_i32().await?;
partial_message_state.bytes_remaining -= std::mem::size_of::<i32>();

let op_code = stream.read_i32().await?;
partial_message_state.bytes_remaining -= std::mem::size_of::<i32>();

Ok(Self {
length: stream.read_i32().await?,
request_id: stream.read_i32().await?,
response_to: stream.read_i32().await?,
op_code: OpCode::from_i32(stream.read_i32().await?)?,
length,
request_id,
response_to,
op_code: OpCode::from_i32(op_code)?,
})
}
}
33 changes: 21 additions & 12 deletions src/cmap/conn/wire/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{
};
use crate::{
bson_util::async_encoding,
cmap::conn::command::Command,
cmap::conn::{command::Command, PartialMessageState},
error::{ErrorKind, Result},
runtime::{AsyncLittleEndianRead, AsyncLittleEndianWrite, AsyncStream},
};
Expand Down Expand Up @@ -75,33 +75,38 @@ impl Message {
}

/// Reads bytes from `reader` and deserializes them into a Message.
pub(crate) async fn read_from(reader: &mut AsyncStream) -> Result<Self> {
let header = Header::read_from(reader).await?;
let mut length_remaining = header.length - Header::LENGTH as i32;
pub(crate) async fn read_from(
reader: &mut AsyncStream,
partial_message_state: &mut PartialMessageState,
) -> Result<Self> {
let header = Header::read_from(reader, partial_message_state).await?;

let flags = MessageFlags::from_bits_truncate(reader.read_u32().await?);
length_remaining -= std::mem::size_of::<u32>() as i32;
partial_message_state.bytes_remaining -= std::mem::size_of::<u32>();

let mut count_reader = CountReader::new(reader);
let mut sections = Vec::new();

while length_remaining - count_reader.bytes_read() as i32 > 4 {
while partial_message_state.bytes_remaining - count_reader.bytes_read() > 4 {
sections.push(MessageSection::read(&mut count_reader).await?);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the state needs to be propagated down here too. I wonder if this would be simpler if we just spawned a task via RUNTIME.spawn in execute_operation to guarantee everything got polled to completion. Doing all this partial reading and passing of state around seems to be pretty error-prone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on what you mean by spawning a task in execute operation? Do you mean running send_command in a background task and then using a channel to get the result back?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, except a channel wouldn't be needed. We could just use our AsyncJoinHandle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, it would probably be safest to do the entirety of execute_operation's body in a task, so that the various things it needs to do based on operation execution happen (like updating the cluster time, clearing the pool on certain errors, etc.)

Copy link
Contributor Author

@saghm saghm May 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried to play around with this; it runs into some issues due to the fact that we can only spawn futures that are 'static, which won't be the case if we need to capture over any references. When trying to work around this, I ran into issues with the cursor's getMore operation, so I tried rebasing your sessions PR to see if the new cursor implementation avoided this issue, but this actually made things even more difficult due to the need to capture a reference to the session, which doesn't have the static lifetime.

We might be able to make this less error-prone by doing it on our AsyncStream type itself, which would make it impossible to accidentally read or write without checking the partial state.

}

length_remaining -= count_reader.bytes_read() as i32;
partial_message_state.bytes_remaining -= count_reader.bytes_read();

let mut checksum = None;

if length_remaining == 4 && flags.contains(MessageFlags::CHECKSUM_PRESENT) {
if partial_message_state.bytes_remaining == 4
&& flags.contains(MessageFlags::CHECKSUM_PRESENT)
{
checksum = Some(reader.read_u32().await?);
} else if length_remaining != 0 {
} else if partial_message_state.bytes_remaining != 0 {
return Err(ErrorKind::OperationError {
message: format!(
"The server indicated that the reply would be {} bytes long, but it instead \
was {}",
header.length,
header.length - length_remaining + count_reader.bytes_read() as i32,
header.length as usize - partial_message_state.bytes_remaining
+ count_reader.bytes_read(),
),
}
.into());
Expand All @@ -117,7 +122,11 @@ impl Message {
}

/// Serializes the Message to bytes and writes them to `writer`.
pub(crate) async fn write_to(&self, writer: &mut AsyncStream) -> Result<()> {
pub(crate) async fn write_to(
&self,
writer: &mut AsyncStream,
partial_message_state: &mut PartialMessageState,
) -> Result<()> {
let mut sections_bytes = Vec::new();

for section in &self.sections {
Expand All @@ -140,7 +149,7 @@ impl Message {
op_code: OpCode::Message,
};

header.write_to(writer).await?;
header.write_to(writer, partial_message_state).await?;
writer.write_u32(self.flags.bits()).await?;
writer.write_all(&sections_bytes).await?;

Expand Down
11 changes: 8 additions & 3 deletions src/cmap/conn/wire/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ async fn basic() {
};

let mut stream = AsyncStream::connect(options).await.unwrap();
message.write_to(&mut stream).await.unwrap();

let reply = Message::read_from(&mut stream).await.unwrap();
message
.write_to(&mut stream, &mut Default::default())
.await
.unwrap();

let reply = Message::read_from(&mut stream, &mut Default::default())
.await
.unwrap();

let response_doc = match reply.sections.into_iter().next().unwrap() {
MessageSection::Document(doc) => doc,
Expand Down
41 changes: 41 additions & 0 deletions src/test/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
selection_criteria::{ReadPreference, SelectionCriteria},
test::{util::TestClient, CLIENT_OPTIONS, LOCK},
Client,
RUNTIME,
};

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -532,3 +533,43 @@ async fn saslprep_uri() {
auth_test_uri("%E2%85%A8", "IV", None, true).await;
auth_test_uri("%E2%85%A8", "I%C2%ADV", None, true).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn future_drop_flush_response() {
let _guard = LOCK.run_concurrently().await;

let options = CLIENT_OPTIONS.clone();

let client = Client::with_options(options.clone()).unwrap();
let db = client.database("test");

db.collection("foo")
.insert_one(doc! { "x": 1 }, None)
.await
.unwrap();

let _: Result<_, _> = RUNTIME
.timeout(
Duration::from_millis(50),
db.run_command(
doc! { "count": "foo",
"query": {
"$where": "sleep(100) && true"
}
},
None,
),
)
.await;

RUNTIME.delay_for(Duration::from_millis(200)).await;

let is_master_response = db.run_command(doc! { "isMaster": 1 }, None).await;

// Ensure that the response to `isMaster` is read, not the response to `count`.
assert!(is_master_response
.ok()
.and_then(|value| value.get("ismaster").and_then(|value| value.as_bool()))
.is_some());
}