From a98e2715ad236ebcf2d33a9fa46e266afc6d430a Mon Sep 17 00:00:00 2001 From: ssrlive <30760636+ssrlive@users.noreply.github.com> Date: Sun, 21 Apr 2024 23:00:25 +0800 Subject: [PATCH] reading code --- src/stream/tcb.rs | 13 +++++++----- src/stream/tcp.rs | 43 ++++++++++++++++++++++----------------- src/stream/tcp_wrapper.rs | 7 ++++--- 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 299b437..7bac065 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -5,16 +5,19 @@ use tokio::time::Sleep; const MAX_UNACK: u32 = 1024 * 16; // 16KB const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Default)] pub enum TcpState { SynReceived(bool), // bool means if syn/ack is sent + #[default] Established, FinWait1(bool), FinWait2(bool), // bool means waiting for ack Closed, } -#[derive(Clone, Debug)] + +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Default)] pub(super) enum PacketStatus { + #[default] WindowUpdate, Invalid, RetransmissionRequest, @@ -104,8 +107,8 @@ impl Tcb { pub(super) fn change_state(&mut self, state: TcpState) { self.state = state; } - pub(super) fn get_state(&self) -> &TcpState { - &self.state + pub(super) fn get_state(&self) -> TcpState { + self.state } pub(super) fn change_send_window(&mut self, window: u16) { let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64) @@ -170,7 +173,7 @@ impl Tcb { let distance = ack.wrapping_sub(self.last_ack); self.last_ack = self.last_ack.wrapping_add(distance); - if matches!(self.state, TcpState::Established) { + if self.state == TcpState::Established { if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) { let mut inflight_packet = self.inflight_packets.remove(i); let distance = ack.wrapping_sub(inflight_packet.seq) as usize; diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index f4dd4d1..f89cab2 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -64,7 +64,7 @@ impl IpStackTcpStream { src_addr: SocketAddr, dst_addr: SocketAddr, tcp: TcpHeaderWrapper, - pkt_sender: PacketSender, + packet_sender: PacketSender, stream_receiver: PacketReceiver, mtu: u16, tcp_timeout: Duration, @@ -73,7 +73,7 @@ impl IpStackTcpStream { src_addr, dst_addr, stream_receiver, - packet_sender: pkt_sender.clone(), + packet_sender, packet_to_send: None, tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout), mtu, @@ -84,7 +84,10 @@ impl IpStackTcpStream { return Ok(stream); } if !tcp.inner().rst { - _ = pkt_sender.send(stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?); + let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?; + if let Err(err) = stream.packet_sender.send(pkt) { + log::warn!("Error sending RST/ACK packet: {:?}", err); + } } Err(IpStackError::InvalidTcpPacket) } @@ -156,7 +159,8 @@ impl IpStackTcpStream { tcp_header.header_len() as u16, ); payload.truncate(payload_len as usize); - ip_h.payload_length = (payload.len() + tcp_header.header_len()) as u16; + let len = payload.len() + tcp_header.header_len(); + ip_h.set_payload_length(len).map_err(IpStackError::from)?; IpHeader::Ipv6(ip_h) } @@ -190,17 +194,17 @@ impl AsyncRead for IpStackTcpStream { buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { loop { - if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) - && self.packet_to_send.is_none() - { + if self.tcb.get_state() == TcpState::FinWait2(false) && self.packet_to_send.is_none() { self.packet_to_send = Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); self.shutdown.ready(); return Poll::Ready(Ok(())); } + let min = self.tcb.get_available_read_buffer_size() as u16; self.tcb.change_recv_window(min); + if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) { trace!("timeout reached for {:?}", self.dst_addr); self.packet_sender @@ -210,10 +214,9 @@ impl AsyncRead for IpStackTcpStream { self.shutdown.ready(); return Poll::Ready(Err(Error::from(ErrorKind::TimedOut))); } - self.tcb.reset_timeout(); - if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) { + if self.tcb.get_state() == TcpState::SynReceived(false) { self.packet_to_send = Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); @@ -224,7 +227,7 @@ impl AsyncRead for IpStackTcpStream { self.packet_sender .send(packet) .or(Err(ErrorKind::UnexpectedEof))?; - if matches!(self.tcb.get_state(), TcpState::Closed) { + if self.tcb.get_state() == TcpState::Closed { self.shutdown.ready(); return Poll::Ready(Ok(())); } @@ -242,7 +245,7 @@ impl AsyncRead for IpStackTcpStream { .or(Err(ErrorKind::UnexpectedEof))?; return Poll::Ready(Ok(())); } - if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) { + if self.tcb.get_state() == TcpState::FinWait1(true) { self.packet_to_send = Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?); self.tcb.add_seq_one(); @@ -250,7 +253,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } else if matches!(self.shutdown, Shutdown::Pending(_)) - && matches!(self.tcb.get_state(), TcpState::Established) + && self.tcb.get_state() == TcpState::Established && self.tcb.last_ack == self.tcb.seq { self.packet_to_send = @@ -278,13 +281,13 @@ impl AsyncRead for IpStackTcpStream { continue; } - if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) { + if self.tcb.get_state() == TcpState::SynReceived(true) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.change_send_window(t.inner().window_size); self.tcb.change_state(TcpState::Established); } - } else if matches!(self.tcb.get_state(), TcpState::Established) { + } else if self.tcb.get_state() == TcpState::Established { if t.flags() == ACK { match self.tcb.check_pkt_type(&t, &p.payload) { PacketStatus::WindowUpdate => { @@ -391,7 +394,7 @@ impl AsyncRead for IpStackTcpStream { .add_unordered_packet(t.inner().sequence_number, &p.payload); continue; } - } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { + } else if self.tcb.get_state() == TcpState::FinWait1(false) { if t.flags() == ACK { self.tcb.change_last_ack(t.inner().acknowledgment_number); self.tcb.add_ack(1); @@ -405,7 +408,7 @@ impl AsyncRead for IpStackTcpStream { self.tcb.change_state(TcpState::FinWait2(true)); continue; } - } else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) { + } else if self.tcb.get_state() == TcpState::FinWait2(true) { if t.flags() == ACK { self.tcb.change_state(TcpState::FinWait2(false)); } else if t.flags() == (FIN | ACK) { @@ -428,7 +431,7 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if !matches!(self.tcb.get_state(), TcpState::Established) { + if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } self.tcb.reset_timeout(); @@ -463,7 +466,7 @@ impl AsyncWrite for IpStackTcpStream { mut self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - if !matches!(self.tcb.get_state(), TcpState::Established) { + if self.tcb.get_state() != TcpState::Established { return Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); } if let Some(i) = self @@ -513,7 +516,9 @@ impl AsyncWrite for IpStackTcpStream { impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) { - _ = self.packet_sender.send(p); + if let Err(err) = self.packet_sender.send(p) { + log::trace!("Error sending NON packet: {:?}", err); + } } } } diff --git a/src/stream/tcp_wrapper.rs b/src/stream/tcp_wrapper.rs index 19bddd2..e6653b9 100644 --- a/src/stream/tcp_wrapper.rs +++ b/src/stream/tcp_wrapper.rs @@ -32,9 +32,8 @@ impl IpStackTcpStream { mtu, tcp_timeout, ) - .map(Box::new) .map(|inner| IpStackTcpStream { - inner: Some(inner), + inner: Some(Box::new(inner)), peer_addr, local_addr, stream_sender, @@ -107,7 +106,9 @@ impl Drop for IpStackTcpStream { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { tokio::spawn(async move { - _ = timeout(Duration::from_secs(2), inner.shutdown()).await; + if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await { + log::warn!("Error while dropping IpStackTcpStream: {:?}", err); + } }); } }