Skip to content

Commit

Permalink
Improve code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
SajjadPourali committed Mar 18, 2024
1 parent f5c141c commit 528ac24
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 45 deletions.
1 change: 1 addition & 0 deletions src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod tcp_flags {
pub const RST: u8 = 0b00000100;
pub const SYN: u8 = 0b00000010;
pub const FIN: u8 = 0b00000001;
pub const NON: u8 = 0b00000000;
}

#[derive(Debug, Clone)]
Expand Down
83 changes: 38 additions & 45 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
error::IpStackError,
packet::{tcp_flags, IpStackPacketProtocol, TcpPacket, TransportHeader},
packet::{
tcp_flags::{ACK, FIN, NON, PSH, RST, SYN},
IpStackPacketProtocol, TcpPacket, TransportHeader,
},
stream::tcb::{Tcb, TcpState},
DROP_TTL, TTL,
};
Expand Down Expand Up @@ -83,8 +86,7 @@ impl IpStackTcpStream {
write_notify: None,
};
if !tcp.inner().syn {
let flags = tcp_flags::RST | tcp_flags::ACK;
_ = pkt_sender.send(stream.create_rev_packet(flags, TTL, None, Vec::new())?);
_ = pkt_sender.send(stream.create_rev_packet(RST | ACK, TTL, None, Vec::new())?);
Err(IpStackError::InvalidTcpPacket)
} else {
Ok(stream)
Expand Down Expand Up @@ -117,19 +119,19 @@ impl IpStackTcpStream {
);

tcp_header.acknowledgment_number = self.tcb.get_ack();
if flags & tcp_flags::SYN != 0 {
if flags & SYN != 0 {
tcp_header.syn = true;
}
if flags & tcp_flags::ACK != 0 {
if flags & ACK != 0 {
tcp_header.ack = true;
}
if flags & tcp_flags::RST != 0 {
if flags & RST != 0 {
tcp_header.rst = true;
}
if flags & tcp_flags::FIN != 0 {
if flags & FIN != 0 {
tcp_header.fin = true;
}
if flags & tcp_flags::PSH != 0 {
if flags & PSH != 0 {
tcp_header.psh = true;
}

Expand Down Expand Up @@ -208,7 +210,7 @@ impl AsyncRead for IpStackTcpStream {
&& self.packet_to_send.is_none()
{
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
return Poll::Ready(Ok(()));
Expand All @@ -218,9 +220,8 @@ impl AsyncRead for IpStackTcpStream {
if matches!(Pin::new(&mut self.tcb.timeout).poll(cx), Poll::Ready(_)) {
#[cfg(feature = "log")]
trace!("timeout reached for {:?}", self.dst_addr);
let flags = tcp_flags::RST | tcp_flags::ACK;
self.packet_sender
.send(self.create_rev_packet(flags, TTL, None, Vec::new())?)
.send(self.create_rev_packet(RST | ACK, TTL, None, Vec::new())?)
.or(Err(ErrorKind::UnexpectedEof))?;
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
Expand All @@ -230,8 +231,8 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.reset_timeout();

if matches!(self.tcb.get_state(), TcpState::SynReceived(false)) {
let flags = tcp_flags::SYN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.packet_to_send =
Some(self.create_rev_packet(SYN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::SynReceived(true));
}
Expand All @@ -249,21 +250,21 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.add_ack(b.len() as u32);
buf.put_slice(&b);
self.packet_sender
.send(self.create_rev_packet(tcp_flags::ACK, TTL, None, Vec::new())?)
.send(self.create_rev_packet(ACK, TTL, None, Vec::new())?)
.or(Err(ErrorKind::UnexpectedEof))?;
return Poll::Ready(Ok(()));
}
if matches!(self.tcb.get_state(), TcpState::FinWait1(true)) {
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.add_seq_one();
self.tcb.change_state(TcpState::FinWait2(true));
continue;
} else if matches!(self.shutdown, Shutdown::Pending(_))
&& matches!(self.tcb.get_state(), TcpState::Established)
{
let flags = tcp_flags::FIN | tcp_flags::ACK;
self.packet_to_send = Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
self.packet_to_send =
Some(self.create_rev_packet(FIN | ACK, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait1(false));

continue;
Expand All @@ -273,9 +274,9 @@ impl AsyncRead for IpStackTcpStream {
let IpStackPacketProtocol::Tcp(t) = p.transport_protocol() else {
unreachable!()
};
if t.flags() & tcp_flags::RST != 0 {
if t.flags() & RST != 0 {
self.packet_to_send =
Some(self.create_rev_packet(0, DROP_TTL, None, Vec::new())?);
Some(self.create_rev_packet(NON, DROP_TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::Closed);
self.shutdown.ready();
return Poll::Ready(Err(Error::from(ErrorKind::ConnectionReset)));
Expand All @@ -288,13 +289,13 @@ impl AsyncRead for IpStackTcpStream {
}

if matches!(self.tcb.get_state(), TcpState::SynReceived(true)) {
if t.flags() == tcp_flags::ACK {
if t.flags() == ACK {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::Established);
}
} else if matches!(self.tcb.get_state(), TcpState::Established) {
if t.flags() == tcp_flags::ACK {
if t.flags() == ACK {
match self.tcb.check_pkt_type(&t, &p.payload) {
PacketStatus::WindowUpdate => {
self.tcb.change_send_window(t.inner().window_size);
Expand All @@ -308,12 +309,8 @@ impl AsyncRead for IpStackTcpStream {
PacketStatus::KeepAlive => {
self.tcb.change_last_ack(t.inner().acknowledgment_number);
self.tcb.change_send_window(t.inner().window_size);
self.packet_to_send = Some(self.create_rev_packet(
tcp_flags::ACK,
TTL,
None,
Vec::new(),
)?);
self.packet_to_send =
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
continue;
}
PacketStatus::RetransmissionRequest => {
Expand All @@ -328,7 +325,7 @@ impl AsyncRead for IpStackTcpStream {
// if t.inner().sequence_number != self.tcb.get_ack() {
// dbg!(t.inner().sequence_number);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// ACK,
// TTL,
// None,
// Vec::new(),
Expand All @@ -344,7 +341,7 @@ impl AsyncRead for IpStackTcpStream {
// buf.put_slice(&p.payload);
// self.tcb.add_ack(p.payload.len() as u32);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// ACK,
// TTL,
// None,
// Vec::new(),
Expand All @@ -368,15 +365,14 @@ impl AsyncRead for IpStackTcpStream {
}
};
}
if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
if t.flags() == (FIN | ACK) {
self.tcb.add_ack(1);
let flags = tcp_flags::ACK;
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
self.tcb.change_state(TcpState::FinWait1(true));
continue;
}
if t.flags() == (tcp_flags::PSH | tcp_flags::ACK) {
if t.flags() == (PSH | ACK) {
if !matches!(
self.tcb.check_pkt_type(&t, &p.payload),
PacketStatus::NewPacket
Expand All @@ -395,7 +391,7 @@ impl AsyncRead for IpStackTcpStream {
self.tcb.change_send_window(t.inner().window_size);
// buf.put_slice(&p.payload);
// self.packet_to_send = Some(self.create_rev_packet(
// tcp_flags::ACK,
// ACK,
// TTL,
// None,
// Vec::new(),
Expand All @@ -406,23 +402,22 @@ impl AsyncRead for IpStackTcpStream {
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait1(false)) {
if t.flags() == tcp_flags::ACK {
if t.flags() == ACK {
// panic!("ACK received in FinWait1");
self.tcb.add_ack(1);
self.tcb.change_state(TcpState::FinWait1(true));
continue;
} else if t.flags() == (tcp_flags::FIN | tcp_flags::ACK) {
let flags = tcp_flags::ACK;
} else if t.flags() == (FIN | ACK) {
self.tcb.add_seq_one();
self.tcb.add_ack(1);
self.packet_to_send =
Some(self.create_rev_packet(flags, TTL, None, Vec::new())?);
Some(self.create_rev_packet(ACK, TTL, None, Vec::new())?);
self.tcb.change_send_window(t.inner().window_size);
self.tcb.change_state(TcpState::FinWait2(false));
continue;
}
} else if matches!(self.tcb.get_state(), TcpState::FinWait2(true))
&& t.flags() == tcp_flags::ACK
&& t.flags() == ACK
{
self.tcb.change_state(TcpState::FinWait2(false));
}
Expand Down Expand Up @@ -461,8 +456,7 @@ impl AsyncWrite for IpStackTcpStream {
}
}

let flags = tcp_flags::PSH | tcp_flags::ACK;
let packet = self.create_rev_packet(flags, TTL, None, buf.to_vec())?;
let packet = self.create_rev_packet(PSH | ACK, TTL, None, buf.to_vec())?;
let seq = self.tcb.seq;
let payload_len = packet.payload.len();
let payload = packet.payload.clone();
Expand All @@ -488,8 +482,7 @@ impl AsyncWrite for IpStackTcpStream {
.and_then(|s| self.tcb.inflight_packets.iter().position(|p| p.seq == s))
.and_then(|p| self.tcb.inflight_packets.get(p))
{
let flags = tcp_flags::PSH | tcp_flags::ACK;
let packet = self.create_rev_packet(flags, TTL, i.seq, i.payload.to_vec())?;
let packet = self.create_rev_packet(PSH | ACK, TTL, i.seq, i.payload.to_vec())?;

self.packet_sender
.send(packet)
Expand Down Expand Up @@ -528,7 +521,7 @@ impl AsyncWrite for IpStackTcpStream {

impl Drop for IpStackTcpStream {
fn drop(&mut self) {
if let Ok(p) = self.create_rev_packet(0, DROP_TTL, None, Vec::new()) {
if let Ok(p) = self.create_rev_packet(NON, DROP_TTL, None, Vec::new()) {
_ = self.packet_sender.send(p);
}
}
Expand Down

0 comments on commit 528ac24

Please sign in to comment.