From 44e0f47ceb386ae3dc1fa301b3c64c7a039e0b10 Mon Sep 17 00:00:00 2001 From: Dmitry Rodionov Date: Tue, 9 Apr 2024 20:35:43 +0400 Subject: [PATCH] Fixes and refactor (#37) * Fix buffer, move offset out of loop, store handle * Fix lost packets if channel was closed, Refactor * Improve error handling in create_stream --- src/lib.rs | 279 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 172 insertions(+), 107 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9a597df..f7f7913 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, select, sync::mpsc::{self, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, }; mod error; @@ -77,120 +78,21 @@ impl IpStackConfig { pub struct IpStack { accept_receiver: UnboundedReceiver, + pub handle: JoinHandle>, } impl IpStack { - pub fn new(config: IpStackConfig, mut device: D) -> IpStack + pub fn new(config: IpStackConfig, device: D) -> IpStack where - D: AsyncRead + AsyncWrite + std::marker::Unpin + std::marker::Send + 'static, + D: AsyncRead + AsyncWrite + Unpin + Send + 'static, { let (accept_sender, accept_receiver) = mpsc::unbounded_channel::(); + let handle = run(config, device, accept_sender); - tokio::spawn(async move { - let mut streams: AHashMap> = - AHashMap::new(); - let mut buffer = [0u8; u16::MAX as usize]; - - let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); - loop { - select! { - Ok(n) = device.read(&mut buffer) => { - let offset = if config.packet_information && cfg!(unix) {4} else {0}; - let Ok(packet) = NetworkPacket::parse(&buffer[offset..n]) else { - accept_sender.send(IpStackStream::UnknownNetwork(buffer[offset..n].to_vec()))?; - continue; - }; - if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { - accept_sender.send( - IpStackStream::UnknownTransport(IpStackUnknownTransport::new( - packet.src_addr().ip(), - packet.dst_addr().ip(), - packet.payload, - &packet.ip, - config.mtu, - pkt_sender.clone() - )) - )?; - continue; - } - - match streams.entry(packet.network_tuple()){ - Occupied(entry) =>{ - if let Err(e) = entry.get().send(packet){ - trace!("Send packet error \"{}\"", e); - } - } - Vacant(entry) => { - match packet.transport_protocol(){ - IpStackPacketProtocol::Tcp(h) => { - match IpStackTcpStream::new( - packet.src_addr(), - packet.dst_addr(), - h, - pkt_sender.clone(), - config.mtu, - config.tcp_timeout - ){ - Ok(stream) => { - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Tcp(stream))?; - } - Err(e) => { - if matches!(e,IpStackError::InvalidTcpPacket){ - trace!("Invalid TCP packet"); - continue; - } - error!("IpStackTcpStream::new failed \"{}\"", e); - } - } - } - IpStackPacketProtocol::Udp => { - let stream = IpStackUdpStream::new( - packet.src_addr(), - packet.dst_addr(), - packet.payload, - pkt_sender.clone(), - config.mtu, - config.udp_timeout - ); - entry.insert(stream.stream_sender()); - accept_sender.send(IpStackStream::Udp(stream))?; - } - IpStackPacketProtocol::Unknown => { - unreachable!() - } - } - } - } - } - Some(packet) = pkt_receiver.recv() => { - if packet.ttl() == 0{ - streams.remove(&packet.reverse_network_tuple()); - continue; - } - #[allow(unused_mut)] - let Ok(mut packet_byte) = packet.to_bytes() else{ - trace!("to_bytes error"); - continue; - }; - #[cfg(unix)] - if config.packet_information { - if packet.src_addr().is_ipv4(){ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); - } else{ - packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); - } - } - device.write_all(&packet_byte).await?; - // device.flush().await.unwrap(); - } - } - } - #[allow(unreachable_code)] - Ok::<(), IpStackError>(()) - }); - - IpStack { accept_receiver } + IpStack { + accept_receiver, + handle, + } } pub async fn accept(&mut self) -> Result { @@ -200,3 +102,166 @@ impl IpStack { .ok_or(IpStackError::AcceptError) } } + +fn run( + config: IpStackConfig, + mut device: D, + accept_sender: UnboundedSender, +) -> JoinHandle> +where + D: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + let mut streams: AHashMap> = AHashMap::new(); + let offset = if config.packet_information && cfg!(unix) { + 4 + } else { + 0 + }; + let mut buffer = [0_u8; u16::MAX as usize + 4]; + let (pkt_sender, mut pkt_receiver) = mpsc::unbounded_channel::(); + + tokio::spawn(async move { + loop { + select! { + Ok(n) = device.read(&mut buffer) => { + if let Some(stream) = process_read( + &buffer[offset..n], + &mut streams, + &pkt_sender, + &config, + )? { + accept_sender.send(stream)?; + } + } + Some(packet) = pkt_receiver.recv() => { + process_recv( + packet, + &mut streams, + &mut device, + #[cfg(unix)] + config.packet_information, + ) + .await?; + } + } + } + }) +} + +fn process_read( + data: &[u8], + streams: &mut AHashMap>, + pkt_sender: &UnboundedSender, + config: &IpStackConfig, +) -> Result> { + let Ok(packet) = NetworkPacket::parse(data) else { + return Ok(Some(IpStackStream::UnknownNetwork(data.to_owned()))); + }; + + if let IpStackPacketProtocol::Unknown = packet.transport_protocol() { + return Ok(Some(IpStackStream::UnknownTransport( + IpStackUnknownTransport::new( + packet.src_addr().ip(), + packet.dst_addr().ip(), + packet.payload, + &packet.ip, + config.mtu, + pkt_sender.clone(), + ), + ))); + } + + Ok(match streams.entry(packet.network_tuple()) { + Occupied(mut entry) => { + if let Err(e) = entry.get().send(packet) { + trace!("New stream because: {}", e); + create_stream(e.0, config, pkt_sender)?.map(|s| { + entry.insert(s.0); + s.1 + }) + } else { + None + } + } + Vacant(entry) => create_stream(packet, config, pkt_sender)?.map(|s| { + entry.insert(s.0); + s.1 + }), + }) +} + +fn create_stream( + packet: NetworkPacket, + config: &IpStackConfig, + pkt_sender: &UnboundedSender, +) -> Result, IpStackStream)>> { + match packet.transport_protocol() { + IpStackPacketProtocol::Tcp(h) => { + match IpStackTcpStream::new( + packet.src_addr(), + packet.dst_addr(), + h, + pkt_sender.clone(), + config.mtu, + config.tcp_timeout, + ) { + Ok(stream) => Ok(Some((stream.stream_sender(), IpStackStream::Tcp(stream)))), + Err(e) => { + if matches!(e, IpStackError::InvalidTcpPacket) { + trace!("Invalid TCP packet"); + Ok(None) + } else { + error!("IpStackTcpStream::new failed \"{}\"", e); + Err(e) + } + } + } + } + IpStackPacketProtocol::Udp => { + let stream = IpStackUdpStream::new( + packet.src_addr(), + packet.dst_addr(), + packet.payload, + pkt_sender.clone(), + config.mtu, + config.udp_timeout, + ); + Ok(Some((stream.stream_sender(), IpStackStream::Udp(stream)))) + } + IpStackPacketProtocol::Unknown => { + unreachable!() + } + } +} + +async fn process_recv( + packet: NetworkPacket, + streams: &mut AHashMap>, + device: &mut D, + #[cfg(unix)] packet_information: bool, +) -> Result<()> +where + D: AsyncWrite + Unpin + 'static, +{ + if packet.ttl() == 0 { + streams.remove(&packet.reverse_network_tuple()); + return Ok(()); + } + #[allow(unused_mut)] + let Ok(mut packet_byte) = packet.to_bytes() else { + trace!("to_bytes error"); + return Ok(()); + }; + #[cfg(unix)] + if packet_information { + if packet.src_addr().is_ipv4() { + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP4].concat()); + } else { + packet_byte.splice(0..0, [TUN_FLAGS, TUN_PROTO_IP6].concat()); + } + } + device.write_all(&packet_byte).await?; + // device.flush().await.unwrap(); + + Ok(()) +}