Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 64 additions & 76 deletions hyperactor_mesh/src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ pub trait Alloc {
}

/// The address that should be used to serve the client's router.
fn client_router_addr(&self) -> AllocAssignedAddr {
AllocAssignedAddr(ChannelAddr::any(self.transport()))
fn client_router_addr(&self) -> ChannelAddr {
ChannelAddr::any(self.transport())
}
}

Expand Down Expand Up @@ -479,92 +479,80 @@ impl<A: ?Sized + Send + Alloc> AllocExt for A {
}
}

/// A new type to indicate this addr is assigned by alloc.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllocAssignedAddr(ChannelAddr);

impl AllocAssignedAddr {
pub(crate) fn new(addr: ChannelAddr) -> AllocAssignedAddr {
AllocAssignedAddr(addr)
/// If addr is Tcp or Metatls, use its IP address or hostname to create
/// a new addr with port unspecified.
///
/// for other types of addr, return "any" address.
pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> ChannelAddr {
match addr {
ChannelAddr::Tcp(socket) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::Tcp(new_socket)
}
ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket))
}
ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => {
ChannelAddr::MetaTls(MetaTlsAddr::Host {
hostname: hostname.clone(),
port: 0,
})
}
_ => addr.transport().any(),
}
}

/// If addr is Tcp or Metatls, use its IP address or hostname to create
/// a new addr with port unspecified.
///
/// for other types of addr, return "any" address.
pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> AllocAssignedAddr {
let new_addr = match addr {
ChannelAddr::Tcp(socket) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::Tcp(new_socket)
}
ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket))
}
ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => {
ChannelAddr::MetaTls(MetaTlsAddr::Host {
hostname: hostname.clone(),
port: 0,
})
}
_ => addr.transport().any(),
pub(crate) fn serve_with_config<M: RemoteMessage>(
mut serve_addr: ChannelAddr,
) -> anyhow::Result<(ChannelAddr, ChannelRx<M>)> {
fn set_as_inaddr_any(original: &mut SocketAddr) {
let inaddr_any: IpAddr = match &original {
SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(),
SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(),
};
AllocAssignedAddr(new_addr)
original.set_ip(inaddr_any);
}

pub(crate) fn serve_with_config<M: RemoteMessage>(
self,
) -> anyhow::Result<(ChannelAddr, ChannelRx<M>)> {
fn set_as_inaddr_any(original: &mut SocketAddr) {
let inaddr_any: IpAddr = match &original {
SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(),
SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(),
};
original.set_ip(inaddr_any);
}

let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY);
let mut bind_to = self.0;
let mut original_ip: Option<IpAddr> = None;
match &mut bind_to {
ChannelAddr::Tcp(socket) => {
original_ip = Some(socket.ip().clone());
if use_inaddr_any {
set_as_inaddr_any(socket);
tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),);
}
if socket.port() == 0 {
socket.set_port(next_allowed_port(socket.ip().clone())?);
}
let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY);
let mut original_ip: Option<IpAddr> = None;
match &mut serve_addr {
ChannelAddr::Tcp(socket) => {
original_ip = Some(socket.ip().clone());
if use_inaddr_any {
set_as_inaddr_any(socket);
tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),);
}
_ => {
if use_inaddr_any {
tracing::debug!(
"can only bind to INADDR_ANY for TCP; got transport {}, addr {}",
bind_to.transport(),
bind_to
);
}
if socket.port() == 0 {
socket.set_port(next_allowed_port(socket.ip().clone())?);
}
};
}
_ => {
if use_inaddr_any {
tracing::debug!(
"can only bind to INADDR_ANY for TCP; got transport {}, addr {}",
serve_addr.transport(),
serve_addr
);
}
}
};

let (mut bound, rx) = channel::serve(bind_to)?;
let (mut bound, rx) = channel::serve(serve_addr)?;

