From ecf9a0fbff961780dbc1a28ecee011e19e72e4a1 Mon Sep 17 00:00:00 2001 From: bochaco Date: Fri, 11 Sep 2020 16:09:17 -0500 Subject: [PATCH] refactor(wiremsg): minor refactor in WireMsg mod to make MsgHeader private --- src/api.rs | 4 +- src/connections.rs | 33 +++-------- src/endpoint.rs | 11 ++-- src/wire_msg.rs | 140 ++++++++++++++++++++++++++------------------- tests/quic_p2p.rs | 10 ++-- 5 files changed, 102 insertions(+), 96 deletions(-) diff --git a/src/api.rs b/src/api.rs index be195f32..c5d9773b 100644 --- a/src/api.rs +++ b/src/api.rs @@ -188,7 +188,7 @@ impl QuicP2p { /// config.port = Some(3000); /// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?; /// let endpoint = quic_p2p.new_endpoint()?; - /// let peer_addr = endpoint.local_address(); + /// let peer_addr = endpoint.our_endpoint()?; /// /// config.port = Some(3001); /// let hcc = vec![peer_addr]; @@ -254,7 +254,7 @@ impl QuicP2p { /// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST)); /// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?; /// let peer_1 = quic_p2p.new_endpoint()?; - /// let peer1_addr = peer_1.local_address(); + /// let peer1_addr = peer_1.our_endpoint()?; /// /// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?; /// Ok(()) diff --git a/src/connections.rs b/src/connections.rs index af476ad5..63e29791 100644 --- a/src/connections.rs +++ b/src/connections.rs @@ -10,7 +10,7 @@ use super::{ api::Message, error::{Error, Result}, - wire_msg::{MsgHeader, WireMsg, HEADER_LEN}, + wire_msg::WireMsg, }; use bytes::Bytes; use futures::{lock::Mutex, stream::StreamExt}; @@ -49,7 +49,7 @@ impl Connection { /// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST)); /// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?; /// let peer_1 = quic_p2p.new_endpoint()?; - /// let peer1_addr = peer_1.local_address(); + /// let peer1_addr = peer_1.our_endpoint()?; /// /// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?; /// assert_eq!(connection.remote_address(), peer1_addr); @@ -75,7 +75,7 @@ impl Connection { /// config.ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST)); /// let mut quic_p2p = QuicP2p::with_config(Some(config.clone()), Default::default(), true)?; /// let peer_1 = quic_p2p.new_endpoint()?; - /// let peer1_addr = peer_1.local_address(); + /// let peer1_addr = peer_1.our_endpoint()?; /// /// let (peer_2, connection) = quic_p2p.connect_to(&peer1_addr).await?; /// let (send_stream, recv_stream) = connection.open_bi_stream().await?; @@ -275,16 +275,8 @@ impl SendStream { // Helper to read the message's bytes from the provided stream async fn read_bytes(recv: &mut quinn::RecvStream) -> Result { - let mut header_bytes = [0; HEADER_LEN]; - recv.read_exact(&mut header_bytes).await?; - - let msg_header = MsgHeader::from_bytes(header_bytes); - let mut data: Vec = vec![0; msg_header.data_len()]; - - recv.read_exact(&mut data).await?; - trace!("Got new message with {} bytes.", data.len()); - match WireMsg::from_raw(data, msg_header.usr_msg_flag())? { - WireMsg::UserMsg(msg_bytes) => Ok(Bytes::copy_from_slice(&msg_bytes)), + match WireMsg::read_from_stream(recv).await? { + WireMsg::UserMsg(msg_bytes) => Ok(msg_bytes), WireMsg::EndpointEchoReq | WireMsg::EndpointEchoResp(_) => { // TODO: handle the echo request/response message unimplemented!("echo message type not supported yet"); @@ -293,20 +285,9 @@ async fn read_bytes(recv: &mut quinn::RecvStream) -> Result { } // Helper to send bytes to peer using the provided stream. -async fn send_msg(send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> { - // Let's generate the message bytes +async fn send_msg(mut send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> { let wire_msg = WireMsg::UserMsg(msg); - let (msg_bytes, msg_flag) = wire_msg.into(); - trace!("Sending message to remote peer ({} bytes)", msg_bytes.len()); - - let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?; - let header_bytes = msg_header.to_bytes(); - - // Send the message header - send_stream.write_all(&header_bytes).await?; - - // Send message bytes over QUIC - send_stream.write_all(&msg_bytes[..]).await?; + wire_msg.write_to_stream(&mut send_stream).await?; trace!("Message was sent to remote peer"); diff --git a/src/endpoint.rs b/src/endpoint.rs index c9780aa3..deefc7d0 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -44,11 +44,6 @@ impl Endpoint { }) } - /// Endpoint local address - pub fn local_address(&self) -> SocketAddr { - self.local_addr - } - /// Get our connection adddress to give to others for them to connect to us. /// /// Attempts to use UPnP to automatically find the public endpoint and forward a port. @@ -62,6 +57,12 @@ impl Endpoint { Ok(self.local_addr) } + /// Endpoint local address + #[cfg(not(feature = "upnp"))] + pub fn our_endpoint(&self) -> Result { + Ok(self.local_addr) + } + /// Connect to another peer pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result { let quinn_connecting = self.quic_endpoint.connect_with( diff --git a/src/wire_msg.rs b/src/wire_msg.rs index 390c1839..d3a1ea0d 100644 --- a/src/wire_msg.rs +++ b/src/wire_msg.rs @@ -12,20 +12,92 @@ use crate::{ utils, }; use bytes::Bytes; +use log::trace; use serde::{Deserialize, Serialize}; use std::{fmt, net::SocketAddr}; use unwrap::unwrap; -pub(crate) const HEADER_LEN: usize = 9; -pub(crate) const VERSION: i16 = 0; -pub(crate) const MAX_DATA_LEN: usize = usize::from_be_bytes([0, 0, 0, 0, 255, 255, 255, 255]); +const MSG_HEADER_LEN: usize = 9; +const MSG_PROTOCOL_VERSION: u16 = 0x0001; + +/// Final type serialised and sent on the wire by QuicP2p +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum WireMsg { + EndpointEchoReq, + EndpointEchoResp(SocketAddr), + UserMsg(Bytes), +} + +const USER_MSG_FLAG: u8 = 0x00; +const ECHO_SRVC_MSG_FLAG: u8 = 0x01; + +impl WireMsg { + // Read a message's bytes from the provided stream + pub async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result { + let mut header_bytes = [0; MSG_HEADER_LEN]; + recv.read_exact(&mut header_bytes).await?; + + let msg_header = MsgHeader::from_bytes(header_bytes); + let mut data: Vec = vec![0; msg_header.data_len()]; + let msg_flag = msg_header.usr_msg_flag(); + + recv.read_exact(&mut data).await?; + trace!("Got new message with {} bytes.", data.len()); + + if data.is_empty() { + Err(Error::EmptyResponse) + } else if msg_flag == USER_MSG_FLAG { + Ok(WireMsg::UserMsg(From::from(data))) + } else if msg_flag == ECHO_SRVC_MSG_FLAG { + Ok(bincode::deserialize(&data)?) + } else { + Err(Error::InvalidWireMsgFlag) + } + } + + // Helper to write WireMsg bytes to the provided stream. + pub async fn write_to_stream(&self, send_stream: &mut quinn::SendStream) -> Result<()> { + // Let's generate the message bytes + let (msg_bytes, msg_flag) = match self { + WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG), + _ => ( + From::from(unwrap!(bincode::serialize(&self))), + !USER_MSG_FLAG, + ), + }; + trace!("Sending message to remote peer ({} bytes)", msg_bytes.len()); + + let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?; + let header_bytes = msg_header.to_bytes(); + + // Send the header bytes over QUIC + send_stream.write_all(&header_bytes).await?; + + // Send message bytes over QUIC + send_stream.write_all(&msg_bytes[..]).await?; + + Ok(()) + } +} + +impl fmt::Display for WireMsg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + WireMsg::UserMsg(ref m) => { + write!(f, "WireMsg::UserMsg({})", utils::bin_data_format(&*m)) + } + WireMsg::EndpointEchoReq => write!(f, "WireMsg::EndpointEchoReq"), + WireMsg::EndpointEchoResp(ref sa) => write!(f, "WireMsg::EndpointEchoResp({})", sa), + } + } +} /// Message Header that is sent over the wire /// Format of the message header is as follows /// | version | message length | usr_msg_flag | reserved | /// | 2 bytes | 4 bytes | 1 byte | 2 bytes | -pub(crate) struct MsgHeader { - version: i16, +struct MsgHeader { + version: u16, data_len: usize, usr_msg_flag: u8, #[allow(unused)] @@ -35,11 +107,11 @@ pub(crate) struct MsgHeader { impl MsgHeader { pub fn new(msg: &Bytes, usr_msg_flag: u8) -> Result { let data_len = msg.len(); - if data_len > MAX_DATA_LEN { + if data_len > u32::MAX as usize { return Err(Error::MaxLengthExceeded); } Ok(Self { - version: VERSION, + version: MSG_PROTOCOL_VERSION, data_len, usr_msg_flag, reserved: [0, 0], @@ -54,7 +126,7 @@ impl MsgHeader { self.usr_msg_flag } - pub fn to_bytes(&self) -> [u8; HEADER_LEN] { + pub fn to_bytes(&self) -> [u8; MSG_HEADER_LEN] { let version = self.version.to_be_bytes(); let data_len = self.data_len.to_be_bytes(); [ @@ -70,8 +142,8 @@ impl MsgHeader { ] } - pub fn from_bytes(bytes: [u8; HEADER_LEN]) -> Self { - let version = i16::from_be_bytes([bytes[0], bytes[1]]); + pub fn from_bytes(bytes: [u8; MSG_HEADER_LEN]) -> Self { + let version = u16::from_be_bytes([bytes[0], bytes[1]]); let data_len = usize::from_be_bytes([0, 0, 0, 0, bytes[2], bytes[3], bytes[4], bytes[5]]); let usr_msg_flag = bytes[6]; Self { @@ -82,51 +154,3 @@ impl MsgHeader { } } } - -/// Final type serialised and sent on the wire by QuicP2p -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum WireMsg { - EndpointEchoReq, - EndpointEchoResp(SocketAddr), - UserMsg(Bytes), -} - -const USER_MSG_FLAG: u8 = 0; - -impl Into<(Bytes, u8)> for WireMsg { - fn into(self) -> (Bytes, u8) { - match self { - WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG), - _ => ( - From::from(unwrap!(bincode::serialize(&self))), - !USER_MSG_FLAG, - ), - } - } -} - -impl WireMsg { - pub fn from_raw(raw: Vec, msg_flag: u8) -> Result { - if raw.is_empty() { - Err(Error::EmptyResponse) - } else if msg_flag == USER_MSG_FLAG { - Ok(WireMsg::UserMsg(From::from(raw))) - } else if msg_flag == !USER_MSG_FLAG { - Ok(bincode::deserialize(&raw)?) - } else { - Err(Error::InvalidWireMsgFlag) - } - } -} - -impl fmt::Display for WireMsg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - WireMsg::UserMsg(ref m) => { - write!(f, "WireMsg::UserMsg({})", utils::bin_data_format(&*m)) - } - WireMsg::EndpointEchoReq => write!(f, "WireMsg::EndpointEchoReq"), - WireMsg::EndpointEchoResp(ref sa) => write!(f, "WireMsg::EndpointEchoResp({})", sa), - } - } -} diff --git a/tests/quic_p2p.rs b/tests/quic_p2p.rs index 7973c9ad..7ad675c8 100644 --- a/tests/quic_p2p.rs +++ b/tests/quic_p2p.rs @@ -34,7 +34,7 @@ fn random_msg() -> Bytes { async fn successful_connection() -> Result<()> { let qp2p = new_qp2p(); let peer1 = qp2p.new_endpoint()?; - let peer1_addr = peer1.local_address(); + let peer1_addr = peer1.our_endpoint()?; let peer2 = qp2p.new_endpoint()?; let _connection = peer2.connect_to(&peer1_addr).await?; @@ -45,7 +45,7 @@ async fn successful_connection() -> Result<()> { .await .ok_or_else(|| Error::Unexpected("No incoming connection".to_string()))?; - assert_eq!(incoming_messages.remote_addr(), peer2.local_address()); + assert_eq!(incoming_messages.remote_addr(), peer2.our_endpoint()?); Ok(()) } @@ -54,7 +54,7 @@ async fn successful_connection() -> Result<()> { async fn bi_directional_streams() -> Result<()> { let qp2p = new_qp2p(); let peer1 = qp2p.new_endpoint()?; - let peer1_addr = peer1.local_address(); + let peer1_addr = peer1.our_endpoint()?; let peer2 = qp2p.new_endpoint()?; let connection = peer2.connect_to(&peer1_addr).await?; @@ -110,11 +110,11 @@ async fn bi_directional_streams() -> Result<()> { async fn uni_directional_streams() -> Result<()> { let qp2p = new_qp2p(); let peer1 = qp2p.new_endpoint()?; - let peer1_addr = peer1.local_address(); + let peer1_addr = peer1.our_endpoint()?; let mut incoming_conn_peer1 = peer1.listen()?; let peer2 = qp2p.new_endpoint()?; - let peer2_addr = peer2.local_address(); + let peer2_addr = peer2.our_endpoint()?; let mut incoming_conn_peer2 = peer2.listen()?; // Peer 2 sends a message