diff --git a/monarch_hyperactor/src/bootstrap.rs b/monarch_hyperactor/src/bootstrap.rs index 0188d6b94..63e19cd22 100644 --- a/monarch_hyperactor/src/bootstrap.rs +++ b/monarch_hyperactor/src/bootstrap.rs @@ -7,7 +7,10 @@ */ use futures::future::try_join_all; +use hyperactor::channel; use hyperactor::channel::ChannelAddr; +use hyperactor::channel::Rx; +use hyperactor::channel::Tx; use hyperactor_mesh::Bootstrap; use hyperactor_mesh::bootstrap::BootstrapCommand; use hyperactor_mesh::bootstrap_or_die; @@ -19,13 +22,17 @@ use pyo3::Bound; use pyo3::PyAny; use pyo3::PyResult; use pyo3::Python; +use pyo3::pyclass; use pyo3::pyfunction; +use pyo3::pymethods; use pyo3::types::PyAnyMethods; use pyo3::types::PyModule; use pyo3::types::PyModuleMethods; use pyo3::wrap_pyfunction; use crate::pytokio::PyPythonTask; +use crate::runtime::get_tokio_runtime; +use crate::runtime::signal_safe_block_on; use crate::v1::host_mesh::PyHostMesh; #[pyfunction] @@ -134,6 +141,86 @@ pub fn attach_to_workers<'py>( }) } +/// Python wrapper for ChannelTx that sends bytes messages. +#[pyclass( + name = "ChannelTx", + module = "monarch._rust_bindings.monarch_hyperactor.bootstrap" +)] +pub struct PyChannelTx { + inner: std::sync::Arc>>, +} + +#[pymethods] +impl PyChannelTx { + /// Send a message (bytes) on the channel. Returns a PyPythonTask that completes when the message has been delivered. + fn send(&self, message: &[u8]) -> PyResult { + let inner = self.inner.clone(); + let message = message.to_vec(); + + PyPythonTask::new(async move { + inner + .send(message) + .await + .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e.0))?; + Ok(()) + }) + } +} + +/// Python wrapper for ChannelRx that receives bytes messages. +#[pyclass( + name = "ChannelRx", + module = "monarch._rust_bindings.monarch_hyperactor.bootstrap" +)] +pub struct PyChannelRx { + inner: std::sync::Arc>>>, +} + +#[pymethods] +impl PyChannelRx { + /// Receive the next message (bytes) from the channel. Returns a PyPythonTask that completes with the message bytes. + fn recv(&self) -> PyResult { + let inner = self.inner.clone(); + + PyPythonTask::new(async move { + let mut rx = inner.lock().await; + let message = rx + .recv() + .await + .map_err(|e| anyhow::anyhow!("Failed to receive message: {}", e))?; + Ok(message) + }) + } +} + +/// Dial a channel address and return a transmitter for sending bytes. +#[pyfunction] +pub fn dial(address: &str) -> PyResult { + let addr = ChannelAddr::from_zmq_url(address)?; + let tx = channel::dial::>(addr).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to dial channel: {}", e)) + })?; + Ok(PyChannelTx { + inner: std::sync::Arc::new(tx), + }) +} + +/// Serve on a channel address and return a tuple of (address, receiver) for receiving bytes. +#[pyfunction] +pub fn serve(py: Python<'_>, address: &str) -> PyResult { + let addr = ChannelAddr::from_zmq_url(address)?; + + let rx = signal_safe_block_on(py, async move { + let (_, rx) = channel::serve::>(addr).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to serve channel: {}", e)) + })?; + Ok::>, pyo3::PyErr>(rx) + })??; + Ok(PyChannelRx { + inner: std::sync::Arc::new(tokio::sync::Mutex::new(rx)), + }) +} + pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { let f = wrap_pyfunction!(bootstrap_main, hyperactor_mod)?; f.setattr( @@ -156,5 +243,22 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul )?; hyperactor_mod.add_function(f)?; + let f = wrap_pyfunction!(dial, hyperactor_mod)?; + f.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.bootstrap", + )?; + hyperactor_mod.add_function(f)?; + + let f = wrap_pyfunction!(serve, hyperactor_mod)?; + f.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.bootstrap", + )?; + hyperactor_mod.add_function(f)?; + + hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; + Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi index a2ad0f2a1..e6eeaa4e8 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi @@ -20,3 +20,102 @@ def run_worker_loop_forever(address: str) -> PythonTask[None]: ... def attach_to_workers( workers: List[PythonTask[str]], name: Optional[str] = None ) -> PythonTask[HostMesh]: ... + +class ChannelTx: + """ + A channel transmitter for sending bytes messages over network connections. + + Supports TCP, Unix domain sockets, and in-process channels using ZMQ-style + addressing (e.g., "tcp://hostname:port", "unix:/path", "inproc://port"). + """ + + def send(self, message: bytes) -> PythonTask[None]: + """ + Send a message on the channel and wait for delivery confirmation. + + Args: + message: The bytes message to send. + + Returns: + A PythonTask that completes when the message has been delivered to the + remote end of the channel. + + Raises: + RuntimeError: If the message fails to send. + """ + ... + +class ChannelRx: + """ + A channel receiver for receiving bytes messages over network connections. + + Receives messages from multiple concurrent connections on a single address. + """ + + def recv(self) -> PythonTask[bytes]: + """ + Receive the next message from the channel. + + Returns: + A PythonTask that completes with the received message bytes. + + Raises: + RuntimeError: If the channel is closed or an error occurs. + """ + ... + +def dial(address: str) -> ChannelTx: + """ + Dial a channel address and return a transmitter for sending bytes. + + Establishes a connection to the specified address and returns a ChannelTx + that can be used to send messages. The connection uses reliable delivery + with automatic reconnection and retransmission. + + Args: + address: The channel address in ZMQ-style URL format: + - "tcp://hostname:port" - TCP connection to hostname:port + - "unix:/path/to/socket" - Unix domain socket + - "inproc://name" - In-process channel (local only) + - "ipc://path" - IPC socket (equivalent to unix) + - "metatls://hostname:port" - TLS connection (Meta internal) + + Returns: + A ChannelTx instance for sending messages. + + Raises: + RuntimeError: If the address format is invalid or connection fails. + + Example: + >>> tx = dial("tcp://localhost:12345") + >>> await tx.send(b"Hello, world!") + """ + ... + +def serve(address: str) -> ChannelRx: + """ + Serve on a channel address and return a receiver for accepting connections. + + Starts listening on the specified address and returns a ChannelRx for + receiving messages from multiple concurrent connections. + + Args: + address: The channel address in ZMQ-style URL format: + - "tcp://*:0" - Listen on any available port on all interfaces + - "tcp://hostname:port" - Listen on specific hostname and port + - "unix:/path/to/socket" - Unix domain socket + - "inproc://name" - In-process channel (local only) + - "ipc://path" - IPC socket (equivalent to unix) + - "metatls://*:port" - TLS listener (Meta internal) + + Returns: + A ChannelRx for receiving messages. + + Raises: + RuntimeError: If the address format is invalid or binding fails. + + Example: + >>> rx = serve("tcp://*:8080") + >>> message = await rx.recv() + """ + ... diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index d6659ef64..56dabc1da 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -1601,6 +1601,51 @@ def test_simple_bootstrap(): proc.wait() +def test_channel_dial_serve(): + """Test that dial and serve functions work for sending and receiving bytes messages.""" + from monarch._rust_bindings.monarch_hyperactor.bootstrap import dial, serve + + # Setup: Start a server on a fixed port + rx = serve("tcp://*:18765") + + # Setup: Create a client that connects to the server + tx = dial("tcp://localhost:18765") + + # Execute: Send a message from client to server + test_message = b"Hello, Monarch!" + tx.send(test_message).block_on() + + # Execute: Receive the message on the server + received = rx.recv().block_on() + + # Assert: The received message matches what was sent + assert received == test_message + + +def test_channel_unix_socket(): + """Test that dial and serve work with Unix domain sockets.""" + from monarch._rust_bindings.monarch_hyperactor.bootstrap import dial, serve + + with TemporaryDirectory() as tmpdir: + # Setup: Create a Unix socket address + socket_path = os.path.join(tmpdir, "test.sock") + unix_addr = f"ipc://{socket_path}" + + # Setup: Start a server on the Unix socket + rx = serve(unix_addr) + + # Setup: Create a client + tx = dial(unix_addr) + + # Execute: Send and receive a message + test_message = b"Unix socket test" + tx.send(test_message).block_on() + received = rx.recv().block_on() + + # Assert: Message was received correctly + assert received == test_message + + class HostMeshActor(Actor): @endpoint async def this_host(self) -> HostMesh: