Skip to content

Commit

Permalink
refactor(wiremsg): minor refactor in WireMsg mod to make MsgHeader pr…
Browse files Browse the repository at this point in the history
…ivate
  • Loading branch information
bochaco committed Sep 11, 2020
1 parent 033f70f commit ecf9a0f
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 96 deletions.
4 changes: 2 additions & 2 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl QuicP2p {
/// config.port = Some(3000);
/// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?;
/// let endpoint = quic_p2p.new_endpoint()?;
/// let peer_addr = endpoint.local_address();
/// let peer_addr = endpoint.our_endpoint()?;
///
/// config.port = Some(3001);
/// let hcc = vec![peer_addr];
Expand Down Expand Up @@ -254,7 +254,7 @@ impl QuicP2p {
/// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?;
/// let peer_1 = quic_p2p.new_endpoint()?;
/// let peer1_addr = peer_1.local_address();
/// let peer1_addr = peer_1.our_endpoint()?;
///
/// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?;
/// Ok(())
Expand Down
33 changes: 7 additions & 26 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use super::{
api::Message,
error::{Error, Result},
wire_msg::{MsgHeader, WireMsg, HEADER_LEN},
wire_msg::WireMsg,
};
use bytes::Bytes;
use futures::{lock::Mutex, stream::StreamExt};
Expand Down Expand Up @@ -49,7 +49,7 @@ impl Connection {
/// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?;
/// let peer_1 = quic_p2p.new_endpoint()?;
/// let peer1_addr = peer_1.local_address();
/// let peer1_addr = peer_1.our_endpoint()?;
///
/// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?;
/// assert_eq!(connection.remote_address(), peer1_addr);
Expand All @@ -75,7 +75,7 @@ impl Connection {
/// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?;
/// let peer_1 = quic_p2p.new_endpoint()?;
/// let peer1_addr = peer_1.local_address();
/// let peer1_addr = peer_1.our_endpoint()?;
///
/// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?;
/// let (send_stream, recv_stream) = connection.open_bi_stream().await?;
Expand Down Expand Up @@ -275,16 +275,8 @@ impl SendStream {

// Helper to read the message's bytes from the provided stream
async fn read_bytes(recv: &mut quinn::RecvStream) -> Result<Bytes> {
let mut header_bytes = [0; HEADER_LEN];
recv.read_exact(&mut header_bytes).await?;

let msg_header = MsgHeader::from_bytes(header_bytes);
let mut data: Vec<u8> = vec![0; msg_header.data_len()];

recv.read_exact(&mut data).await?;
trace!("Got new message with {} bytes.", data.len());
match WireMsg::from_raw(data, msg_header.usr_msg_flag())? {
WireMsg::UserMsg(msg_bytes) => Ok(Bytes::copy_from_slice(&msg_bytes)),
match WireMsg::read_from_stream(recv).await? {
WireMsg::UserMsg(msg_bytes) => Ok(msg_bytes),
WireMsg::EndpointEchoReq | WireMsg::EndpointEchoResp(_) => {
// TODO: handle the echo request/response message
unimplemented!("echo message type not supported yet");
Expand All @@ -293,20 +285,9 @@ async fn read_bytes(recv: &mut quinn::RecvStream) -> Result<Bytes> {
}

// Helper to send bytes to peer using the provided stream.
async fn send_msg(send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> {
// Let's generate the message bytes
async fn send_msg(mut send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> {
let wire_msg = WireMsg::UserMsg(msg);
let (msg_bytes, msg_flag) = wire_msg.into();
trace!("Sending message to remote peer ({} bytes)", msg_bytes.len());

let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?;
let header_bytes = msg_header.to_bytes();

// Send the message header
send_stream.write_all(&header_bytes).await?;

// Send message bytes over QUIC
send_stream.write_all(&msg_bytes[..]).await?;
wire_msg.write_to_stream(&mut send_stream).await?;

trace!("Message was sent to remote peer");

Expand Down
11 changes: 6 additions & 5 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ impl Endpoint {
})
}

/// Endpoint local address
pub fn local_address(&self) -> SocketAddr {
self.local_addr
}

/// Get our connection adddress to give to others for them to connect to us.
///
/// Attempts to use UPnP to automatically find the public endpoint and forward a port.
Expand All @@ -62,6 +57,12 @@ impl Endpoint {
Ok(self.local_addr)
}

/// Endpoint local address
#[cfg(not(feature = "upnp"))]
pub fn our_endpoint(&self) -> Result<SocketAddr> {
Ok(self.local_addr)
}

/// Connect to another peer
pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result<Connection> {
let quinn_connecting = self.quic_endpoint.connect_with(
Expand Down
140 changes: 82 additions & 58 deletions src/wire_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,92 @@ use crate::{
utils,
};
use bytes::Bytes;
use log::trace;
use serde::{Deserialize, Serialize};
use std::{fmt, net::SocketAddr};
use unwrap::unwrap;

pub(crate) const HEADER_LEN: usize = 9;
pub(crate) const VERSION: i16 = 0;
pub(crate) const MAX_DATA_LEN: usize = usize::from_be_bytes([0, 0, 0, 0, 255, 255, 255, 255]);
const MSG_HEADER_LEN: usize = 9;
const MSG_PROTOCOL_VERSION: u16 = 0x0001;

/// Final type serialised and sent on the wire by QuicP2p
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum WireMsg {
EndpointEchoReq,
EndpointEchoResp(SocketAddr),
UserMsg(Bytes),
}

const USER_MSG_FLAG: u8 = 0x00;
const ECHO_SRVC_MSG_FLAG: u8 = 0x01;

impl WireMsg {
// Read a message's bytes from the provided stream
pub async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result<Self> {
let mut header_bytes = [0; MSG_HEADER_LEN];
recv.read_exact(&mut header_bytes).await?;

let msg_header = MsgHeader::from_bytes(header_bytes);
let mut data: Vec<u8> = vec![0; msg_header.data_len()];
let msg_flag = msg_header.usr_msg_flag();

recv.read_exact(&mut data).await?;
trace!("Got new message with {} bytes.", data.len());

if data.is_empty() {
Err(Error::EmptyResponse)
} else if msg_flag == USER_MSG_FLAG {
Ok(WireMsg::UserMsg(From::from(data)))
} else if msg_flag == ECHO_SRVC_MSG_FLAG {
Ok(bincode::deserialize(&data)?)
} else {
Err(Error::InvalidWireMsgFlag)
}
}

// Helper to write WireMsg bytes to the provided stream.
pub async fn write_to_stream(&self, send_stream: &mut quinn::SendStream) -> Result<()> {
// Let's generate the message bytes
let (msg_bytes, msg_flag) = match self {
WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG),
_ => (
From::from(unwrap!(bincode::serialize(&self))),
!USER_MSG_FLAG,
),
};
trace!("Sending message to remote peer ({} bytes)", msg_bytes.len());

let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?;
let header_bytes = msg_header.to_bytes();

// Send the header bytes over QUIC
send_stream.write_all(&header_bytes).await?;

// Send message bytes over QUIC
send_stream.write_all(&msg_bytes[..]).await?;

Ok(())
}
}

impl fmt::Display for WireMsg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
WireMsg::UserMsg(ref m) => {
write!(f, "WireMsg::UserMsg({})", utils::bin_data_format(&*m))
}
WireMsg::EndpointEchoReq => write!(f, "WireMsg::EndpointEchoReq"),
WireMsg::EndpointEchoResp(ref sa) => write!(f, "WireMsg::EndpointEchoResp({})", sa),
}
}
}

/// Message Header that is sent over the wire
/// Format of the message header is as follows
/// | version | message length | usr_msg_flag | reserved |
/// | 2 bytes | 4 bytes | 1 byte | 2 bytes |
pub(crate) struct MsgHeader {
version: i16,
struct MsgHeader {
version: u16,
data_len: usize,
usr_msg_flag: u8,
#[allow(unused)]
Expand All @@ -35,11 +107,11 @@ pub(crate) struct MsgHeader {
impl MsgHeader {
pub fn new(msg: &Bytes, usr_msg_flag: u8) -> Result<Self> {
let data_len = msg.len();
if data_len > MAX_DATA_LEN {
if data_len > u32::MAX as usize {
return Err(Error::MaxLengthExceeded);
}
Ok(Self {
version: VERSION,
version: MSG_PROTOCOL_VERSION,
data_len,
usr_msg_flag,
reserved: [0, 0],
Expand All @@ -54,7 +126,7 @@ impl MsgHeader {
self.usr_msg_flag
}

pub fn to_bytes(&self) -> [u8; HEADER_LEN] {
pub fn to_bytes(&self) -> [u8; MSG_HEADER_LEN] {
let version = self.version.to_be_bytes();
let data_len = self.data_len.to_be_bytes();
[
Expand All @@ -70,8 +142,8 @@ impl MsgHeader {
]
}

pub fn from_bytes(bytes: [u8; HEADER_LEN]) -> Self {
let version = i16::from_be_bytes([bytes[0], bytes[1]]);
pub fn from_bytes(bytes: [u8; MSG_HEADER_LEN]) -> Self {
let version = u16::from_be_bytes([bytes[0], bytes[1]]);
let data_len = usize::from_be_bytes([0, 0, 0, 0, bytes[2], bytes[3], bytes[4], bytes[5]]);
let usr_msg_flag = bytes[6];
Self {
Expand All @@ -82,51 +154,3 @@ impl MsgHeader {
}
}
}

/// Final type serialised and sent on the wire by QuicP2p
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum WireMsg {
EndpointEchoReq,
EndpointEchoResp(SocketAddr),
UserMsg(Bytes),
}

const USER_MSG_FLAG: u8 = 0;

impl Into<(Bytes, u8)> for WireMsg {
fn into(self) -> (Bytes, u8) {
match self {
WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG),
_ => (
From::from(unwrap!(bincode::serialize(&self))),
!USER_MSG_FLAG,
),
}
}
}

impl WireMsg {
pub fn from_raw(raw: Vec<u8>, msg_flag: u8) -> Result<Self> {
if raw.is_empty() {
Err(Error::EmptyResponse)
} else if msg_flag == USER_MSG_FLAG {
Ok(WireMsg::UserMsg(From::from(raw)))
} else if msg_flag == !USER_MSG_FLAG {
Ok(bincode::deserialize(&raw)?)
} else {
Err(Error::InvalidWireMsgFlag)
}
}
}

impl fmt::Display for WireMsg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
WireMsg::UserMsg(ref m) => {
write!(f, "WireMsg::UserMsg({})", utils::bin_data_format(&*m))
}
WireMsg::EndpointEchoReq => write!(f, "WireMsg::EndpointEchoReq"),
WireMsg::EndpointEchoResp(ref sa) => write!(f, "WireMsg::EndpointEchoResp({})", sa),
}
}
}
10 changes: 5 additions & 5 deletions tests/quic_p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn random_msg() -> Bytes {
async fn successful_connection() -> Result<()> {
let qp2p = new_qp2p();
let peer1 = qp2p.new_endpoint()?;
let peer1_addr = peer1.local_address();
let peer1_addr = peer1.our_endpoint()?;

let peer2 = qp2p.new_endpoint()?;
let _connection = peer2.connect_to(&peer1_addr).await?;
Expand All @@ -45,7 +45,7 @@ async fn successful_connection() -> Result<()> {
.await
.ok_or_else(|| Error::Unexpected("No incoming connection".to_string()))?;

assert_eq!(incoming_messages.remote_addr(), peer2.local_address());
assert_eq!(incoming_messages.remote_addr(), peer2.our_endpoint()?);

Ok(())
}
Expand All @@ -54,7 +54,7 @@ async fn successful_connection() -> Result<()> {
async fn bi_directional_streams() -> Result<()> {
let qp2p = new_qp2p();
let peer1 = qp2p.new_endpoint()?;
let peer1_addr = peer1.local_address();
let peer1_addr = peer1.our_endpoint()?;

let peer2 = qp2p.new_endpoint()?;
let connection = peer2.connect_to(&peer1_addr).await?;
Expand Down Expand Up @@ -110,11 +110,11 @@ async fn bi_directional_streams() -> Result<()> {
async fn uni_directional_streams() -> Result<()> {
let qp2p = new_qp2p();
let peer1 = qp2p.new_endpoint()?;
let peer1_addr = peer1.local_address();
let peer1_addr = peer1.our_endpoint()?;
let mut incoming_conn_peer1 = peer1.listen()?;

let peer2 = qp2p.new_endpoint()?;
let peer2_addr = peer2.local_address();
let peer2_addr = peer2.our_endpoint()?;
let mut incoming_conn_peer2 = peer2.listen()?;

// Peer 2 sends a message
Expand Down

0 comments on commit ecf9a0f

Please sign in to comment.