Skip to content

Commit

Permalink
Fix fin race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
SajjadPourali committed Mar 4, 2024
1 parent f6e618e commit 6b23ebf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/tun_wintun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
};
println!("==== New TCP connection ====");
tokio::spawn(async move {
let _ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await;
_ = tokio::io::copy_bidirectional(&mut tcp, &mut s).await;
println!("====== end tcp connection ======");
});
}
Expand Down
49 changes: 28 additions & 21 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ impl AsyncRead for IpStackTcpStream {
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
loop {
if matches!(self.tcb.get_state(), TcpState::FinWait2(false)) {
if matches!(self.tcb.get_state(), TcpState::FinWait2(false))
&& self.packet_to_send.is_none()
{
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
Expand Down Expand Up @@ -256,7 +258,21 @@ impl AsyncRead for IpStackTcpStream {
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
return std::task::Poll::Ready(Ok(()));
}
if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
&& matches!(self.tcb.get_state(), TcpState::Established)
{
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait1(false));

continue;
}
match self.stream_receiver.poll_recv(cx) {
std::task::Poll::Ready(Some(p)) => {
let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
Expand Down Expand Up @@ -366,7 +382,6 @@ impl AsyncRead for IpStackTcpStream {
let flags = tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
// self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait1(true));
continue;
}
Expand Down Expand Up @@ -400,12 +415,18 @@ impl AsyncRead for IpStackTcpStream {
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
if t.flags() == tcp_flags::ACK {
// panic!("ACK received in FinWait1");
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait1(true));
continue;
} else if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
let flags = tcp_flags::ACK;
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.change_send_window(t.inner().window_size);
// self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(false));
continue;
}
Expand All @@ -417,23 +438,6 @@ impl AsyncRead for IpStackTcpStream {
}
std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())),
std::task::Poll::Pending => {
if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);

self.tcb.change_state(TcpState::FinWait2(false));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
&& matches!(self.tcb.get_state(), TcpState::Established)
{
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait1(false));

continue;
}
return std::task::Poll::Pending;
}
}
Expand All @@ -447,6 +451,9 @@ impl AsyncWrite for IpStackTcpStream {
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
if !matches!(self.tcb.get_state(), TcpState::Established) {
return std::task::Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
self.tcb.reset_timeout();

if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2
Expand Down

0 comments on commit 6b23ebf

Please sign in to comment.