From 6b23ebf28151027bbdbb34778c1e902ad955e22b Mon Sep 17 00:00:00 2001 From: SajjadPourali Date: Mon, 4 Mar 2024 18:05:57 -0500 Subject: [PATCH] Fix fin race condition --- examples/tun_wintun.rs | 2 +- src/stream/tcp.rs | 49 ++++++++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/examples/tun_wintun.rs b/examples/tun_wintun.rs index 397cc85..1df05c0 100644 --- a/examples/tun_wintun.rs +++ b/examples/tun_wintun.rs @@ -63,7 +63,7 @@ async fn main() -> Result<(), Box> { }; println!("==== New TCP connection ===="); tokio::spawn(async move { - let _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await; + _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await; println!("====== end tcp connection ======"); }); } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index 54a0f00..f230cc5 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -208,7 +208,9 @@ impl AsyncRead for IpStackTcpStream { buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { loop { - if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) { + if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) + && self.packet_to_send.is_none() + { self.packet_to_send = Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?); self.tcb.change_state(TcpState::Closed); @@ -256,7 +258,21 @@ impl AsyncRead for IpStackTcpStream { .map_err(|_| Error::from(ErrorKind::UnexpectedEof))?; return std::task::Poll::Ready(Ok(())); } + if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) { + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); + self.tcb.add_seq_one(); + self.tcb.change_state(TcpState::FinWait2(true)); + continue; + } else if matches!(self.shutdown, Shutdown::Pending(_)) + && matches!(self.tcb.get_state(), TcpState::Established) + { + let flags = tcp_flags::FIN | tcp_flags::ACK; + self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); + self.tcb.change_state(TcpState::FinWait1(false)); + continue; + } match self.stream_receiver.poll_recv(cx) { std::task::Poll::Ready(Some(p)) => { let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else { @@ -366,7 +382,6 @@ impl AsyncRead for IpStackTcpStream { let flags = tcp_flags::ACK; self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); - // self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait1(true)); continue; } @@ -400,12 +415,18 @@ impl AsyncRead for IpStackTcpStream { continue; } } else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) { - if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { + if t.flags() == tcp_flags::ACK { + // panic!("ACK received in FinWait1"); + self.tcb.add_ack(1); + self.tcb.change_state(TcpState::FinWait1(true)); + continue; + } else if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) { let flags = tcp_flags::ACK; + self.tcb.add_seq_one(); + self.tcb.add_ack(1); self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); self.tcb.change_send_window(t.inner().window_size); - // self.tcb.add_seq_one(); self.tcb.change_state(TcpState::FinWait2(false)); continue; } @@ -417,23 +438,6 @@ impl AsyncRead for IpStackTcpStream { } std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())), std::task::Poll::Pending => { - if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) { - let flags = tcp_flags::FIN | tcp_flags::ACK; - self.packet_to_send = - Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); - - self.tcb.change_state(TcpState::FinWait2(false)); - continue; - } else if matches!(self.shutdown, Shutdown::Pending(_)) - && matches!(self.tcb.get_state(), TcpState::Established) - { - let flags = tcp_flags::FIN | tcp_flags::ACK; - self.packet_to_send = - Some(self.create_rev_packet(flags, TTL, None, Vec::new())?); - self.tcb.change_state(TcpState::FinWait1(false)); - - continue; - } return std::task::Poll::Pending; } } @@ -447,6 +451,9 @@ impl AsyncWrite for IpStackTcpStream { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { + if !matches!(self.tcb.get_state(), TcpState::Established) { + return std::task::Poll::Ready(Err(Error::from(ErrorKind::NotConnected))); + } self.tcb.reset_timeout(); if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2