diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index f66097b87..3676beaeb 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -343,8 +343,8 @@ pub trait Alloc { } /// The address that should be used to serve the client's router. - fn client_router_addr(&self) -> AllocAssignedAddr { - AllocAssignedAddr(ChannelAddr::any(self.transport())) + fn client_router_addr(&self) -> ChannelAddr { + ChannelAddr::any(self.transport()) } } @@ -479,92 +479,80 @@ impl AllocExt for A { } } -/// A new type to indicate this addr is assigned by alloc. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AllocAssignedAddr(ChannelAddr); - -impl AllocAssignedAddr { - pub(crate) fn new(addr: ChannelAddr) -> AllocAssignedAddr { - AllocAssignedAddr(addr) +/// If addr is Tcp or Metatls, use its IP address or hostname to create +/// a new addr with port unspecified. +/// +/// for other types of addr, return "any" address. +pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> ChannelAddr { + match addr { + ChannelAddr::Tcp(socket) => { + let mut new_socket = socket.clone(); + new_socket.set_port(0); + ChannelAddr::Tcp(new_socket) + } + ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => { + let mut new_socket = socket.clone(); + new_socket.set_port(0); + ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket)) + } + ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => { + ChannelAddr::MetaTls(MetaTlsAddr::Host { + hostname: hostname.clone(), + port: 0, + }) + } + _ => addr.transport().any(), } +} - /// If addr is Tcp or Metatls, use its IP address or hostname to create - /// a new addr with port unspecified. - /// - /// for other types of addr, return "any" address. - pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> AllocAssignedAddr { - let new_addr = match addr { - ChannelAddr::Tcp(socket) => { - let mut new_socket = socket.clone(); - new_socket.set_port(0); - ChannelAddr::Tcp(new_socket) - } - ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => { - let mut new_socket = socket.clone(); - new_socket.set_port(0); - ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket)) - } - ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => { - ChannelAddr::MetaTls(MetaTlsAddr::Host { - hostname: hostname.clone(), - port: 0, - }) - } - _ => addr.transport().any(), +pub(crate) fn serve_with_config( + mut serve_addr: ChannelAddr, +) -> anyhow::Result<(ChannelAddr, ChannelRx)> { + fn set_as_inaddr_any(original: &mut SocketAddr) { + let inaddr_any: IpAddr = match &original { + SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(), + SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(), }; - AllocAssignedAddr(new_addr) + original.set_ip(inaddr_any); } - pub(crate) fn serve_with_config( - self, - ) -> anyhow::Result<(ChannelAddr, ChannelRx)> { - fn set_as_inaddr_any(original: &mut SocketAddr) { - let inaddr_any: IpAddr = match &original { - SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(), - SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(), - }; - original.set_ip(inaddr_any); - } - - let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY); - let mut bind_to = self.0; - let mut original_ip: Option = None; - match &mut bind_to { - ChannelAddr::Tcp(socket) => { - original_ip = Some(socket.ip().clone()); - if use_inaddr_any { - set_as_inaddr_any(socket); - tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),); - } - if socket.port() == 0 { - socket.set_port(next_allowed_port(socket.ip().clone())?); - } + let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY); + let mut original_ip: Option = None; + match &mut serve_addr { + ChannelAddr::Tcp(socket) => { + original_ip = Some(socket.ip().clone()); + if use_inaddr_any { + set_as_inaddr_any(socket); + tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),); } - _ => { - if use_inaddr_any { - tracing::debug!( - "can only bind to INADDR_ANY for TCP; got transport {}, addr {}", - bind_to.transport(), - bind_to - ); - } + if socket.port() == 0 { + socket.set_port(next_allowed_port(socket.ip().clone())?); } - }; + } + _ => { + if use_inaddr_any { + tracing::debug!( + "can only bind to INADDR_ANY for TCP; got transport {}, addr {}", + serve_addr.transport(), + serve_addr + ); + } + } + }; - let (mut bound, rx) = channel::serve(bind_to)?; + let (mut bound, rx) = channel::serve(serve_addr)?; - // Restore the original IP address if we used INADDR_ANY. - match &mut bound { - ChannelAddr::Tcp(socket) => { - if use_inaddr_any { - socket.set_ip(original_ip.unwrap()); - } + // Restore the original IP address if we used INADDR_ANY. + match &mut bound { + ChannelAddr::Tcp(socket) => { + if use_inaddr_any { + socket.set_ip(original_ip.unwrap()); } - _ => (), } - - Ok((bound, rx)) + _ => (), } + + Ok((bound, rx)) } enum AllowedPorts { diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index f54a17fe1..a49868755 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -59,7 +59,6 @@ use tokio_stream::wrappers::WatchStream; use tokio_util::sync::CancellationToken; use crate::alloc::Alloc; -use crate::alloc::AllocAssignedAddr; use crate::alloc::AllocConstraints; use crate::alloc::AllocSpec; use crate::alloc::Allocator; @@ -70,6 +69,8 @@ use crate::alloc::ProcessAllocator; use crate::alloc::REMOTE_ALLOC_BOOTSTRAP_ADDR; use crate::alloc::process::CLIENT_TRACE_ID_LABEL; use crate::alloc::process::ClientContext; +use crate::alloc::serve_with_config; +use crate::alloc::with_unspecified_port_or_any; use crate::shortuuid::ShortUuid; /// Control messages sent from remote process allocator to local allocator. @@ -91,7 +92,7 @@ pub enum RemoteProcessAllocatorMessage { /// the client_context will go to the message header instead client_context: Option, /// The address allocator should use for its forwarder. - forwarder_addr: AllocAssignedAddr, + forwarder_addr: ChannelAddr, }, /// Stop allocation. Stop, @@ -317,11 +318,11 @@ impl RemoteProcessAllocator { bootstrap_addr: ChannelAddr, hosts: Vec, cancel_token: CancellationToken, - forwarder_addr: AllocAssignedAddr, + forwarder_addr: ChannelAddr, ) { tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}"); // start proc message forwarder - let (forwarder_addr, forwarder_rx) = match forwarder_addr.serve_with_config() { + let (forwarder_addr, forwarder_rx) = match serve_with_config(forwarder_addr) { Ok(v) => v, Err(e) => { tracing::error!("failed to to bootstrap forwarder actor: {}", e); @@ -626,11 +627,11 @@ impl RemoteProcessAlloc { initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static, ) -> Result { let alloc_serve_addr = match config::global::try_get_cloned(REMOTE_ALLOC_BOOTSTRAP_ADDR) { - Some(addr_str) => AllocAssignedAddr::new(addr_str.parse()?), - None => AllocAssignedAddr::new(ChannelAddr::any(spec.transport.clone())), + Some(addr_str) => addr_str.parse()?, + None => ChannelAddr::any(spec.transport.clone()), }; - let (bootstrap_addr, rx) = alloc_serve_addr.serve_with_config()?; + let (bootstrap_addr, rx) = serve_with_config(alloc_serve_addr)?; tracing::info!( "starting alloc for {} on: {}", @@ -825,7 +826,7 @@ impl RemoteProcessAlloc { // its host's private IP address, while its known addres to // alloc is a public IP address. In some environment, that // could lead to port unreachable error. - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&remote_addr), + forwarder_addr: with_unspecified_port_or_any(&remote_addr), }; tracing::info!( name = message.as_ref(), @@ -1208,8 +1209,8 @@ impl Alloc for RemoteProcessAlloc { /// one could lead to port unreachable error. /// /// For other channel types, this method still uses ChannelAddr::any. - fn client_router_addr(&self) -> AllocAssignedAddr { - AllocAssignedAddr::with_unspecified_port_or_any(&self.bootstrap_addr) + fn client_router_addr(&self) -> ChannelAddr { + with_unspecified_port_or_any(&self.bootstrap_addr) } } @@ -1236,6 +1237,7 @@ mod test { use crate::alloc::MockAllocWrapper; use crate::alloc::MockAllocator; use crate::alloc::ProcStopReason; + use crate::alloc::with_unspecified_port_or_any; use crate::proc_mesh::mesh_agent::ProcMeshAgent; async fn read_all_created(rx: &mut ChannelRx, alloc_len: usize) { @@ -1372,7 +1374,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1526,7 +1528,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1631,7 +1633,7 @@ mod test { bootstrap_addr: bootstrap_addr.clone(), hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1656,7 +1658,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1754,7 +1756,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1846,7 +1848,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -1941,7 +1943,7 @@ mod test { client_context: Some(ClientContext { trace_id: test_trace_id.to_string(), }), - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); @@ -2016,7 +2018,7 @@ mod test { bootstrap_addr, hosts: vec![], client_context: None, - forwarder_addr: AllocAssignedAddr::with_unspecified_port_or_any(&tx.addr()), + forwarder_addr: with_unspecified_port_or_any(&tx.addr()), }) .await .unwrap(); diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 6e43ca654..a02f75c92 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -68,6 +68,7 @@ use crate::alloc::AllocatedProc; use crate::alloc::AllocatorError; use crate::alloc::ProcState; use crate::alloc::ProcStopReason; +use crate::alloc::serve_with_config; use crate::assign::Ranks; use crate::comm::CommActorMode; use crate::proc_mesh::mesh_agent::GspawnResult; @@ -379,10 +380,8 @@ impl ProcMesh { ); // Ensure that the router is served so that agents may reach us. - let (router_channel_addr, router_rx) = alloc - .client_router_addr() - .serve_with_config() - .map_err(AllocatorError::Other)?; + let (router_channel_addr, router_rx) = + serve_with_config(alloc.client_router_addr()).map_err(AllocatorError::Other)?; router.serve(router_rx); tracing::info!("router channel started listening on addr: {router_channel_addr}");