Skip to content

Commit

Permalink
refactor: read the expected number of bytes form recv stream
Browse files Browse the repository at this point in the history
- When reading from a stream, we first read the header which gives us
information about the exact number of bytes to read the reamining content,
thus when calling RecvStream::read_to_end we can use that number as a limit.
  • Loading branch information
bochaco committed Feb 22, 2023
1 parent 691f81c commit 1056286
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions src/wire_msg.rs
Expand Up @@ -40,40 +40,38 @@ impl WireMsg {

let msg_header = MsgHeader::from_bytes(header_bytes);

// https://github.com/rust-lang/rust/issues/70460 for work on a cleaner alternative:
#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
{
compile_error!("You need an architecture capable of addressing 32-bit pointers");
}
// we know we can convert without loss thanks to our assertions above
let header_length = msg_header.user_header_len() as usize;
let dst_length = msg_header.user_dst_len() as usize;
let payload_length = msg_header.user_payload_len() as usize;
let total_length = header_length + dst_length + payload_length;

let start = Instant::now();
let all_bytes = recv.read_to_end(1024 * 1024 * 100).await?;
let all_bytes = recv.read_to_end(total_length).await?;

let duration = start.elapsed();
trace!(
"Incoming new msg. Reading {:?} bytes took: {:?}",
"Incoming new msg. Reading {:?} bytes took: {duration:?}",
all_bytes.len(),
duration
);

if all_bytes.is_empty() {
return Err(RecvError::EmptyMsgPayload);
}

let mut bytes = Bytes::from(all_bytes);

// https://github.com/rust-lang/rust/issues/70460 for work on a cleaner alternative:
#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
{
compile_error!("You need an architecture capable of addressing 32-bit pointers");
}
// we know we can convert without loss thanks to our assertions above
let header_length = msg_header.user_header_len() as usize;
let dst_length = msg_header.user_dst_len() as usize;
let payload_length = msg_header.user_payload_len() as usize;

// Check we have all the data and we weren't cut short, otherwise
// the following would panic...
if bytes.len() != (header_length + dst_length + payload_length) {
if all_bytes.len() != total_length {
return Err(RecvError::NotEnoughBytes);
}

let mut bytes = Bytes::from(all_bytes);
let header_data = bytes.split_to(header_length);

let dst_data = bytes.split_to(dst_length);
let payload_data = bytes;

Expand Down

0 comments on commit 1056286

Please sign in to comment.