Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions monarch_hyperactor/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
*/

use futures::future::try_join_all;
use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::Rx;
use hyperactor::channel::Tx;
use hyperactor_mesh::Bootstrap;
use hyperactor_mesh::bootstrap::BootstrapCommand;
use hyperactor_mesh::bootstrap_or_die;
Expand All @@ -19,13 +22,17 @@ use pyo3::Bound;
use pyo3::PyAny;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::pyclass;
use pyo3::pyfunction;
use pyo3::pymethods;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::wrap_pyfunction;

use crate::pytokio::PyPythonTask;
use crate::runtime::get_tokio_runtime;
use crate::runtime::signal_safe_block_on;
use crate::v1::host_mesh::PyHostMesh;

#[pyfunction]
Expand Down Expand Up @@ -134,6 +141,86 @@ pub fn attach_to_workers<'py>(
})
}

/// Python wrapper for ChannelTx that sends bytes messages.
#[pyclass(
name = "ChannelTx",
module = "monarch._rust_bindings.monarch_hyperactor.bootstrap"
)]
pub struct PyChannelTx {
inner: std::sync::Arc<channel::ChannelTx<Vec<u8>>>,
}

#[pymethods]
impl PyChannelTx {
/// Send a message (bytes) on the channel. Returns a PyPythonTask that completes when the message has been delivered.
fn send(&self, message: &[u8]) -> PyResult<PyPythonTask> {
let inner = self.inner.clone();
let message = message.to_vec();

PyPythonTask::new(async move {
inner
.send(message)
.await
.map_err(|e| anyhow::anyhow!("Failed to send message: {}", e.0))?;
Ok(())
})
}
}

/// Python wrapper for ChannelRx that receives bytes messages.
#[pyclass(
name = "ChannelRx",
module = "monarch._rust_bindings.monarch_hyperactor.bootstrap"
)]
pub struct PyChannelRx {
inner: std::sync::Arc<tokio::sync::Mutex<channel::ChannelRx<Vec<u8>>>>,
}

#[pymethods]
impl PyChannelRx {
/// Receive the next message (bytes) from the channel. Returns a PyPythonTask that completes with the message bytes.
fn recv(&self) -> PyResult<PyPythonTask> {
let inner = self.inner.clone();

PyPythonTask::new(async move {
let mut rx = inner.lock().await;
let message = rx
.recv()
.await
.map_err(|e| anyhow::anyhow!("Failed to receive message: {}", e))?;
Ok(message)
})
}
}

/// Dial a channel address and return a transmitter for sending bytes.
#[pyfunction]
pub fn dial(address: &str) -> PyResult<PyChannelTx> {
let addr = ChannelAddr::from_zmq_url(address)?;
let tx = channel::dial::<Vec<u8>>(addr).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to dial channel: {}", e))
})?;
Ok(PyChannelTx {
inner: std::sync::Arc::new(tx),
})
}

/// Serve on a channel address and return a tuple of (address, receiver) for receiving bytes.
#[pyfunction]
pub fn serve(py: Python<'_>, address: &str) -> PyResult<PyChannelRx> {
let addr = ChannelAddr::from_zmq_url(address)?;

let rx = signal_safe_block_on(py, async move {
let (_, rx) = channel::serve::<Vec<u8>>(addr).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to serve channel: {}", e))
})?;
Ok::<channel::ChannelRx<Vec<u8>>, pyo3::PyErr>(rx)
})??;
Ok(PyChannelRx {
inner: std::sync::Arc::new(tokio::sync::Mutex::new(rx)),
})
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
let f = wrap_pyfunction!(bootstrap_main, hyperactor_mod)?;
f.setattr(
Expand All @@ -156,5 +243,22 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
)?;
hyperactor_mod.add_function(f)?;

let f = wrap_pyfunction!(dial, hyperactor_mod)?;
f.setattr(
"__module__",
"monarch._rust_bindings.monarch_hyperactor.bootstrap",
)?;
hyperactor_mod.add_function(f)?;

let f = wrap_pyfunction!(serve, hyperactor_mod)?;
f.setattr(
"__module__",
"monarch._rust_bindings.monarch_hyperactor.bootstrap",
)?;
hyperactor_mod.add_function(f)?;

hyperactor_mod.add_class::<PyChannelTx>()?;
hyperactor_mod.add_class::<PyChannelRx>()?;

