Skip to content

Commit

Permalink
Fix timeout issues (#25)
Browse files Browse the repository at this point in the history
* Fix timeout issues

* Another fixing

* minor changes

* Prevent creating stream when incorrect packet received + add connection check to poll_flush

* minor changes

* minor changes

* minor changes

* minor changes

---------

Co-authored-by: SajjadPourali <sajjad@pourali.com>
  • Loading branch information
ssrlive and SajjadPourali committed Mar 14, 2024
1 parent 05988a3 commit ba40e3d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 69 deletions.
14 changes: 8 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,25 @@ impl IpStack {
Occupied(entry) =>{
if let Err(_x) = entry.get().send(packet){
#[cfg(feature = "log")]
trace!("{}", _x);
trace!("Send packet error \"{}\"", _x);
}
}
Vacant(entry) => {
match packet.transport_protocol(){
IpStackPacketProtocol::Tcp(h) => {
match IpStackTcpStream::new(packet.src_addr(),packet.dst_addr(),h, pkt_sender.clone(),config.mtu,config.tcp_timeout).await{
Ok(stream) => {
if stream.is_closed(){
continue;
}
entry.insert(stream.stream_sender());
accept_sender.send(IpStackStream::Tcp(stream))?;
}
Err(_e) => {
Err(e) => {
if matches!(e,IpStackError::InvalidTcpPacket){
#[cfg(feature = "log")]
trace!("Invalid TCP packet");
continue;
}
#[cfg(feature = "log")]
error!("{}", _e);
error!("IpStackTcpStream::new failed \"{}\"", e);
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/stream/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};

pub use self::tcp::IpStackTcpStream;
pub use self::udp::IpStackUdpStream;
Expand All @@ -22,11 +22,11 @@ impl IpStackStream {
IpStackStream::Tcp(tcp) => tcp.local_addr(),
IpStackStream::Udp(udp) => udp.local_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownTransport(unknown) => match unknown.src_addr() {
std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
},
}
}
Expand All @@ -35,11 +35,11 @@ impl IpStackStream {
IpStackStream::Tcp(tcp) => tcp.peer_addr(),
IpStackStream::Udp(udp) => udp.peer_addr(),
IpStackStream::UnknownNetwork(_) => {
SocketAddr::V4(SocketAddrV4::new(std::net::Ipv4Addr::new(0, 0, 0, 0), 0))
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))
}
IpStackStream::UnknownTransport(unknown) => match unknown.dst_addr() {
std::net::IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
std::net::IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
IpAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr, 0)),
IpAddr::V6(addr) => SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, 0)),
},
}
}
Expand Down
104 changes: 49 additions & 55 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
io::{Error, ErrorKind},
net::SocketAddr,
pin::Pin,
task::Waker,
task::{Context, Poll, Waker},
time::Duration,
};
use tokio::{
Expand Down Expand Up @@ -70,7 +70,7 @@ impl IpStackTcpStream {
) -> Result<IpStackTcpStream, IpStackError> {
let (stream_sender, stream_receiver) = mpsc::unbounded_channel::<NetworkPacket>();

let mut stream = IpStackTcpStream {
let stream = IpStackTcpStream {
src_addr,
dst_addr,
stream_sender,
Expand All @@ -84,12 +84,11 @@ impl IpStackTcpStream {
};
if !tcp.inner().syn {
let flags = tcp_flags::RST | tcp_flags::ACK;
pkt_sender
.send(stream.create_rev_packet(flags, TTL, None, Vec::new())?)
.map_err(|_| IpStackError::InvalidTcpPacket)?;
stream.tcb.change_state(TcpState::Closed);
_ = pkt_sender.send(stream.create_rev_packet(flags, TTL, None, Vec::new())?);
Err(IpStackError::InvalidTcpPacket)
} else {
Ok(stream)
}
Ok(stream)
}

pub(crate) fn stream_sender(&self) -> UnboundedSender<NetworkPacket> {
Expand Down Expand Up @@ -174,12 +173,12 @@ impl IpStackTcpStream {
etherparse::NetHeaders::Ipv4(ref ip_header, _) => {
tcp_header.checksum = tcp_header
.calc_checksum_ipv4(ip_header, &payload)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
.or(Err(ErrorKind::InvalidInput))?;
}
etherparse::NetHeaders::Ipv6(ref ip_header, _) => {
tcp_header.checksum = tcp_header
.calc_checksum_ipv6(ip_header, &payload)
.map_err(|_e| Error::from(ErrorKind::InvalidInput))?;
.or(Err(ErrorKind::InvalidInput))?;
}
}
Ok(NetworkPacket {
Expand All @@ -196,17 +195,14 @@ impl IpStackTcpStream {
pub fn peer_addr(&self) -> SocketAddr {
self.dst_addr
}
pub(crate) fn is_closed(&self) -> bool {
matches!(self.tcb.get_state(), TcpState::Closed)
}
}

impl AsyncRead for IpStackTcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
) -> Poll<std::io::Result<()>> {
loop {
if matches!(self.tcb.get_state(), TcpState::FinWait2(false))
&& self.packet_to_send.is_none()
Expand All @@ -215,21 +211,20 @@ impl AsyncRead for IpStackTcpStream {
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
return std::task::Poll::Ready(Ok(()));
return Poll::Ready(Ok(()));
}
let min = cmp::min(self.tcb.get_available_read_buffer_size() as u16, u16::MAX);
self.tcb.change_recv_window(min);
if matches!(
Pin::new(&mut self.tcb.timeout).poll(cx),
std::task::Poll::Ready(_)
) {
if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) {
#[cfg(feature = "log")]
trace!("timeout reached for {:?}", self.dst_addr);
let flags = tcp_flags::RST | tcp_flags::ACK;
self.packet_sender
.send(self.create_rev_packet(flags, TTL, None, Vec::new())?)
.map_err(|_| ErrorKind::UnexpectedEof)?;
return std::task::Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
.or(Err(ErrorKind::UnexpectedEof))?;
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
return Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
}

self.tcb.reset_timeout();
Expand All @@ -244,19 +239,19 @@ impl AsyncRead for IpStackTcpStream {
if let Some(packet) = self.packet_to_send.take() {
self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
.or(Err(ErrorKind::UnexpectedEof))?;
if matches!(self.tcb.get_state(), TcpState::Closed) {
self.shutdown.ready();
return std::task::Poll::Ready(Ok(()));
return Poll::Ready(Ok(()));
}
}
if let Some(b) = self.tcb.get_unordered_packets() {
self.tcb.add_ack(b.len() as u32);
buf.put_slice(&b);
self.packet_sender
.send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
return std::task::Poll::Ready(Ok(()));
.or(Err(ErrorKind::UnexpectedEof))?;
return Poll::Ready(Ok(()));
}
if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
let flags = tcp_flags::FIN | tcp_flags::ACK;
Expand All @@ -274,17 +269,16 @@ impl AsyncRead for IpStackTcpStream {
continue;
}
match self.stream_receiver.poll_recv(cx) {
std::task::Poll::Ready(Some(p)) => {
Poll::Ready(Some(p)) => {
let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
unreachable!()
};
if t.flags() & tcp_flags::RST != 0 {
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
return std::task::Poll::Ready(Err(Error::from(
ErrorKind::ConnectionReset,
)));
self.shutdown.ready();
return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset)));
}
if matches!(
self.tcb.check_pkt_type(&t, &p.payload),
Expand Down Expand Up @@ -325,11 +319,8 @@ impl AsyncRead for IpStackTcpStream {
PacketStatus::RetransmissionRequest => {
self.tcb.change_send_window(t.inner().window_size);
self.tcb.retransmission = Some(t.inner().acknowledgment_number);
if matches!(
self.as_mut().poll_flush(cx),
std::task::Poll::Pending
) {
return std::task::Poll::Pending;
if matches!(self.as_mut().poll_flush(cx), Poll::Pending) {
return Poll::Pending;
}
continue;
}
Expand Down Expand Up @@ -364,7 +355,7 @@ impl AsyncRead for IpStackTcpStream {
self.write_notify = None;
};
continue;
// return std::task::Poll::Ready(Ok(()));
// return Poll::Ready(Ok(()));
}
PacketStatus::Ack => {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
Expand Down Expand Up @@ -409,7 +400,7 @@ impl AsyncRead for IpStackTcpStream {
// None,
// Vec::new(),
// )?);
// return std::task::Poll::Ready(Ok(()));
// return Poll::Ready(Ok(()));
self.tcb
.add_unordered_packet(t.inner().sequence_number, &p.payload);
continue;
Expand All @@ -436,9 +427,9 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.change_state(TcpState::FinWait2(false));
}
}
std::task::Poll::Ready(None) => return std::task::Poll::Ready(Ok(())),
std::task::Poll::Pending => {
return std::task::Poll::Pending;
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => {
return Poll::Pending;
}
}
}
Expand All @@ -448,25 +439,25 @@ impl AsyncRead for IpStackTcpStream {
impl AsyncWrite for IpStackTcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
cx: &mut Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
) -> Poll<std::io::Result<usize>> {
if !matches!(self.tcb.get_state(), TcpState::Established) {
return std::task::Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
self.tcb.reset_timeout();

if (self.tcb.send_window as u64) < self.tcb.avg_send_window.0 / 2
|| self.tcb.is_send_buffer_full()
{
self.write_notify = Some(cx.waker().clone());
return std::task::Poll::Pending;
return Poll::Pending;
}

if self.tcb.retransmission.is_some() {
self.write_notify = Some(cx.waker().clone());
if matches!(self.as_mut().poll_flush(cx), std::task::Poll::Pending) {
return std::task::Poll::Pending;
if matches!(self.as_mut().poll_flush(cx), Poll::Pending) {
return Poll::Pending;
}
}

Expand All @@ -478,16 +469,19 @@ impl AsyncWrite for IpStackTcpStream {

self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
.or(Err(ErrorKind::UnexpectedEof))?;
self.tcb.add_inflight_packet(seq, &payload);

std::task::Poll::Ready(Ok(payload_len))
Poll::Ready(Ok(payload_len))
}

fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
if !matches!(self.tcb.get_state(), TcpState::Established) {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
if let Some(i) = self
.tcb
.retransmission
Expand All @@ -499,7 +493,7 @@ impl AsyncWrite for IpStackTcpStream {

self.packet_sender
.send(packet)
.map_err(|_| Error::from(ErrorKind::UnexpectedEof))?;
.or(Err(ErrorKind::UnexpectedEof))?;
self.tcb.retransmission = None;
} else if let Some(_i) = self.tcb.retransmission {
#[cfg(feature = "log")]
Expand All @@ -515,18 +509,18 @@ impl AsyncWrite for IpStackTcpStream {
}
panic!("Please report these values at: https://github.com/narrowlink/ipstack/");
}
std::task::Poll::Ready(Ok(()))
Poll::Ready(Ok(()))
}

fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
match &self.shutdown {
Shutdown::Ready => std::task::Poll::Ready(Ok(())),
Shutdown::Ready => Poll::Ready(Ok(())),
Shutdown::Pending(_) | Shutdown::None => {
self.shutdown.pending(cx.waker().clone());
std::task::Poll::Pending
Poll::Pending
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/stream/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl AsyncWrite for IpStackUdpStream {
let payload_len = packet.payload.len();
self.packet_sender
.send(packet)
.map_err(|_| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?;
.or(Err(std::io::ErrorKind::UnexpectedEof))?;
std::task::Poll::Ready(Ok(payload_len))
}

Expand Down

0 comments on commit ba40e3d

Please sign in to comment.