Skip to content

Commit

Permalink
[inetstack] Bug Fix: Properly close connecting sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
iyzhang authored and ppenna committed Jun 12, 2024
1 parent 708427e commit 5213d8b
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
+3 > S seq 0(0) win 65535 <mss 1450, wscale 0>

// Fail to connect.
+3 wait(500, ...) = ETIMEDOUT
+3 wait(500, ...) = ECONNREFUSED
96 changes: 74 additions & 22 deletions src/rust/inetstack/protocols/tcp/active_open.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
//======================================================================================================================

use crate::{
collections::async_queue::SharedAsyncQueue,
collections::{
async_queue::SharedAsyncQueue,
async_value::SharedAsyncValue,
},
expect_some,
inetstack::protocols::{
arp::SharedArpPeer,
Expand Down Expand Up @@ -37,6 +40,7 @@ use crate::{
},
},
runtime::{
conditional_yield_with_timeout,
fail::Fail,
memory::DemiBuffer,
network::{
Expand All @@ -50,7 +54,11 @@ use crate::{
SharedObject,
},
};
use ::futures::channel::mpsc;
use ::futures::{
channel::mpsc,
select_biased,
FutureExt,
};
use ::std::{
net::SocketAddrV4,
ops::{
Expand All @@ -63,6 +71,15 @@ use ::std::{
// Structures
//======================================================================================================================

/// States of a connecting socket.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum State {
/// The socket is listening for new connections.
Connecting,
/// The socket is closed.
Closed,
}

pub struct ActiveOpenSocket<N: NetworkRuntime> {
local_isn: SeqNumber,
local: SocketAddrV4,
Expand All @@ -76,6 +93,7 @@ pub struct ActiveOpenSocket<N: NetworkRuntime> {
socket_options: TcpSocketOptions,
arp: SharedArpPeer<N>,
dead_socket_tx: mpsc::UnboundedSender<QDesc>,
state: SharedAsyncValue<State>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -115,6 +133,7 @@ impl<N: NetworkRuntime> SharedActiveOpenSocket<N> {
socket_options: default_socket_options,
arp,
dead_socket_tx,
state: SharedAsyncValue::new(State::Connecting),
})))
}

Expand Down Expand Up @@ -251,17 +270,35 @@ impl<N: NetworkRuntime> SharedActiveOpenSocket<N> {
// Start connection handshake.
let handshake_retries: usize = self.tcp_config.get_handshake_retries();
let handshake_timeout = self.tcp_config.get_handshake_timeout();
for _ in 0..handshake_retries {
// Look up remote MAC address.
// TODO: Do we need to do this every iteration?
let remote_link_addr = match self.clone().arp.query(self.remote.ip().clone()).await {
Ok(r) => r,
// Look up remote MAC address.
let mut retries_left: usize = handshake_retries;
// Look up MAC address.
let remote_link_addr: MacAddress = loop {
match conditional_yield_with_timeout(
self.clone().arp.query(self.remote.ip().clone()).fuse(),
handshake_timeout,
)
.await
{
Ok(r) => break r?,
Err(e) if e.errno == libc::ETIMEDOUT && retries_left > 0 => {
retries_left = retries_left - 1;
},
Err(e) if e.errno == libc::ETIMEDOUT => {
let cause: String = format!("ARP query failed");
error!("connect(): {}", cause);
return Err(Fail::new(libc::ECONNREFUSED, &cause));
},
Err(e) => {
warn!("ARP query failed: {:?}", e);
continue;
let cause: String = format!("ARP query failed: {:?}", e);
error!("connect(): {}", cause);
return Err(e);
},
};
}
};

// Try to connect.
for _ in 0..handshake_retries {
// Set up SYN packet.
let mut tcp_hdr = TcpHeader::new(self.local.port(), self.remote.port());
tcp_hdr.syn = true;
Expand All @@ -287,24 +324,39 @@ impl<N: NetworkRuntime> SharedActiveOpenSocket<N> {
self.transport.transmit(Box::new(segment));

// Wait for either a response or timeout.
match self.recv_queue.pop(Some(handshake_timeout)).await {
let mut recv_queue: SharedAsyncQueue<(Ipv4Header, TcpHeader, DemiBuffer)> = self.recv_queue.clone();
let mut state: SharedAsyncValue<State> = self.state.clone();
select_biased! {
r = state.wait_for_change(None).fuse() => if let Ok(r) = r {
if r == State::Closed {
let cause: &str = "Closing socket while connecting";
warn!("{}", cause);
return Err(Fail::new(libc::ECONNABORTED, &cause));
}
},
r = recv_queue.pop(Some(handshake_timeout)).fuse() => match r {
Ok((_, header, _)) => match self.process_ack(header) {
Ok(socket) => return Ok(socket),
Err(Fail { errno, cause: _ }) if errno == libc::EAGAIN => continue,
Err(e) => return Err(e),
},
Err(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => continue,
Err(_) => {
unreachable!(
"either the ack deadline changed or the deadline passed, no other errors are possible!"
)
},
Ok(socket) => return Ok(socket),
Err(Fail { errno, cause: _ }) if errno == libc::EAGAIN => continue,
Err(e) => return Err(e),
},
Err(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => continue,
Err(_) => {
unreachable!(
"either the ack deadline changed or the deadline passed, no other errors are possible!"
)
},
}
}
}

let cause: String = format!("connection handshake timed out");
error!("connect(): {}", cause);
Err(Fail::new(libc::ETIMEDOUT, &cause))
Err(Fail::new(libc::ECONNREFUSED, &cause))
}

pub fn close(&mut self) {
self.state.set(State::Closed);
}

/// Returns the addresses of the two ends of this connection.
Expand Down
6 changes: 5 additions & 1 deletion src/rust/inetstack/protocols/tcp/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ impl<N: NetworkRuntime> SharedTcpPeer<N> {
pub async fn connect(&mut self, socket: &mut SharedTcpSocket<N>, remote: SocketAddrV4) -> Result<(), Fail> {
// Check whether we need to allocate an ephemeral port.
let local: SocketAddrV4 = match socket.local() {
Some(addr) => addr,
Some(addr) => {
// If socket is already bound to a local address, use it but remove the old binding.
self.addresses.remove(&SocketId::Passive(addr));
addr
},
None => {
let local_port: u16 = self.runtime.alloc_ephemeral_port()?;
SocketAddrV4::new(self.local_ipv4_addr, local_port)
Expand Down
14 changes: 6 additions & 8 deletions src/rust/inetstack/protocols/tcp/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,9 @@ impl<N: NetworkRuntime> SharedTcpSocket<N> {
Ok(Some(SocketId::Passive(socket.endpoint())))
},
// Closing a connecting socket.
SocketState::Connecting(_) => {
let cause: String = format!("cannot close a connecting socket");
error!("do_close(): {}", &cause);
Err(Fail::new(libc::ENOTSUP, &cause))
SocketState::Connecting(ref mut socket) => {
socket.close();
Ok(Some(SocketId::Active(socket.endpoints().0, socket.endpoints().1)))
},
// Closing a closing socket.
SocketState::Closing(_) => {
Expand All @@ -310,10 +309,9 @@ impl<N: NetworkRuntime> SharedTcpSocket<N> {
Ok(Some(SocketId::Passive(socket.endpoint())))
},
// Closing a connecting socket.
SocketState::Connecting(_) => {
let cause: String = format!("cannot close a connecting socket");
error!("do_close(): {}", &cause);
Err(Fail::new(libc::ENOTSUP, &cause))
SocketState::Connecting(ref mut socket) => {
socket.close();
Ok(Some(SocketId::Active(socket.endpoints().0, socket.endpoints().1)))
},
// Closing a closing socket.
SocketState::Closing(_) => {
Expand Down
2 changes: 1 addition & 1 deletion tools/demikernel_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def run_pipeline(
# STEP 4: Run integration tests.
if test_integration:
if status["checkout"] and status["compile"]:
if libos == "catnap" or libos == "catnapw" or libos == "catloop":
if libos == "catnap" or libos == "catnapw" or libos == "catloop" or libos == "catnip" or libos == "catpowder":
status["integration_tests"] = factory.integration_test().execute()
elif libos == "catmem":
status["integration_tests"] = factory.integration_test("standalone").execute()
Expand Down

0 comments on commit 5213d8b

Please sign in to comment.