Skip to content

Commit

Permalink
reading code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Apr 22, 2024
1 parent 082b2b6 commit a98e271
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
13 changes: 8 additions & 5 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ use tokio::time::Sleep;
const MAX_UNACK: u32 = 1024 * 16; // 16KB
const READ_BUFFER_SIZE: usize = 1024 * 16; // 16KB

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Default)]
pub enum TcpState {
SynReceived(bool), // bool means if syn/ack is sent
#[default]
Established,
FinWait1(bool),
FinWait2(bool), // bool means waiting for ack
Closed,
}
#[derive(Clone, Debug)]

#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Default)]
pub(super) enum PacketStatus {
#[default]
WindowUpdate,
Invalid,
RetransmissionRequest,
Expand Down Expand Up @@ -104,8 +107,8 @@ impl Tcb {
pub(super) fn change_state(&mut self, state: TcpState) {
self.state = state;
}
pub(super) fn get_state(&self) -> &TcpState {
&self.state
pub(super) fn get_state(&self) -> TcpState {
self.state
}
pub(super) fn change_send_window(&mut self, window: u16) {
let avg_send_window = ((self.avg_send_window.0 * self.avg_send_window.1) + window as u64)
Expand Down Expand Up @@ -170,7 +173,7 @@ impl Tcb {
let distance = ack.wrapping_sub(self.last_ack);
self.last_ack = self.last_ack.wrapping_add(distance);

if matches!(self.state, TcpState::Established) {
if self.state == TcpState::Established {
if let Some(i) = self.inflight_packets.iter().position(|p| p.contains(ack)) {
let mut inflight_packet = self.inflight_packets.remove(i);
let distance = ack.wrapping_sub(inflight_packet.seq) as usize;
Expand Down
43 changes: 24 additions & 19 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl IpStackTcpStream {
src_addr: SocketAddr,
dst_addr: SocketAddr,
tcp: TcpHeaderWrapper,
pkt_sender: PacketSender,
packet_sender: PacketSender,
stream_receiver: PacketReceiver,
mtu: u16,
tcp_timeout: Duration,
Expand All @@ -73,7 +73,7 @@ impl IpStackTcpStream {
src_addr,
dst_addr,
stream_receiver,
packet_sender: pkt_sender.clone(),
packet_sender,
packet_to_send: None,
tcb: Tcb::new(tcp.inner().sequence_number + 1, tcp_timeout),
mtu,
Expand All @@ -84,7 +84,10 @@ impl IpStackTcpStream {
return Ok(stream);
}
if !tcp.inner().rst {
_ = pkt_sender.send(stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?);
let pkt = stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?;
if let Err(err) = stream.packet_sender.send(pkt) {
log::warn!("Error sending RST/ACK packet: {:?}", err);
}
}
Err(IpStackError::InvalidTcpPacket)
}
Expand Down Expand Up @@ -156,7 +159,8 @@ impl IpStackTcpStream {
tcp_header.header_len() as u16,
);
payload.truncate(payload_len as usize);
ip_h.payload_length = (payload.len() + tcp_header.header_len()) as u16;
let len = payload.len() + tcp_header.header_len();
ip_h.set_payload_length(len).map_err(IpStackError::from)?;

IpHeader::Ipv6(ip_h)
}
Expand Down Expand Up @@ -190,17 +194,17 @@ impl AsyncRead for IpStackTcpStream {
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
loop {
if matches!(self.tcb.get_state(), TcpState::FinWait2(false))
&& self.packet_to_send.is_none()
{
if self.tcb.get_state() == TcpState::FinWait2(false) && self.packet_to_send.is_none() {
self.packet_to_send =
Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
return Poll::Ready(Ok(()));
}

let min = self.tcb.get_available_read_buffer_size() as u16;
self.tcb.change_recv_window(min);

if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) {
trace!("timeout reached for {:?}", self.dst_addr);
self.packet_sender
Expand All @@ -210,10 +214,9 @@ impl AsyncRead for IpStackTcpStream {
self.shutdown.ready();
return Poll::Ready(Err(Error::from(ErrorKind::TimedOut)));
}

self.tcb.reset_timeout();

if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) {
if self.tcb.get_state() == TcpState::SynReceived(false) {
self.packet_to_send =
Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
Expand All @@ -224,7 +227,7 @@ impl AsyncRead for IpStackTcpStream {
self.packet_sender
.send(packet)
.or(Err(ErrorKind::UnexpectedEof))?;
if matches!(self.tcb.get_state(), TcpState::Closed) {
if self.tcb.get_state() == TcpState::Closed {
self.shutdown.ready();
return Poll::Ready(Ok(()));
}
Expand All @@ -242,15 +245,15 @@ impl AsyncRead for IpStackTcpStream {
.or(Err(ErrorKind::UnexpectedEof))?;
return Poll::Ready(Ok(()));
}
if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
if self.tcb.get_state() == TcpState::FinWait1(true) {
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
&& matches!(self.tcb.get_state(), TcpState::Established)
&& self.tcb.get_state() == TcpState::Established
&& self.tcb.last_ack == self.tcb.seq
{
self.packet_to_send =
Expand Down Expand Up @@ -278,13 +281,13 @@ impl AsyncRead for IpStackTcpStream {
continue;
}

if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) {
if self.tcb.get_state() == TcpState::SynReceived(true) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::Established);
}
} else if matches!(self.tcb.get_state(), TcpState::Established) {
} else if self.tcb.get_state() == TcpState::Established {
if t.flags() == ACK {
match self.tcb.check_pkt_type(&t, &p.payload) {
PacketStatus::WindowUpdate => {
Expand Down Expand Up @@ -391,7 +394,7 @@ impl AsyncRead for IpStackTcpStream {
.add_unordered_packet(t.inner().sequence_number, &p.payload);
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
} else if self.tcb.get_state() == TcpState::FinWait1(false) {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.add_ack(1);
Expand All @@ -405,7 +408,7 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.change_state(TcpState::FinWait2(true));
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait2(true)) {
} else if self.tcb.get_state() == TcpState::FinWait2(true) {
if t.flags() == ACK {
self.tcb.change_state(TcpState::FinWait2(false));
} else if t.flags() == (FIN | ACK) {
Expand All @@ -428,7 +431,7 @@ impl AsyncWrite for IpStackTcpStream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
if !matches!(self.tcb.get_state(), TcpState::Established) {
if self.tcb.get_state() != TcpState::Established {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
self.tcb.reset_timeout();
Expand Down Expand Up @@ -463,7 +466,7 @@ impl AsyncWrite for IpStackTcpStream {
mut self: std::pin::Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
if !matches!(self.tcb.get_state(), TcpState::Established) {
if self.tcb.get_state() != TcpState::Established {
return Poll::Ready(Err(Error::from(ErrorKind::NotConnected)));
}
if let Some(i) = self
Expand Down Expand Up @@ -513,7 +516,9 @@ impl AsyncWrite for IpStackTcpStream {
impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
if let Err(err) = self.packet_sender.send(p) {
log::trace!("Error sending NON packet: {:?}", err);
}
}
}
}
7 changes: 4 additions & 3 deletions src/stream/tcp_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ impl IpStackTcpStream {
mtu,
tcp_timeout,
)
.map(Box::new)
.map(|inner| IpStackTcpStream {
inner: Some(inner),
inner: Some(Box::new(inner)),
peer_addr,
local_addr,
stream_sender,
Expand Down Expand Up @@ -107,7 +106,9 @@ impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Some(mut inner) = self.inner.take() {
tokio::spawn(async move {
_ = timeout(Duration::from_secs(2), inner.shutdown()).await;
if let Err(err) = timeout(Duration::from_secs(2), inner.shutdown()).await {
log::warn!("Error while dropping IpStackTcpStream: {:?}", err);
}
});
}
}
Expand Down

0 comments on commit a98e271

Please sign in to comment.