From 409692f5fa50d02d926a525a1ac8793ec3751d1f Mon Sep 17 00:00:00 2001 From: RoDmitry Date: Mon, 8 Apr 2024 13:06:37 +0000 Subject: [PATCH 1/3] Fix buffer, move offset out of loop, store handle --- src/lib.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9a597df..dbe45b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, select, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, }; mod error; @@ -77,6 +78,7 @@ impl IpStackConfig { pub struct IpStack { accept_receiver: UnboundedReceiver, + pub handle: JoinHandle>, } impl IpStack { @@ -86,16 +88,20 @@ impl IpStack { { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - tokio::spawn(async move { + let handle = tokio::spawn(async move { let mut streams: AHashMap> = AHashMap::new(); - let mut buffer = [0u8; u16::MAX as usize]; + let offset = if config.packet_information && cfg!(unix) { + 4 + } else { + 0 + }; + let mut buffer = [0u8; u16::MAX as usize + 4]; let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); loop { select! { Ok(n) = device.read(&mut buffer) => { - let offset = if config.packet_information && cfg!(unix) {4} else {0}; let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else { accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?; continue; @@ -186,11 +192,12 @@ impl IpStack { } } } - #[allow(unreachable_code)] - Ok::<(), IpStackError>(()) }); - IpStack { accept_receiver } + IpStack { + accept_receiver, + handle, + } } pub async fn accept(&mut self) -> Result { From d5ad2425167e4d48c76aebb867f4742bac3cd64c Mon Sep 17 00:00:00 2001 From: RoDmitry Date: Mon, 8 Apr 2024 18:02:10 +0000 Subject: [PATCH 2/3] Fix lost packets if channel was closed, Refactor --- src/lib.rs | 273 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 108 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dbe45b4..1bc2e13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,117 +82,12 @@ pub struct IpStack { } impl IpStack { - pub fn new(config: IpStackConfig, mut device: D) -> IpStack + pub fn new(config: IpStackConfig, device: D) -> IpStack where - D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static, + D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); - - let handle = tokio::spawn(async move { - let mut streams: AHashMap> = - AHashMap::new(); - let offset = if config.packet_information && cfg!(unix) { - 4 - } else { - 0 - }; - let mut buffer = [0u8; u16::MAX as usize + 4]; - - let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); - loop { - select! { - Ok(n) = device.read(&mut buffer) => { - let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else { - accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?; - continue; - }; - if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - accept_sender.send( - IpStackStream::UnknownTransport(IpStackUnknownTransport::new( - packet.src_addr().ip(), - packet.dst_addr().ip(), - packet.payload, - &packet.ip, - config.mtu, - pkt_sender.clone() - )) - )?; - continue; - } - - match streams.entry(packet.network_tuple()){ - Occupied(entry) =>{ - if let Err(e) = entry.get().send(packet){ - trace!("Send packet error \"{}\"", e); - } - } - Vacant(entry) => { - match packet.transport_protocol(){ - IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new( - packet.src_addr(), - packet.dst_addr(), - h, - pkt_sender.clone(), - config.mtu, - config.tcp_timeout - ){ - Ok(stream) => { - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Tcp(stream))?; - } - Err(e) => { - if matches!(e,IpStackError::InvalidTcpPacket){ - trace!("Invalid TCP packet"); - continue; - } - error!("IpStackTcpStream::new failed \"{}\"", e); - } - } - } - IpStackPacketProtocol::Udp => { - let stream = IpStackUdpStream::new( - packet.src_addr(), - packet.dst_addr(), - packet.payload, - pkt_sender.clone(), - config.mtu, - config.udp_timeout - ); - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Udp(stream))?; - } - IpStackPacketProtocol::Unknown => { - unreachable!() - } - } - } - } - } - Some(packet) = pkt_receiver.recv() => { - if packet.ttl() == 0{ - streams.remove(&packet.reverse_network_tuple()); - continue; - } - #[allow(unused_mut)] - let Ok(mut packet_byte) = packet.to_bytes() else{ - trace!("to_bytes error"); - continue; - }; - #[cfg(unix)] - if config.packet_information { - if packet.src_addr().is_ipv4(){ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); - } else{ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); - } - } - device.write_all(&packet_byte).await?; - // device.flush().await.unwrap(); - } - } - } - }); + let handle = run(config, device, accept_sender); IpStack { accept_receiver, @@ -207,3 +102,165 @@ impl IpStack { .ok_or(IpStackError::AcceptError) } } + +fn run( + config: IpStackConfig, + mut device: D, + accept_sender: UnboundedSender, +) -> JoinHandle> +where + D: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let mut streams: AHashMap> = AHashMap::new(); + let offset = if config.packet_information && cfg!(unix) { + 4 + } else { + 0 + }; + let mut buffer = [0_u8; u16::MAX as usize + 4]; + let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); + + tokio::spawn(async move { + loop { + select! { + Ok(n) = device.read(&mut buffer) => { + if let Some(stream) = process_read( + &buffer[offset..n], + &mut streams, + &pkt_sender, + &config, + )? { + accept_sender.send(stream)?; + } + } + Some(packet) = pkt_receiver.recv() => { + process_recv( + packet, + &mut streams, + &mut device, + #[cfg(unix)] + config.packet_information, + ) + .await?; + } + } + } + }) +} + +fn process_read( + data: &[u8], + streams: &mut AHashMap>, + pkt_sender: &UnboundedSender, + config: &IpStackConfig, +) -> Result> { + let Ok(packet) = NetworkPacket::parse(data) else { + return Ok(Some(IpStackStream::UnknownNetwork(data.to_owned()))); + }; + + if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { + return Ok(Some(IpStackStream::UnknownTransport( + IpStackUnknownTransport::new( + packet.src_addr().ip(), + packet.dst_addr().ip(), + packet.payload, + &packet.ip, + config.mtu, + pkt_sender.clone(), + ), + ))); + } + + Ok(match streams.entry(packet.network_tuple()) { + Occupied(mut entry) => { + if let Err(e) = entry.get().send(packet) { + trace!("New stream because: {}", e); + create_stream(e.0, config, pkt_sender)?.map(|s| { + entry.insert(s.0); + s.1 + }) + } else { + None + } + } + Vacant(entry) => create_stream(packet, config, pkt_sender)?.map(|s| { + entry.insert(s.0); + s.1 + }), + }) +} + +fn create_stream( + packet: NetworkPacket, + config: &IpStackConfig, + pkt_sender: &UnboundedSender, +) -> Result, IpStackStream)>> { + match packet.transport_protocol() { + IpStackPacketProtocol::Tcp(h) => { + match IpStackTcpStream::new( + packet.src_addr(), + packet.dst_addr(), + h, + pkt_sender.clone(), + config.mtu, + config.tcp_timeout, + ) { + Ok(stream) => Ok(Some((stream.stream_sender(), IpStackStream::Tcp(stream)))), + Err(e) => { + if matches!(e, IpStackError::InvalidTcpPacket) { + trace!("Invalid TCP packet"); + } else { + error!("IpStackTcpStream::new failed \"{}\"", e); + } + Ok(None) + } + } + } + IpStackPacketProtocol::Udp => { + let stream = IpStackUdpStream::new( + packet.src_addr(), + packet.dst_addr(), + packet.payload, + pkt_sender.clone(), + config.mtu, + config.udp_timeout, + ); + Ok(Some((stream.stream_sender(), IpStackStream::Udp(stream)))) + } + IpStackPacketProtocol::Unknown => { + unreachable!() + } + } +} + +async fn process_recv( + packet: NetworkPacket, + streams: &mut AHashMap>, + device: &mut D, + #[cfg(unix)] packet_information: bool, +) -> Result<()> +where + D: AsyncWrite + Unpin + 'static, +{ + if packet.ttl() == 0 { + streams.remove(&packet.reverse_network_tuple()); + return Ok(()); + } + #[allow(unused_mut)] + let Ok(mut packet_byte) = packet.to_bytes() else { + trace!("to_bytes error"); + return Ok(()); + }; + #[cfg(unix)] + if packet_information { + if packet.src_addr().is_ipv4() { + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); + } else { + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); + } + } + device.write_all(&packet_byte).await?; + // device.flush().await.unwrap(); + + Ok(()) +} From a98e2b7b04ed819f1c6192ec783190681aaa194e Mon Sep 17 00:00:00 2001 From: SajjadPourali Date: Tue, 9 Apr 2024 12:30:31 -0400 Subject: [PATCH 3/3] Improve error handling in create_stream --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 1bc2e13..f7f7913 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -209,10 +209,11 @@ fn create_stream( Err(e) => { if matches!(e, IpStackError::InvalidTcpPacket) { trace!("Invalid TCP packet"); + Ok(None) } else { error!("IpStackTcpStream::new failed \"{}\"", e); + Err(e) } - Ok(None) } } }