// Restore the original IP address if we used INADDR_ANY.
match &mut bound {
ChannelAddr::Tcp(socket) => {
if use_inaddr_any {
socket.set_ip(original_ip.unwrap());
}
// Restore the original IP address if we used INADDR_ANY.
match &mut bound {
ChannelAddr::Tcp(socket) => {
if use_inaddr_any {
socket.set_ip(original_ip.unwrap());
}
_ => (),
}

Ok((bound, rx))
_ => (),
}

Ok((bound, rx))
}

enum AllowedPorts {
Expand Down
38 changes: 20 additions & 18 deletions hyperactor_mesh/src/alloc/remoteprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ use tokio_stream::wrappers::WatchStream;
use tokio_util::sync::CancellationToken;

use crate::alloc::Alloc;
use crate::alloc::AllocAssignedAddr;
use crate::alloc::AllocConstraints;
use crate::alloc::AllocSpec;
use crate::alloc::Allocator;
Expand All @@ -70,6 +69,8 @@ use crate::alloc::ProcessAllocator;
use crate::alloc::REMOTE_ALLOC_BOOTSTRAP_ADDR;
use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
use crate::alloc::process::ClientContext;
use crate::alloc::serve_with_config;
use crate::alloc::with_unspecified_port_or_any;
use crate::shortuuid::ShortUuid;

/// Control messages sent from remote process allocator to local allocator.
Expand All @@ -91,7 +92,7 @@ pub enum RemoteProcessAllocatorMessage {
/// the client_context will go to the message header instead
client_context: Option<ClientContext>,
/// The address allocator should use for its forwarder.
forwarder_addr: AllocAssignedAddr,
forwarder_addr: ChannelAddr,
},
/// Stop allocation.
Stop,
Expand Down Expand Up @@ -317,11 +318,11 @@ impl RemoteProcessAllocator {
bootstrap_addr: ChannelAddr,
hosts: Vec<String>,
cancel_token: CancellationToken,
forwarder_addr: AllocAssignedAddr,
forwarder_addr: ChannelAddr,
) {
tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
// start proc message forwarder
let (forwarder_addr, forwarder_rx) = match forwarder_addr.serve_with_config() {
let (forwarder_addr, forwarder_rx) = match serve_with_config(forwarder_addr) {
Ok(v) => v,
Err(e) => {
tracing::error!("failed to to bootstrap forwarder actor: {}", e);
Expand Down Expand Up @@ -626,11 +627,11 @@ impl RemoteProcessAlloc {
initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
) -> Result<Self, anyhow::Error> {
let alloc_serve_addr = match config::global::try_get_cloned(REMOTE_ALLOC_BOOTSTRAP_ADDR) {
Some(addr_str) => AllocAssignedAddr::new(addr_str.parse()?),
None => AllocAssignedAddr::new(ChannelAddr::any(spec.transport.clone())),
Some(addr_str) => addr_str.parse()?,
None => ChannelAddr::any(spec.transport.clone()),
};

let (bootstrap_addr, rx) = alloc_serve_addr.serve_with_config()?;
let (bootstrap_addr, rx) = serve_with_config(alloc_serve_addr)?;

tracing::info!(
"starting alloc for {} on: {}",
Expand Down Expand Up @@ -825,7 +826,7 @@ impl RemoteProcessAlloc {
// its host's private IP address, while its known addres to
// alloc is a public IP address. In some environment, that
// could lead to port unreachable error.
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&remote_addr),
forwarder_addr: with_unspecified_port_or_any(&remote_addr),
};
tracing::info!(
name = message.as_ref(),
Expand Down Expand Up @@ -1208,8 +1209,8 @@ impl Alloc for RemoteProcessAlloc {
/// one could lead to port unreachable error.
///
/// For other channel types, this method still uses ChannelAddr::any.
fn client_router_addr(&self) -> AllocAssignedAddr {
AllocAssignedAddr::with_unspecified_port_or_any(&self.bootstrap_addr)
fn client_router_addr(&self) -> ChannelAddr {
with_unspecified_port_or_any(&self.bootstrap_addr)
}
}

Expand All @@ -1236,6 +1237,7 @@ mod test {
use crate::alloc::MockAllocWrapper;
use crate::alloc::MockAllocator;
use crate::alloc::ProcStopReason;
use crate::alloc::with_unspecified_port_or_any;
use crate::proc_mesh::mesh_agent::ProcMeshAgent;

async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
Expand Down Expand Up @@ -1372,7 +1374,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -1526,7 +1528,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -1631,7 +1633,7 @@ mod test {
bootstrap_addr: bootstrap_addr.clone(),
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand All @@ -1656,7 +1658,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -1754,7 +1756,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -1846,7 +1848,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -1941,7 +1943,7 @@ mod test {
client_context: Some(ClientContext {
trace_id: test_trace_id.to_string(),
}),
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down Expand Up @@ -2016,7 +2018,7 @@ mod test {
bootstrap_addr,
hosts: vec![],
client_context: None,
forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()),
forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
})
.await
.unwrap();
Expand Down
7 changes: 3 additions & 4 deletions hyperactor_mesh/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ use crate::alloc::AllocatedProc;
use crate::alloc::AllocatorError;
use crate::alloc::ProcState;
use crate::alloc::ProcStopReason;
use crate::alloc::serve_with_config;
use crate::assign::Ranks;
use crate::comm::CommActorMode;
use crate::proc_mesh::mesh_agent::GspawnResult;
Expand Down Expand Up @@ -379,10 +380,8 @@ impl ProcMesh {
);

// Ensure that the router is served so that agents may reach us.
let (router_channel_addr, router_rx) = alloc
.client_router_addr()
.serve_with_config()
.map_err(AllocatorError::Other)?;
let (router_channel_addr, router_rx) =
serve_with_config(alloc.client_router_addr()).map_err(AllocatorError::Other)?;
router.serve(router_rx);
tracing::info!("router channel started listening on addr: {router_channel_addr}");

Expand Down
Loading