Ok(())
}
99 changes: 99 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,102 @@ def run_worker_loop_forever(address: str) -> PythonTask[None]: ...
def attach_to_workers(
workers: List[PythonTask[str]], name: Optional[str] = None
) -> PythonTask[HostMesh]: ...

class ChannelTx:
"""
A channel transmitter for sending bytes messages over network connections.

Supports TCP, Unix domain sockets, and in-process channels using ZMQ-style
addressing (e.g., "tcp://hostname:port", "unix:/path", "inproc://port").
"""

def send(self, message: bytes) -> PythonTask[None]:
"""
Send a message on the channel and wait for delivery confirmation.

Args:
message: The bytes message to send.

Returns:
A PythonTask that completes when the message has been delivered to the
remote end of the channel.

Raises:
RuntimeError: If the message fails to send.
"""
...

class ChannelRx:
"""
A channel receiver for receiving bytes messages over network connections.

Receives messages from multiple concurrent connections on a single address.
"""

def recv(self) -> PythonTask[bytes]:
"""
Receive the next message from the channel.

Returns:
A PythonTask that completes with the received message bytes.

Raises:
RuntimeError: If the channel is closed or an error occurs.
"""
...

def dial(address: str) -> ChannelTx:
"""
Dial a channel address and return a transmitter for sending bytes.

Establishes a connection to the specified address and returns a ChannelTx
that can be used to send messages. The connection uses reliable delivery
with automatic reconnection and retransmission.

Args:
address: The channel address in ZMQ-style URL format:
- "tcp://hostname:port" - TCP connection to hostname:port
- "unix:/path/to/socket" - Unix domain socket
- "inproc://name" - In-process channel (local only)
- "ipc://path" - IPC socket (equivalent to unix)
- "metatls://hostname:port" - TLS connection (Meta internal)

Returns:
A ChannelTx instance for sending messages.

Raises:
RuntimeError: If the address format is invalid or connection fails.

Example:
>>> tx = dial("tcp://localhost:12345")
>>> await tx.send(b"Hello, world!")
"""
...

def serve(address: str) -> ChannelRx:
"""
Serve on a channel address and return a receiver for accepting connections.

Starts listening on the specified address and returns a ChannelRx for
receiving messages from multiple concurrent connections.

Args:
address: The channel address in ZMQ-style URL format:
- "tcp://*:0" - Listen on any available port on all interfaces
- "tcp://hostname:port" - Listen on specific hostname and port
- "unix:/path/to/socket" - Unix domain socket
- "inproc://name" - In-process channel (local only)
- "ipc://path" - IPC socket (equivalent to unix)
- "metatls://*:port" - TLS listener (Meta internal)

Returns:
A ChannelRx for receiving messages.

Raises:
RuntimeError: If the address format is invalid or binding fails.

Example:
>>> rx = serve("tcp://*:8080")
>>> message = await rx.recv()
"""
...
45 changes: 45 additions & 0 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,51 @@ def test_simple_bootstrap():
proc.wait()


def test_channel_dial_serve():
"""Test that dial and serve functions work for sending and receiving bytes messages."""
from monarch._rust_bindings.monarch_hyperactor.bootstrap import dial, serve

# Setup: Start a server on a fixed port
rx = serve("tcp://*:18765")

# Setup: Create a client that connects to the server
tx = dial("tcp://localhost:18765")

# Execute: Send a message from client to server
test_message = b"Hello, Monarch!"
tx.send(test_message).block_on()

# Execute: Receive the message on the server
received = rx.recv().block_on()

# Assert: The received message matches what was sent
assert received == test_message


def test_channel_unix_socket():
"""Test that dial and serve work with Unix domain sockets."""
from monarch._rust_bindings.monarch_hyperactor.bootstrap import dial, serve

with TemporaryDirectory() as tmpdir:
# Setup: Create a Unix socket address
socket_path = os.path.join(tmpdir, "test.sock")
unix_addr = f"ipc://{socket_path}"

# Setup: Start a server on the Unix socket
rx = serve(unix_addr)

# Setup: Create a client
tx = dial(unix_addr)

# Execute: Send and receive a message
test_message = b"Unix socket test"
tx.send(test_message).block_on()
received = rx.recv().block_on()

# Assert: Message was received correctly
assert received == test_message


class HostMeshActor(Actor):
@endpoint
async def this_host(self) -> HostMesh:
Expand Down
Loading