Skip to content

Commit

Permalink
Merge 847aa8a into 4f5d114
Browse files Browse the repository at this point in the history
  • Loading branch information
lionel-faber committed Sep 9, 2020
2 parents 4f5d114 + 847aa8a commit 9fe6da5
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 21 deletions.
27 changes: 13 additions & 14 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use super::{
api::Message,
error::{Error, Result},
wire_msg::WireMsg,
wire_msg::{MsgHeader, WireMsg, HEADER_LEN},
};
use bytes::Bytes;
use futures::{lock::Mutex, stream::StreamExt};
Expand Down Expand Up @@ -233,13 +233,15 @@ impl SendStream {

// Helper to read the message's bytes from the provided stream
async fn read_bytes(recv: &mut quinn::RecvStream) -> Result<Bytes> {
let mut data_len: [u8; 8] = [0; 8];
recv.read_exact(&mut data_len).await?;
let data_len = usize::from_le_bytes(data_len);
let mut data: Vec<u8> = vec![0; data_len];
let mut header_bytes = [0; HEADER_LEN];
recv.read_exact(&mut header_bytes).await?;

let msg_header = MsgHeader::from_bytes(header_bytes);
let mut data: Vec<u8> = vec![0; msg_header.data_len()];

recv.read_exact(&mut data).await?;
trace!("Got new message with {} bytes.", data.len());
match WireMsg::from_raw(data)? {
match WireMsg::from_raw(data, msg_header.usr_msg_flag())? {
WireMsg::UserMsg(msg_bytes) => Ok(Bytes::copy_from_slice(&msg_bytes)),
WireMsg::EndpointEchoReq | WireMsg::EndpointEchoResp(_) => {
// TODO: handle the echo request/response message
Expand All @@ -253,20 +255,17 @@ async fn send_msg(send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()>
// Let's generate the message bytes
let wire_msg = WireMsg::UserMsg(msg);
let (msg_bytes, msg_flag) = wire_msg.into();
trace!("Sending message to remote peer ({} bytes)", msg_bytes.len());

trace!("Sending message to remote peer ({} bytes)", msg_bytes.len(),);
let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?;
let header_bytes = msg_header.to_bytes();

// Send the length of the message + 1 (for the flag)
send_stream
.write_all(&(msg_bytes.len() + 1).to_le_bytes())
.await?;
// Send the message header
send_stream.write_all(&header_bytes).await?;

// Send message bytes over QUIC
send_stream.write_all(&msg_bytes[..]).await?;

// Then send message flag over QUIC
send_stream.write_all(&[msg_flag]).await?;

trace!("Message was sent to remote peer");

Ok(())
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum Error {
EmptyResponse,
#[error(display = "Type of the message received was not the expected one")]
UnexpectedMessageType,
#[error(display = "Maximum data length exceeded")]
MaxLengthExceeded,
#[error(display = "Unexpected: {}", 0)]
Unexpected(String),
}
80 changes: 73 additions & 7 deletions src/wire_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,73 @@ use serde::{Deserialize, Serialize};
use std::{fmt, net::SocketAddr};
use unwrap::unwrap;

pub(crate) const HEADER_LEN: usize = 9;
pub(crate) const VERSION: i16 = 0;
pub(crate) const MAX_DATA_LEN: usize = usize::from_be_bytes([0, 0, 0, 0, 255, 255, 255, 255]);

/// Message Header that is sent over the wire
/// Format of the message header is as follows
/// | version | message length | usr_msg_flag | reserved |
/// | 2 bytes | 4 bytes | 1 byte | 2 bytes |
pub(crate) struct MsgHeader {
version: i16,
data_len: usize,
usr_msg_flag: u8,
#[allow(unused)]
reserved: [u8; 2],
}

impl MsgHeader {
pub fn new(msg: &Bytes, usr_msg_flag: u8) -> Result<Self> {
let data_len = msg.len();
if data_len > MAX_DATA_LEN {
return Err(Error::MaxLengthExceeded);
}
Ok(Self {
version: VERSION,
data_len,
usr_msg_flag,
reserved: [0, 0],
})
}

pub fn data_len(&self) -> usize {
self.data_len
}

pub fn usr_msg_flag(&self) -> u8 {
self.usr_msg_flag
}

pub fn to_bytes(&self) -> [u8; HEADER_LEN] {
let version = self.version.to_be_bytes();
let data_len = self.data_len.to_be_bytes();
[
version[0],
version[1],
data_len[4],
data_len[5],
data_len[6],
data_len[7],
self.usr_msg_flag,
0,
0,
]
}

pub fn from_bytes(bytes: [u8; HEADER_LEN]) -> Self {
let version = i16::from_be_bytes([bytes[0], bytes[1]]);
let data_len = usize::from_be_bytes([0, 0, 0, 0, bytes[2], bytes[3], bytes[4], bytes[5]]);
let usr_msg_flag = bytes[6];
Self {
version,
data_len,
usr_msg_flag,
reserved: [0, 0],
}
}
}

/// Final type serialised and sent on the wire by QuicP2p
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum WireMsg {
Expand All @@ -39,16 +106,15 @@ impl Into<(Bytes, u8)> for WireMsg {
}

impl WireMsg {
pub fn from_raw(mut raw: Vec<u8>) -> Result<Self> {
pub fn from_raw(raw: Vec<u8>, msg_flag: u8) -> Result<Self> {
if raw.is_empty() {
Err(Error::EmptyResponse)
} else if msg_flag == USER_MSG_FLAG {
Ok(WireMsg::UserMsg(From::from(raw)))
} else if msg_flag == !USER_MSG_FLAG {
Ok(bincode::deserialize(&raw)?)
} else {
let msg_flag = raw.pop();
match msg_flag {
Some(flag) if flag == USER_MSG_FLAG => Ok(WireMsg::UserMsg(From::from(raw))),
Some(flag) if flag == !USER_MSG_FLAG => Ok(bincode::deserialize(&raw)?),
_ => Err(Error::InvalidWireMsgFlag),
}
Err(Error::InvalidWireMsgFlag)
}
}
}
Expand Down

0 comments on commit 9fe6da5

Please sign in to comment.