From 65ab27d0452a1d4e46104f5a50f1acf8b088490f Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Thu, 18 Apr 2024 00:30:05 +0800 Subject: [PATCH] reading code --- src/lib.rs | 52 +++++++++++++++++++++------------------ src/stream/tcp.rs | 15 +++++------ src/stream/tcp_wrapper.rs | 14 ++++------- src/stream/udp.rs | 20 +++++++-------- src/stream/unknown.rs | 7 +++--- 5 files changed, 52 insertions(+), 56 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6fa7ade..a997502 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,9 @@ use tokio::{ task::JoinHandle, }; +pub(crate) type PacketSender = UnboundedSender; +pub(crate) type PacketReceiver = UnboundedReceiver; + mod error; mod packet; pub mod stream; @@ -62,17 +65,21 @@ impl Default for IpStackConfig { } impl IpStackConfig { - pub fn tcp_timeout(&mut self, timeout: Duration) { + pub fn tcp_timeout(&mut self, timeout: Duration) -> &mut Self { self.tcp_timeout = timeout; + self } - pub fn udp_timeout(&mut self, timeout: Duration) { + pub fn udp_timeout(&mut self, timeout: Duration) -> &mut Self { self.udp_timeout = timeout; + self } - pub fn mtu(&mut self, mtu: u16) { + pub fn mtu(&mut self, mtu: u16) -> &mut Self { self.mtu = mtu; + self } - pub fn packet_information(&mut self, packet_information: bool) { + pub fn packet_information(&mut self, packet_information: bool) -> &mut Self { self.packet_information = packet_information; + self } } @@ -111,12 +118,9 @@ fn run( where D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let mut streams: AHashMap> = AHashMap::new(); - let offset = if config.packet_information && cfg!(unix) { - 4 - } else { - 0 - }; + let mut streams: AHashMap = AHashMap::new(); + let pi = config.packet_information; + let offset = if pi && cfg!(unix) { 4 } else { 0 }; let mut buffer = [0_u8; u16::MAX as usize + 4]; let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); @@ -124,7 +128,7 @@ where loop { select! { Ok(n) = device.read(&mut buffer) => { - if let Some(stream) = process_read( + if let Some(stream) = process_device_read( &buffer[offset..n], &mut streams, &pkt_sender, @@ -134,12 +138,12 @@ where } } Some(packet) = pkt_receiver.recv() => { - process_recv( + process_upstream_recv( packet, &mut streams, &mut device, #[cfg(unix)] - config.packet_information, + pi, ) .await?; } @@ -148,10 +152,10 @@ where }) } -fn process_read( +fn process_device_read( data: &[u8], - streams: &mut AHashMap>, - pkt_sender: &UnboundedSender, + streams: &mut AHashMap, + pkt_sender: &PacketSender, config: &IpStackConfig, ) -> Option { let Ok(packet) = NetworkPacket::parse(data) else { @@ -193,8 +197,8 @@ fn process_read( fn create_stream( packet: NetworkPacket, config: &IpStackConfig, - pkt_sender: &UnboundedSender, -) -> Option<(UnboundedSender, IpStackStream)> { + pkt_sender: &PacketSender, +) -> Option<(PacketSender, IpStackStream)> { match packet.transport_protocol() { IpStackPacketProtocol::Tcp(h) => { match IpStackTcpStream::new( @@ -233,9 +237,9 @@ fn create_stream( } } -async fn process_recv( +async fn process_upstream_recv( packet: NetworkPacket, - streams: &mut AHashMap>, + streams: &mut AHashMap, device: &mut D, #[cfg(unix)] packet_information: bool, ) -> Result<()> @@ -247,19 +251,19 @@ where return Ok(()); } #[allow(unused_mut)] - let Ok(mut packet_byte) = packet.to_bytes() else { + let Ok(mut packet_bytes) = packet.to_bytes() else { trace!("to_bytes error"); return Ok(()); }; #[cfg(unix)] if packet_information { if packet.src_addr().is_ipv4() { - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); + packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); } else { - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); + packet_bytes.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); } } - device.write_all(&packet_byte).await?; + device.write_all(&packet_bytes).await?; // device.flush().await.unwrap(); Ok(()) diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 552a3ec..f4dd4d1 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -5,7 +5,7 @@ use crate::{ IpStackPacketProtocol, TcpHeaderWrapper, TransportHeader, }, stream::tcb::{Tcb, TcpState}, - DROP_TTL, TTL, + PacketReceiver, PacketSender, DROP_TTL, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel}; use std::{ @@ -18,10 +18,7 @@ use std::{ task::{Context, Poll, Waker}, time::Duration, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{UnboundedReceiver, UnboundedSender}, -}; +use tokio::io::{AsyncRead, AsyncWrite}; use log::{trace, warn}; @@ -53,8 +50,8 @@ impl Shutdown { pub(crate) struct IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, + stream_receiver: PacketReceiver, + packet_sender: PacketSender, packet_to_send: Option, tcb: Tcb, mtu: u16, @@ -67,8 +64,8 @@ impl IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpHeaderWrapper, - pkt_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, + pkt_sender: PacketSender, + stream_receiver: PacketReceiver, mtu: u16, tcp_timeout: Duration, ) -> Result { diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index e4b1161..19bddd2 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -1,20 +1,16 @@ use super::tcp::IpStackTcpStream as IpStackTcpStreamInner; use crate::{ packet::{NetworkPacket, TcpHeaderWrapper}, - IpStackError, + IpStackError, PacketSender, }; use std::{net::SocketAddr, pin::Pin, time::Duration}; -use tokio::{ - io::AsyncWriteExt, - sync::mpsc::{self, UnboundedSender}, - time::timeout, -}; +use tokio::{io::AsyncWriteExt, sync::mpsc, time::timeout}; pub struct IpStackTcpStream { inner: Option>, peer_addr: SocketAddr, local_addr: SocketAddr, - stream_sender: mpsc::UnboundedSender, + stream_sender: PacketSender, } impl IpStackTcpStream { @@ -22,7 +18,7 @@ impl IpStackTcpStream { local_addr: SocketAddr, peer_addr: SocketAddr, tcp: TcpHeaderWrapper, - pkt_sender: UnboundedSender, + pkt_sender: PacketSender, mtu: u16, tcp_timeout: Duration, ) -> Result { @@ -50,7 +46,7 @@ impl IpStackTcpStream { pub fn peer_addr(&self) -> SocketAddr { self.peer_addr } - pub fn stream_sender(&self) -> UnboundedSender { + pub fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } } diff --git a/src/stream/udp.rs b/src/stream/udp.rs index b7e6bc0..c9edc90 100644 --- a/src/stream/udp.rs +++ b/src/stream/udp.rs @@ -1,12 +1,12 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - IpStackError, TTL, + IpStackError, PacketReceiver, PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header, UdpHeader}; use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; use tokio::{ io::{AsyncRead, AsyncWrite}, - sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + sync::mpsc, time::Sleep, }; @@ -14,10 +14,10 @@ use tokio::{ pub struct IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, - stream_sender: UnboundedSender, - stream_receiver: UnboundedReceiver, - packet_sender: UnboundedSender, - first_paload: Option>, + stream_sender: PacketSender, + stream_receiver: PacketReceiver, + packet_sender: PacketSender, + first_payload: Option>, timeout: Pin>, udp_timeout: Duration, mtu: u16, @@ -28,7 +28,7 @@ impl IpStackUdpStream { src_addr: SocketAddr, dst_addr: SocketAddr, payload: Vec, - packet_sender: UnboundedSender, + packet_sender: PacketSender, mtu: u16, udp_timeout: Duration, ) -> Self { @@ -40,14 +40,14 @@ impl IpStackUdpStream { stream_sender, stream_receiver, packet_sender, - first_paload: Some(payload), + first_payload: Some(payload), timeout: Box::pin(tokio::time::sleep_until(deadline)), udp_timeout, mtu, } } - pub(crate) fn stream_sender(&self) -> UnboundedSender { + pub(crate) fn stream_sender(&self) -> PacketSender { self.stream_sender.clone() } @@ -126,7 +126,7 @@ impl AsyncRead for IpStackUdpStream { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - if let Some(p) = self.first_paload.take() { + if let Some(p) = self.first_payload.take() { buf.put_slice(&p); return std::task::Poll::Ready(Ok(())); } diff --git a/src/stream/unknown.rs b/src/stream/unknown.rs index 5eccc5d..838d93f 100644 --- a/src/stream/unknown.rs +++ b/src/stream/unknown.rs @@ -1,10 +1,9 @@ use crate::{ packet::{IpHeader, NetworkPacket, TransportHeader}, - TTL, + PacketSender, TTL, }; use etherparse::{IpNumber, Ipv4Header, Ipv6FlowLabel, Ipv6Header}; use std::{io::Error, mem, net::IpAddr}; -use tokio::sync::mpsc::UnboundedSender; pub struct IpStackUnknownTransport { src_addr: IpAddr, @@ -12,7 +11,7 @@ pub struct IpStackUnknownTransport { payload: Vec, protocol: IpNumber, mtu: u16, - packet_sender: UnboundedSender, + packet_sender: PacketSender, } impl IpStackUnknownTransport { @@ -22,7 +21,7 @@ impl IpStackUnknownTransport { payload: Vec, ip: &IpHeader, mtu: u16, - packet_sender: UnboundedSender, + packet_sender: PacketSender, ) -> Self { let protocol = match ip { IpHeader::Ipv4(ip) => ip.protocol,