Skip to content

Commit

Permalink
Merge ebe46bd into 54376ef
Browse files Browse the repository at this point in the history
  • Loading branch information
oetyng committed Aug 3, 2021
2 parents 54376ef + ebe46bd commit 036f7c9
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 143 deletions.
4 changes: 2 additions & 2 deletions examples/p2p_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async fn main() -> Result<()> {
let msg = Bytes::from(MSG_MARCO);
println!("Sending to {:?} --> {:?}\n", peer, msg);
node.connect_to(&peer).await?;
node.send_message(msg.clone(), &peer).await?;
node.send_message(&msg, &peer).await?;
}
}

Expand All @@ -72,7 +72,7 @@ async fn main() -> Result<()> {
println!("Received from {:?} --> {:?}", socket_addr, bytes);
if bytes == *MSG_MARCO {
let reply = Bytes::from(MSG_POLO);
node.send_message(reply.clone(), &socket_addr).await?;
node.send_message(&reply, &socket_addr).await?;
println!("Replied to {:?} --> {:?}", socket_addr, reply);
}
println!();
Expand Down
148 changes: 89 additions & 59 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
// specific language governing permissions and limitations relating to use of the SAFE Network
// Software.

use crate::Endpoint;
use crate::{
wire_msg::{write_to_stream, EndpointMsg, ECHO_SRVC_MSG_FLAG, USER_MSG_FLAG},
Endpoint,
};

use super::{
connection_pool::{ConnId, ConnectionPool, ConnectionRemover},
Expand Down Expand Up @@ -50,7 +53,7 @@ impl<I: ConnId> Connection<I> {
}

/// Send message to peer using a uni-directional stream.
pub async fn send_uni(&self, msg: Bytes) -> Result<()> {
pub async fn send_uni(&self, msg: &Bytes) -> Result<()> {
let mut send_stream = self.handle_error(self.quic_conn.open_uni().await).await?;
self.handle_error(send_msg(&mut send_stream, msg).await)
.await?;
Expand Down Expand Up @@ -113,13 +116,17 @@ impl SendStream {
}

/// Send a message using the stream created by the initiator
pub async fn send_user_msg(&mut self, msg: Bytes) -> Result<()> {
send_msg(&mut self.quinn_send_stream, msg).await
pub async fn send_user_msg(&mut self, msg: &Bytes) -> Result<()> {
write_to_stream(msg, USER_MSG_FLAG, &mut self.quinn_send_stream).await
}

/// Send a wire message
pub async fn send(&mut self, msg: WireMsg) -> Result<()> {
msg.write_to_stream(&mut self.quinn_send_stream).await
pub async fn send(&mut self, msg: &WireMsg) -> Result<()> {
let (bytes, flag) = match msg {
WireMsg::UserMsg(bytes) => (bytes, USER_MSG_FLAG),
WireMsg::Echo(bytes) => (bytes, ECHO_SRVC_MSG_FLAG),
};
write_to_stream(bytes, flag, &mut self.quinn_send_stream).await
}

/// Gracefully finish current stream
Expand All @@ -141,10 +148,8 @@ pub async fn read_bytes(recv: &mut quinn::RecvStream) -> Result<WireMsg> {
}

// Helper to send bytes to peer using the provided stream.
pub async fn send_msg(mut send_stream: &mut quinn::SendStream, msg: Bytes) -> Result<()> {
let wire_msg = WireMsg::UserMsg(msg);
wire_msg.write_to_stream(&mut send_stream).await?;
Ok(())
pub async fn send_msg(mut send_stream: &mut quinn::SendStream, msg: &Bytes) -> Result<()> {
write_to_stream(msg, USER_MSG_FLAG, &mut send_stream).await
}

pub(super) fn listen_for_incoming_connections<I: ConnId>(
Expand Down Expand Up @@ -296,36 +301,46 @@ async fn read_on_bi_streams<I: ConnId>(
Ok(WireMsg::UserMsg(bytes)) => {
let _ = message_tx.send((peer_addr, bytes)).await;
}
Ok(WireMsg::EndpointEchoReq) => {
if let Err(error) = handle_endpoint_echo_req(peer_addr, &mut send).await {
Ok(WireMsg::Echo(bytes)) => match bincode::deserialize(&bytes) {
Ok(EndpointMsg::EchoReq) => {
if let Err(error) = handle_endpoint_echo_req(peer_addr, &mut send).await
{
warn!(
"Failed to handle Echo Request for peer {:?} with: {:?}",
peer_addr, error
);

return Err(error);
}
}
Ok(EndpointMsg::VerificationReq(address_sent)) => {
if let Err(error) = handle_endpoint_verification_req(
peer_addr,
address_sent,
&mut send,
endpoint,
)
.await
{
warn!("Failed to handle Endpoint verification request for peer {:?} with: {:?}", peer_addr, error);

return Err(error);
}
}
Ok(msg) => {
warn!(
"Failed to handle Echo Request for peer {:?} with: {:?}",
peer_addr, error
"Unexpected message type from peer {:?}: {:?}",
peer_addr, msg
);

return Err(error);
}
}
Ok(WireMsg::EndpointVerificationReq(address_sent)) => {
if let Err(error) = handle_endpoint_verification_req(
peer_addr,
address_sent,
&mut send,
endpoint,
)
.await
{
warn!("Failed to handle Endpoint verification request for peer {:?} with: {:?}", peer_addr, error);

return Err(error);
Err(err) => {
warn!(
"Failed deserializing msg from a bi-stream for peer {:?} with: {:?}",
peer_addr, err
);
break;
}
}
Ok(msg) => {
warn!(
"Unexpected message type from peer {:?}: {:?}",
peer_addr, msg
);
}
},
Err(Error::StreamRead(quinn::ReadExactError::FinishedEarly)) => {
warn!("Stream finished early");
break;
Expand All @@ -350,8 +365,11 @@ async fn handle_endpoint_echo_req(
send_stream: &mut quinn::SendStream,
) -> Result<()> {
trace!("Received Echo Request from peer {:?}", peer_addr);
let message = WireMsg::EndpointEchoResp(peer_addr);
message.write_to_stream(send_stream).await?;
let msg = WireMsg::Echo(From::from(bincode::serialize(&EndpointMsg::EchoResp(
peer_addr,
))?));
let bytes = From::from(bincode::serialize(&msg)?);
write_to_stream(&bytes, ECHO_SRVC_MSG_FLAG, send_stream).await?;
trace!("Responded to Echo request from peer {:?}", peer_addr);
Ok(())
}
Expand All @@ -369,21 +387,25 @@ async fn handle_endpoint_verification_req<I: ConnId>(
);
// Verify if the peer's endpoint is reachable via EchoServiceReq
let (mut temp_send, mut temp_recv) = endpoint.open_bidirectional_stream(&addr_sent).await?;
let message = WireMsg::EndpointEchoReq;
message
.write_to_stream(&mut temp_send.quinn_send_stream)
.await?;
let verified = matches!(
timeout(
Duration::from_secs(30),
WireMsg::read_from_stream(&mut temp_recv.quinn_recv_stream)
)
.await,
Ok(Ok(WireMsg::EndpointEchoResp(_)))
);

let message = WireMsg::EndpointVerificationResp(verified);
message.write_to_stream(send_stream).await?;
let bytes = WireMsg::to_bytes(&EndpointMsg::EchoReq)?;
write_to_stream(&bytes, ECHO_SRVC_MSG_FLAG, &mut temp_send.quinn_send_stream).await?;

let verified = match timeout(
Duration::from_secs(30),
WireMsg::read_from_stream(&mut temp_recv.quinn_recv_stream),
)
.await
{
Ok(Ok(WireMsg::Echo(m))) => {
matches!(bincode::deserialize(&m)?, EndpointMsg::EchoResp(_))
}
_ => false,
};

let bytes = WireMsg::to_bytes(&EndpointMsg::VerificationResp(verified))?;
write_to_stream(&bytes, ECHO_SRVC_MSG_FLAG, send_stream).await?;

trace!(
"Responded to Endpoint verification request from {:?}",
peer_addr
Expand All @@ -395,6 +417,7 @@ async fn handle_endpoint_verification_req<I: ConnId>(
#[cfg(test)]
mod tests {
use crate::api::QuicP2p;
use crate::wire_msg::{write_to_stream, EndpointMsg, ECHO_SRVC_MSG_FLAG};
use crate::{config::Config, wire_msg::WireMsg, Error};
use anyhow::anyhow;
use std::net::{IpAddr, Ipv4Addr};
Expand Down Expand Up @@ -429,13 +452,20 @@ mod tests {
.await
.ok_or(Error::MissingConnection)?;
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
let message = WireMsg::EndpointEchoReq;
message
.write_to_stream(&mut send_stream.quinn_send_stream)
.await?;
let message = WireMsg::read_from_stream(&mut recv_stream.quinn_recv_stream).await?;
if let WireMsg::EndpointEchoResp(addr) = message {
assert_eq!(addr, peer1_addr);
let bytes = WireMsg::to_bytes(&EndpointMsg::EchoReq)?;
write_to_stream(
&bytes,
ECHO_SRVC_MSG_FLAG,
&mut send_stream.quinn_send_stream,
)
.await?;
let msg = WireMsg::read_from_stream(&mut recv_stream.quinn_recv_stream).await?;
if let WireMsg::Echo(bytes) = msg {
if let EndpointMsg::EchoResp(addr) = bincode::deserialize(&bytes)? {
assert_eq!(addr, peer1_addr);
} else {
anyhow!("Unexpected response to EchoService request");
}
} else {
anyhow!("Unexpected response to EchoService request");
}
Expand Down
79 changes: 51 additions & 28 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// Software.

use crate::connection_pool::ConnId;
use crate::wire_msg::{write_to_stream, EndpointMsg, ECHO_SRVC_MSG_FLAG};

use super::error::Error;
use super::wire_msg::WireMsg;
Expand Down Expand Up @@ -145,21 +146,31 @@ impl<I: ConnId> Endpoint<I> {
.await
.ok_or(Error::MissingConnection)?;
let (mut send, mut recv) = connection.open_bi().await?;
send.send(WireMsg::EndpointVerificationReq(addr)).await?;
send.send(&WireMsg::from_ep_msg(&EndpointMsg::VerificationReq(addr))?)
.await?;
let response = timeout(
Duration::from_secs(ECHO_SERVICE_QUERY_TIMEOUT),
WireMsg::read_from_stream(&mut recv.quinn_recv_stream),
)
.await;
match response {
Ok(Ok(WireMsg::EndpointVerificationResp(valid))) => {
if valid {
info!("Endpoint verification successful! {} is reachable.", addr);
} else {
error!("Endpoint verification failed! {} is not reachable.", addr);
return Err(Error::IncorrectPublicAddress);
Ok(Ok(WireMsg::Echo(bytes))) => match bincode::deserialize(&bytes)? {
EndpointMsg::VerificationResp(valid) => {
if valid {
info!("Endpoint verification successful! {} is reachable.", addr);
} else {
error!("Endpoint verification failed! {} is not reachable.", addr);
return Err(Error::IncorrectPublicAddress);
}
}
}
other => {
error!(
"Unexpected message when verifying public endpoint: {}",
other
);
return Err(Error::UnexpectedMessageType(WireMsg::Echo(bytes)));
}
},
Ok(Ok(other)) => {
error!(
"Unexpected message when verifying public endpoint: {}",
Expand Down Expand Up @@ -413,32 +424,39 @@ impl<I: ConnId> Endpoint<I> {
}
};

let message = WireMsg::EndpointEchoReq;
message.write_to_stream(&mut send_stream).await?;
let bytes = WireMsg::to_bytes(&EndpointMsg::EchoReq)?;
write_to_stream(&bytes, ECHO_SRVC_MSG_FLAG, &mut send_stream).await?;

match timeout(
let msg = match timeout(
Duration::from_secs(ECHO_SERVICE_QUERY_TIMEOUT),
WireMsg::read_from_stream(&mut recv_stream),
)
.await
{
Ok(Ok(WireMsg::EndpointEchoResp(_))) => Ok(()),
Ok(Ok(other)) => {
info!(
"Unexpected message type when verifying reachability: {}",
&other
);
Ok(())
}
Ok(Ok(msg)) => msg,
Ok(Err(err)) => {
info!("Unable to contact peer: {:?}", err);
Err(err)
return Err(err);
}
Err(err) => {
info!("Unable to contact peer: {:?}", err);
Err(Error::NoEchoServiceResponse)
return Err(Error::NoEchoServiceResponse);
}
}
};

let other = match msg {
WireMsg::Echo(bytes) => match bincode::deserialize(&bytes)? {
EndpointMsg::EchoResp(_) => return Ok(()),
_ => WireMsg::Echo(bytes),
},
other => other,
};

info!(
"Unexpected message type when verifying reachability: {}",
&other
);
Ok(())
}

/// Creates a fresh connection without looking at the connection pool and connection duplicator.
Expand Down Expand Up @@ -515,7 +533,7 @@ impl<I: ConnId> Endpoint<I> {

/// Sends a message to a peer. This will attempt to use an existing connection
/// to the destination peer. If a connection does not exist, this will fail with `Error::MissingConnection`
pub async fn try_send_message(&self, msg: Bytes, dest: &SocketAddr) -> Result<()> {
pub async fn try_send_message(&self, msg: &Bytes, dest: &SocketAddr) -> Result<()> {
let connection = self
.get_connection(dest)
.await
Expand All @@ -527,13 +545,13 @@ impl<I: ConnId> Endpoint<I> {
/// Sends a message to a peer. This will attempt to use an existing connection
/// to the peer first. If this connection is broken or doesn't exist
/// a new connection is created and the message is sent.
pub async fn send_message(&self, msg: Bytes, dest: &SocketAddr) -> Result<()> {
if self.try_send_message(msg.clone(), dest).await.is_ok() {
pub async fn send_message(&self, msg: &Bytes, dest: &SocketAddr) -> Result<()> {
if self.try_send_message(msg, dest).await.is_ok() {
return Ok(());
}
self.connect_to(dest).await?;

self.retry(|| async { Ok(self.try_send_message(msg.clone(), dest).await?) })
self.retry(|| async { Ok(self.try_send_message(msg, dest).await?) })
.await
}

Expand Down Expand Up @@ -561,9 +579,14 @@ impl<I: ConnId> Endpoint<I> {
.await
.ok_or(Error::MissingConnection)?;
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
send_stream.send(WireMsg::EndpointEchoReq).await?;
send_stream
.send(&WireMsg::from_ep_msg(&EndpointMsg::EchoReq)?)
.await?;
match WireMsg::read_from_stream(&mut recv_stream.quinn_recv_stream).await {
Ok(WireMsg::EndpointEchoResp(socket_addr)) => Ok(socket_addr),
Ok(WireMsg::Echo(bytes)) => match bincode::deserialize(&bytes)? {
EndpointMsg::EchoResp(socket_addr) => Ok(socket_addr),
_ => Err(Error::UnexpectedMessageType(WireMsg::Echo(bytes))),
},
Ok(msg) => Err(Error::UnexpectedMessageType(msg)),
Err(err) => Err(err),
}
Expand Down
Loading

0 comments on commit 036f7c9

Please sign in to comment.