Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(autonat): use quick-protobuf-codec #4787

Merged
merged 3 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions protocols/autonat/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ libp2p-identity = { workspace = true }
log = "0.4"
rand = "0.8"
quick-protobuf = "0.8"
quick-protobuf-codec = { workspace = true }
asynchronous-codec = "0.6.2"

[dev-dependencies]
async-std = { version = "1.10", features = ["attributes"] }
Expand Down
100 changes: 45 additions & 55 deletions protocols/autonat/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

use crate::proto;
use async_trait::async_trait;
use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use libp2p_core::{upgrade, Multiaddr};
use asynchronous_codec::{FramedRead, FramedWrite};
use futures::io::{AsyncRead, AsyncWrite};
use futures::{SinkExt, StreamExt};
use libp2p_core::Multiaddr;
use libp2p_identity::PeerId;
use libp2p_request_response::{self as request_response};
use libp2p_swarm::StreamProtocol;
use quick_protobuf::{BytesReader, Writer};
use std::{convert::TryFrom, io};

/// The protocol name used for negotiating with multistream-select.
Expand All @@ -44,8 +45,12 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncRead + Send + Unpin,
{
let bytes = upgrade::read_length_prefixed(io, 1024).await?;
let request = DialRequest::from_bytes(&bytes)?;
let message = FramedRead::new(io, codec())
.next()
.await
.ok_or(io::ErrorKind::UnexpectedEof)??;
let request = DialRequest::from_proto(message)?;

Ok(request)
}

Expand All @@ -57,8 +62,12 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncRead + Send + Unpin,
{
let bytes = upgrade::read_length_prefixed(io, 1024).await?;
let response = DialResponse::from_bytes(&bytes)?;
let message = FramedRead::new(io, codec())
.next()
.await
.ok_or(io::ErrorKind::UnexpectedEof)??;
let response = DialResponse::from_proto(message)?;

Ok(response)
}

Expand All @@ -71,8 +80,11 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncWrite + Send + Unpin,
{
upgrade::write_length_prefixed(io, data.into_bytes()).await?;
io.close().await
let mut framed = FramedWrite::new(io, codec());
framed.send(data.into_proto()).await?;
framed.close().await?;

Ok(())
}

async fn write_response<T>(
Expand All @@ -84,24 +96,26 @@ impl request_response::Codec for AutoNatCodec {
where
T: AsyncWrite + Send + Unpin,
{
upgrade::write_length_prefixed(io, data.into_bytes()).await?;
io.close().await
let mut framed = FramedWrite::new(io, codec());
framed.send(data.into_proto()).await?;
framed.close().await?;

Ok(())
}
}

fn codec() -> quick_protobuf_codec::Codec<proto::Message> {
quick_protobuf_codec::Codec::<proto::Message>::new(1024)
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct DialRequest {
pub peer_id: PeerId,
pub addresses: Vec<Multiaddr>,
}

impl DialRequest {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, io::Error> {
use quick_protobuf::MessageRead;

let mut reader = BytesReader::from_bytes(bytes);
let msg = proto::Message::from_reader(&mut reader, bytes)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
if msg.type_pb != Some(proto::MessageType::DIAL) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
}
Expand Down Expand Up @@ -143,17 +157,15 @@ impl DialRequest {
})
}

pub fn into_bytes(self) -> Vec<u8> {
use quick_protobuf::MessageWrite;

pub fn into_proto(self) -> proto::Message {
let peer_id = self.peer_id.to_bytes();
let addrs = self
.addresses
.into_iter()
.map(|addr| addr.to_vec())
.collect();

let msg = proto::Message {
proto::Message {
type_pb: Some(proto::MessageType::DIAL),
dial: Some(proto::Dial {
peer: Some(proto::PeerInfo {
Expand All @@ -162,12 +174,7 @@ impl DialRequest {
}),
}),
dialResponse: None,
};

let mut buf = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut buf);
msg.write_message(&mut writer).expect("Encoding to succeed");
buf
}
}
}

Expand Down Expand Up @@ -217,12 +224,7 @@ pub struct DialResponse {
}

impl DialResponse {
pub fn from_bytes(bytes: &[u8]) -> Result<Self, io::Error> {
use quick_protobuf::MessageRead;

let mut reader = BytesReader::from_bytes(bytes);
let msg = proto::Message::from_reader(&mut reader, bytes)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
pub fn from_proto(msg: proto::Message) -> Result<Self, io::Error> {
if msg.type_pb != Some(proto::MessageType::DIAL_RESPONSE) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid type"));
}
Expand Down Expand Up @@ -258,9 +260,7 @@ impl DialResponse {
})
}

pub fn into_bytes(self) -> Vec<u8> {
use quick_protobuf::MessageWrite;

pub fn into_proto(self) -> proto::Message {
let dial_response = match self.result {
Ok(addr) => proto::DialResponse {
status: Some(proto::ResponseStatus::OK),
Expand All @@ -274,23 +274,17 @@ impl DialResponse {
},
};

let msg = proto::Message {
proto::Message {
type_pb: Some(proto::MessageType::DIAL_RESPONSE),
dial: None,
dialResponse: Some(dial_response),
};

let mut buf = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut buf);
msg.write_message(&mut writer).expect("Encoding to succeed");
buf
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use quick_protobuf::MessageWrite;

#[test]
fn test_request_encode_decode() {
Expand All @@ -301,8 +295,8 @@ mod tests {
"/ip4/192.168.1.42/tcp/30333".parse().unwrap(),
],
};
let bytes = request.clone().into_bytes();
let request2 = DialRequest::from_bytes(&bytes).unwrap();
let proto = request.clone().into_proto();
let request2 = DialRequest::from_proto(proto).unwrap();
assert_eq!(request, request2);
}

Expand All @@ -312,8 +306,8 @@ mod tests {
result: Ok("/ip4/8.8.8.8/tcp/30333".parse().unwrap()),
status_text: None,
};
let bytes = response.clone().into_bytes();
let response2 = DialResponse::from_bytes(&bytes).unwrap();
let proto = response.clone().into_proto();
let response2 = DialResponse::from_proto(proto).unwrap();
assert_eq!(response, response2);
}

Expand All @@ -323,8 +317,8 @@ mod tests {
result: Err(ResponseError::DialError),
status_text: Some("dial failed".to_string()),
};
let bytes = response.clone().into_bytes();
let response2 = DialResponse::from_bytes(&bytes).unwrap();
let proto = response.clone().into_proto();
let response2 = DialResponse::from_proto(proto).unwrap();
assert_eq!(response, response2);
}

Expand All @@ -350,11 +344,7 @@ mod tests {
dialResponse: None,
};

let mut bytes = Vec::with_capacity(msg.get_size());
let mut writer = Writer::new(&mut bytes);
msg.write_message(&mut writer).expect("Encoding to succeed");

let request = DialRequest::from_bytes(&bytes).expect("not to fail");
let request = DialRequest::from_proto(msg).expect("not to fail");

assert_eq!(request.addresses, vec![valid_multiaddr])
}
Expand Down