From c4336cdd9668f08c3c63b21dfba92aa8699145e4 Mon Sep 17 00:00:00 2001 From: zdevito Date: Fri, 21 Nov 2025 14:24:01 -0800 Subject: [PATCH] Remove unused python bindings to rust objects Only used in tests that basically tested whether the binding works, but we don't actually use the binding code. Differential Revision: [D87675341](https://our.internmc.facebook.com/intern/diff/D87675341/) [ghstack-poisoned] --- monarch_extension/src/client.rs | 8 +- monarch_extension/src/controller.rs | 138 -- monarch_extension/src/lib.rs | 6 - monarch_extension/src/tensor_worker.rs | 1390 ----------------- .../monarch_extension/controller.pyi | 36 +- .../monarch_extension/tensor_worker.pyi | 975 +----------- python/monarch/common/messages.py | 274 ---- .../controller/rust_backend/controller.py | 1 - python/tests/_monarch/test_controller.py | 79 - python/tests/_monarch/test_worker.py | 358 ----- 10 files changed, 9 insertions(+), 3256 deletions(-) delete mode 100644 monarch_extension/src/controller.rs delete mode 100644 python/tests/_monarch/test_controller.py delete mode 100644 python/tests/_monarch/test_worker.py diff --git a/monarch_extension/src/client.rs b/monarch_extension/src/client.rs index a743c2b96..b49f1ea8c 100644 --- a/monarch_extension/src/client.rs +++ b/monarch_extension/src/client.rs @@ -20,6 +20,7 @@ use hyperactor_multiprocess::system_actor::SystemMessageClient; use hyperactor_multiprocess::system_actor::SystemSnapshotFilter; use hyperactor_multiprocess::system_actor::WorldSnapshot; use hyperactor_multiprocess::system_actor::WorldSnapshotProcInfo; +use monarch_hyperactor::ndslice::PySlice; use monarch_hyperactor::proc::ControllerError; use monarch_hyperactor::proc::InstanceWrapper; use monarch_hyperactor::proc::PyActorId; @@ -51,9 +52,14 @@ use pyo3::types::PyNone; use tokio::sync::Mutex; use torch_sys::RValue; -use crate::controller::PyRanks; use crate::convert::convert; +#[derive(Clone, FromPyObject)] +pub enum PyRanks { + Slice(PySlice), + SliceList(Vec), +} + #[pyclass(frozen, module = "monarch._rust_bindings.monarch_extension.client")] pub struct WorkerResponse { seq: Seq, diff --git a/monarch_extension/src/controller.rs b/monarch_extension/src/controller.rs deleted file mode 100644 index a617c1d93..000000000 --- a/monarch_extension/src/controller.rs +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -/// These the controller messages that are exposed to python to allow the client to construct and -/// send messages to the controller. For more details of the definitions take a look at -/// [`monarch_messages::controller::ControllerMessage`]. -use hyperactor::data::Serialized; -use monarch_hyperactor::ndslice::PySlice; -use monarch_hyperactor::proc::PySerialized; -use monarch_messages::controller::Seq; -use monarch_messages::controller::*; -use monarch_messages::worker::Ref; -use pyo3::exceptions::PyRuntimeError; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -use crate::tensor_worker::PyWorkerMessage; - -#[pyclass( - frozen, - get_all, - module = "monarch._rust_bindings.monarch_extension.controller" -)] -struct Node { - seq: Seq, - defs: Vec, - uses: Vec, -} - -#[pymethods] -impl Node { - #[new] - #[pyo3(signature = (*, seq, defs, uses))] - fn new(seq: Seq, defs: Vec, uses: Vec) -> Self { - Node { seq, defs, uses } - } - - fn serialize(&self) -> PyResult { - PySerialized::new(&ControllerMessage::Node { - seq: self.seq, - defs: self.defs.clone(), - uses: self.uses.clone(), - }) - } - - #[staticmethod] - fn from_serialized(serialized: &PySerialized) -> PyResult { - let message = serialized.deserialized()?; - match message { - ControllerMessage::Node { seq, defs, uses } => Ok(Node { seq, defs, uses }), - _ => Err(PyValueError::new_err(format!( - "Expected Node message, got {:?}", - message - ))), - } - } -} - -#[derive(Clone, FromPyObject)] -pub enum PyRanks { - Slice(PySlice), - SliceList(Vec), -} - -#[pyclass(frozen, module = "monarch._rust_bindings.monarch_extension.controller")] -struct Send { - ranks: Ranks, - message: Serialized, -} - -#[pymethods] -impl Send { - #[new] - #[pyo3(signature = (*, ranks, message))] - fn new(ranks: PyRanks, message: PyRef) -> PyResult { - let ranks = match ranks { - PyRanks::Slice(r) => Ranks::Slice(r.into()), - PyRanks::SliceList(r) => { - if r.is_empty() { - return Err(PyValueError::new_err("Send requires at least one rank")); - } - Ranks::SliceList(r.into_iter().map(|r| r.into()).collect()) - } - }; - // println!("send: {:?}", &message.message); - Ok(Self { - ranks, - message: message.to_serialized()?, - }) - } - - #[getter] - fn ranks(&self) -> Vec { - match &self.ranks { - Ranks::Slice(r) => vec![r.clone().into()], - Ranks::SliceList(r) => r.iter().map(|r| r.clone().into()).collect(), - } - } - - #[getter] - fn message(&self, py: Python<'_>) -> PyResult { - let worker_message = self.message.deserialized().map_err(|err| { - PyRuntimeError::new_err(format!("Failed to deserialize worker message: {}", err)) - })?; - crate::tensor_worker::worker_message_to_py(py, &worker_message) - } - - fn serialize(&self) -> PyResult { - let msg = ControllerMessage::Send { - ranks: self.ranks.clone(), - message: self.message.clone(), - }; - PySerialized::new(&msg) - } - - #[staticmethod] - fn from_serialized(serialized: &PySerialized) -> PyResult { - let message = serialized.deserialized()?; - match message { - ControllerMessage::Send { ranks, message } => Ok(Send { ranks, message }), - _ => Err(PyValueError::new_err(format!( - "Expected Send message, got {:?}", - message - ))), - } - } -} - -pub(crate) fn register_python_bindings(controller_mod: &Bound<'_, PyModule>) -> PyResult<()> { - controller_mod.add_class::()?; - controller_mod.add_class::()?; - Ok(()) -} diff --git a/monarch_extension/src/lib.rs b/monarch_extension/src/lib.rs index 706cdec52..f98f2db31 100644 --- a/monarch_extension/src/lib.rs +++ b/monarch_extension/src/lib.rs @@ -12,8 +12,6 @@ mod client; pub mod code_sync; #[cfg(feature = "tensor_engine")] -mod controller; -#[cfg(feature = "tensor_engine")] pub mod convert; #[cfg(feature = "tensor_engine")] mod debugger; @@ -108,10 +106,6 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> { module, "monarch_extension.tensor_worker", )?)?; - controller::register_python_bindings(&get_or_add_new_module( - module, - "monarch_extension.controller", - )?)?; debugger::register_python_bindings(&get_or_add_new_module( module, "monarch_extension.debugger", diff --git a/monarch_extension/src/tensor_worker.rs b/monarch_extension/src/tensor_worker.rs index fe45f75ae..d0f0f28ad 100644 --- a/monarch_extension/src/tensor_worker.rs +++ b/monarch_extension/src/tensor_worker.rs @@ -6,1402 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -/// These are the worker messages exposed through pyo3 to python. -/// The actual documentation of the messages can be found in [`monarch_messages::worker::WorkerMessage`] -/// This split is currently needed to customize the constructors for the messages and due to the -/// fact that the current pyo3 version has weaker support for proc macro based pyclass customization. -/// A version bump + figuring out if we want to expose unittest structs to python along with portref etc -/// + support for constructor specialization will help avoid duplication here. -/// TODO: Potentially too many clones of slice objects, might need to refactor to avoid that. -use std::collections::HashMap; -use std::ops::DerefMut; - -use anyhow::Result; -use hyperactor::data::Serialized; -use hyperactor::reference::ActorId; -use monarch_hyperactor::ndslice::PySlice; -use monarch_hyperactor::proc::PyActorId; -use monarch_hyperactor::runtime::get_tokio_runtime; -use monarch_messages::wire_value::WireValue; -use monarch_messages::wire_value::func_call_args_to_wire_values; use monarch_messages::worker::*; -use monarch_types::TryIntoPyObjectUnsafe; -use pyo3::IntoPyObjectExt; -use pyo3::exceptions::PyRuntimeError; -use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyDict; -use pyo3::types::PyTuple; -use torch_sys_cuda::nccl::ReduceOp; -use torch_sys_cuda::nccl::UniqueId; - -#[pyclass( - name = "WorkerMessage", - subclass, - module = "monarch._rust_bindings.monarch_extension.tensor_worker" -)] -pub(crate) struct PyWorkerMessage { - pub message: WorkerMessage, -} - -impl PyWorkerMessage { - pub(crate) fn to_serialized(&self) -> PyResult { - Serialized::serialize(&self.message) - .map_err(|err| PyRuntimeError::new_err(format!("Failed to serialize message: {err}"))) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BackendNetworkInit; - -#[pymethods] -impl BackendNetworkInit { - #[new] - fn new() -> PyResult<(Self, PyWorkerMessage)> { - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::BackendNetworkInit( - UniqueId::new().map_err(|err| PyRuntimeError::new_err(err.to_string()))?, - ), - }, - )) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BackendNetworkPointToPointInit; - -#[pymethods] -impl BackendNetworkPointToPointInit { - #[new] - #[pyo3(signature = (*, from_stream, to_stream))] - fn new(from_stream: StreamRef, to_stream: StreamRef) -> PyResult<(Self, PyWorkerMessage)> { - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::BackendNetworkPointToPointInit { - from_stream, - to_stream, - }, - }, - )) - } - - #[getter] - fn from_stream(self_: PyRef) -> StreamRef { - self_ - .as_ref() - .message - .as_backend_network_point_to_point_init() - .unwrap() - .0 - .clone() - } - - #[getter] - fn to_stream(self_: PyRef) -> StreamRef { - self_ - .as_ref() - .message - .as_backend_network_point_to_point_init() - .unwrap() - .1 - .clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CallFunction; - -#[pymethods] -impl CallFunction { - #[new] - #[pyo3(signature = (*, seq, results, mutates, function, args, kwargs, stream, remote_process_groups))] - fn new( - seq: u64, - results: Vec>, - mutates: Vec, - function: ResolvableFunction, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, - stream: StreamRef, - remote_process_groups: Vec, - ) -> PyResult<(Self, PyWorkerMessage)> { - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), args, kwargs)?; - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::CallFunction(CallFunctionParams { - seq: seq.into(), - results, - mutates, - function, - args, - kwargs, - stream, - remote_process_groups, - }), - }, - )) - } - - #[getter] - fn seq(self_: PyRef) -> u64 { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .seq - .into() - } - - #[getter] - fn results(self_: PyRef) -> Vec> { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .results - .clone() - } - - #[getter] - fn mutates(self_: PyRef) -> Vec { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .mutates - .clone() - } - - #[getter] - fn function(self_: PyRef) -> ResolvableFunction { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .function - .clone() - } - - #[getter] - fn args(self_: PyRef) -> PyResult { - wire_values_to_args( - self_.py(), - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .args - .clone(), - ) - } - - #[getter] - fn kwargs(self_: PyRef) -> PyResult { - wire_values_to_kwargs( - self_.py(), - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .kwargs - .clone(), - ) - } - - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .stream - .clone() - } - - #[getter] - fn remote_process_groups(self_: PyRef) -> Vec { - self_ - .as_ref() - .message - .as_call_function() - .unwrap() - .remote_process_groups - .clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CreateStream; - -#[pymethods] -impl CreateStream { - #[new] - #[pyo3(signature = (*, id, stream_creation))] - fn new( - id: StreamRef, - stream_creation: StreamCreationMode, - ) -> PyResult<(Self, PyWorkerMessage)> { - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::CreateStream { - id, - stream_creation, - }, - }, - )) - } - - #[getter] - fn id(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_create_stream().unwrap().0.clone() - } - - #[getter] - fn stream_creation(self_: PyRef) -> StreamCreationMode { - *self_.as_ref().message.as_create_stream().unwrap().1 - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CreateDeviceMesh; - -#[pymethods] -impl CreateDeviceMesh { - #[new] - #[pyo3(signature = (*, result, names, ranks))] - fn new(result: Ref, names: Vec, ranks: &PySlice) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::CreateDeviceMesh { - result, - names, - ranks: ranks.clone().into(), - }, - }, - ) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_ - .as_ref() - .message - .as_create_device_mesh() - .unwrap() - .0 - .clone() - } - - #[getter] - fn names(self_: PyRef) -> Vec { - self_ - .as_ref() - .message - .as_create_device_mesh() - .unwrap() - .1 - .clone() - } - - #[getter] - fn ranks(self_: PyRef) -> PySlice { - self_ - .as_ref() - .message - .as_create_device_mesh() - .unwrap() - .2 - .clone() - .into() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CreateRemoteProcessGroup; - -#[pymethods] -impl CreateRemoteProcessGroup { - #[new] - #[pyo3(signature = (*, result, device_mesh, dims))] - fn new(result: Ref, device_mesh: Ref, dims: Vec) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::CreateRemoteProcessGroup { - result, - device_mesh, - dims, - }, - }, - ) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_ - .as_ref() - .message - .as_create_remote_process_group() - .unwrap() - .0 - .clone() - } - - #[getter] - fn device_mesh(self_: PyRef) -> Ref { - self_ - .as_ref() - .message - .as_create_remote_process_group() - .unwrap() - .1 - .clone() - } - - #[getter] - fn dims(self_: PyRef) -> Vec { - self_ - .as_ref() - .message - .as_create_remote_process_group() - .unwrap() - .2 - .clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BorrowCreate; - -#[pymethods] -impl BorrowCreate { - #[new] - #[pyo3(signature = (*, result, borrow, tensor, from_stream, to_stream))] - fn new( - result: Ref, - borrow: u64, - tensor: Ref, - from_stream: StreamRef, - to_stream: StreamRef, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::BorrowCreate { - result, - borrow, - tensor, - from_stream, - to_stream, - }, - }, - ) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_.as_ref().message.as_borrow_create().unwrap().0.clone() - } - - #[getter] - fn borrow(self_: PyRef) -> u64 { - *self_.as_ref().message.as_borrow_create().unwrap().1 - } - - #[getter] - fn tensor(self_: PyRef) -> Ref { - self_.as_ref().message.as_borrow_create().unwrap().2.clone() - } - - #[getter] - fn from_stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_borrow_create().unwrap().3.clone() - } - - #[getter] - fn to_stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_borrow_create().unwrap().4.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BorrowFirstUse; - -#[pymethods] -impl BorrowFirstUse { - #[new] - #[pyo3(signature = (*, borrow))] - fn new(borrow: u64) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::BorrowFirstUse { borrow }, - }, - ) - } - - #[getter] - fn borrow(self_: PyRef) -> u64 { - *self_.as_ref().message.as_borrow_first_use().unwrap() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BorrowLastUse; - -#[pymethods] -impl BorrowLastUse { - #[new] - #[pyo3(signature = (*, borrow))] - fn new(borrow: u64) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::BorrowLastUse { borrow }, - }, - ) - } - - #[getter] - fn borrow(self_: PyRef) -> u64 { - *self_.as_ref().message.as_borrow_last_use().unwrap() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct BorrowDrop; - -#[pymethods] -impl BorrowDrop { - #[new] - #[pyo3(signature = (*, borrow))] - fn new(borrow: u64) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::BorrowDrop { borrow }, - }, - ) - } - - #[getter] - fn borrow(self_: PyRef) -> u64 { - *self_.as_ref().message.as_borrow_drop().unwrap() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct DeleteRefs; - -#[pymethods] -impl DeleteRefs { - #[new] - #[pyo3(signature = (*, refs))] - fn new(refs: Vec) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::DeleteRefs(refs), - }, - ) - } - - #[getter] - fn refs(self_: PyRef) -> Vec { - self_.as_ref().message.as_delete_refs().unwrap().clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct RequestStatus; - -#[pymethods] -impl RequestStatus { - #[new] - #[pyo3(signature = (*, seq, controller))] - fn new(seq: u64, controller: bool) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::RequestStatus { - seq: seq.into(), - controller, - }, - }, - ) - } - - #[getter] - fn seq(self_: PyRef) -> u64 { - self_.as_ref().message.as_request_status().unwrap().0.into() - } - - #[getter] - fn controller(self_: PyRef) -> bool { - self_ - .as_ref() - .message - .as_request_status() - .unwrap() - .1 - .clone() - } -} - -#[pyclass( - name = "ReductionType", - module = "monarch._rust_bindings.monarch_extension.tensor_worker", - eq, - eq_int -)] -#[derive(Clone, PartialEq)] -enum PyReduction { - Stack, - Sum, - Prod, - Max, - Min, - Avg, -} - -impl From for Reduction { - fn from(value: PyReduction) -> Self { - match value { - PyReduction::Stack => Reduction::Stack, - PyReduction::Sum => Reduction::ReduceOp(ReduceOp::Sum), - PyReduction::Prod => Reduction::ReduceOp(ReduceOp::Prod), - PyReduction::Max => Reduction::ReduceOp(ReduceOp::Max), - PyReduction::Min => Reduction::ReduceOp(ReduceOp::Min), - PyReduction::Avg => Reduction::ReduceOp(ReduceOp::Avg), - } - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct Reduce; - -#[pymethods] -impl Reduce { - #[new] - #[pyo3(signature = (*, result, tensor, factory, mesh, dims, stream, scatter, in_place, reduction, out))] - fn new( - result: Ref, - tensor: Ref, - factory: Factory, - mesh: Ref, - dims: Vec, - stream: StreamRef, - scatter: bool, - in_place: bool, - reduction: PyReduction, - out: Option, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::Reduce { - result, - tensor, - factory, - mesh, - stream, - dims, - reduction: reduction.into(), - scatter, - in_place, - out, - }, - }, - ) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_.as_ref().message.as_reduce().unwrap().0.clone() - } - - #[getter] - fn tensor(self_: PyRef) -> Ref { - self_.as_ref().message.as_reduce().unwrap().1.clone() - } - - #[getter] - fn factory(self_: PyRef) -> Factory { - self_.as_ref().message.as_reduce().unwrap().2.clone() - } - - #[getter] - fn mesh(self_: PyRef) -> Ref { - self_.as_ref().message.as_reduce().unwrap().3.clone() - } - - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_reduce().unwrap().4.clone() - } - - #[getter] - fn dims(self_: PyRef) -> Vec { - self_.as_ref().message.as_reduce().unwrap().5.clone() - } - - #[getter] - fn reduction(self_: PyRef) -> PyReduction { - match self_.as_ref().message.as_reduce().unwrap().6 { - Reduction::Stack => PyReduction::Stack, - Reduction::ReduceOp(ReduceOp::Sum) => PyReduction::Sum, - Reduction::ReduceOp(ReduceOp::Prod) => PyReduction::Prod, - Reduction::ReduceOp(ReduceOp::Max) => PyReduction::Max, - Reduction::ReduceOp(ReduceOp::Min) => PyReduction::Min, - Reduction::ReduceOp(ReduceOp::Avg) => PyReduction::Avg, - } - } - - #[getter] - fn scatter(self_: PyRef) -> bool { - *self_.as_ref().message.as_reduce().unwrap().7 - } - - #[getter] - fn in_place(self_: PyRef) -> bool { - *self_.as_ref().message.as_reduce().unwrap().8 - } - - #[getter] - fn out(self_: PyRef) -> Option { - self_.as_ref().message.as_reduce().unwrap().9.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct SendTensor; - -#[pymethods] -impl SendTensor { - #[new] - #[pyo3(signature = (*, tensor, from_stream, to_stream, from_ranks, to_ranks, result, factory))] - fn new( - tensor: Ref, - from_stream: StreamRef, - to_stream: StreamRef, - from_ranks: &PySlice, - to_ranks: &PySlice, - result: Ref, - factory: Factory, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::SendTensor { - result, - from_ranks: from_ranks.clone().into(), - to_ranks: to_ranks.clone().into(), - tensor, - factory, - from_stream, - to_stream, - }, - }, - ) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_.as_ref().message.as_send_tensor().unwrap().0.clone() - } - - #[getter] - fn from_ranks(self_: PyRef) -> PySlice { - self_ - .as_ref() - .message - .as_send_tensor() - .unwrap() - .1 - .clone() - .into() - } - - #[getter] - fn to_ranks(self_: PyRef) -> PySlice { - self_ - .as_ref() - .message - .as_send_tensor() - .unwrap() - .2 - .clone() - .into() - } - - #[getter] - fn tensor(self_: PyRef) -> Ref { - self_.as_ref().message.as_send_tensor().unwrap().3.clone() - } - - #[getter] - fn factory(self_: PyRef) -> Factory { - self_.as_ref().message.as_send_tensor().unwrap().4.clone() - } - - #[getter] - fn from_stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_send_tensor().unwrap().5.clone() - } - - #[getter] - fn to_stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_send_tensor().unwrap().6.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CreatePipe; - -#[pymethods] -impl CreatePipe { - #[new] - #[pyo3(signature = (*, key, result, mesh, function, max_messages, args, kwargs))] - fn new( - key: String, - result: Ref, - mesh: Ref, - function: ResolvableFunction, - max_messages: i64, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, - ) -> PyResult<(Self, PyWorkerMessage)> { - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), args, kwargs)?; - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::CreatePipe { - result, - key, - function, - max_messages, - mesh, - args, - kwargs, - }, - }, - )) - } - - #[getter] - fn result(self_: PyRef) -> Ref { - self_.as_ref().message.as_create_pipe().unwrap().0.clone() - } - - #[getter] - fn key(self_: PyRef) -> String { - self_.as_ref().message.as_create_pipe().unwrap().1.clone() - } - - #[getter] - fn function(self_: PyRef) -> ResolvableFunction { - self_.as_ref().message.as_create_pipe().unwrap().2.clone() - } - - #[getter] - fn max_messages(self_: PyRef) -> i64 { - self_.as_ref().message.as_create_pipe().unwrap().3.clone() - } - - #[getter] - fn mesh(self_: PyRef) -> Ref { - self_.as_ref().message.as_create_pipe().unwrap().4.clone() - } - - #[getter] - fn args(self_: PyRef) -> PyResult { - wire_values_to_args( - self_.py(), - self_.as_ref().message.as_create_pipe().unwrap().5.clone(), - ) - } - - #[getter] - fn kwargs(self_: PyRef) -> PyResult { - wire_values_to_kwargs( - self_.py(), - self_.as_ref().message.as_create_pipe().unwrap().6.clone(), - ) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct SendValue; - -#[pymethods] -impl SendValue { - #[new] - #[pyo3(signature = (*, seq, destination, mutates, function, args, kwargs, stream))] - fn new( - seq: u64, - destination: Option, - mutates: Vec, - function: Option, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, - stream: StreamRef, - ) -> PyResult<(Self, PyWorkerMessage)> { - if function.is_none() && (args.len() != 1 || !kwargs.is_empty()) { - return Err(PyValueError::new_err( - "SendValue with no function must have exactly one argument and no keyword arguments", - )); - } - let (args, kwargs) = func_call_args_to_wire_values(function.as_ref(), args, kwargs)?; - - Ok(( - Self, - PyWorkerMessage { - message: WorkerMessage::SendValue { - seq: seq.into(), - destination, - mutates, - function, - args, - kwargs, - stream, - }, - }, - )) - } - - #[getter] - fn seq(self_: PyRef) -> u64 { - self_.as_ref().message.as_send_value().unwrap().0.into() - } - - #[getter] - fn destination(self_: PyRef) -> Option { - self_.as_ref().message.as_send_value().unwrap().1.clone() - } - - #[getter] - fn mutates(self_: PyRef) -> Vec { - self_.as_ref().message.as_send_value().unwrap().2.clone() - } - - #[getter] - fn function(self_: PyRef) -> Option { - self_.as_ref().message.as_send_value().unwrap().3.clone() - } - - #[getter] - fn args(self_: PyRef) -> PyResult { - wire_values_to_args( - self_.py(), - self_.as_ref().message.as_send_value().unwrap().4.clone(), - ) - } - - #[getter] - fn kwargs(self_: PyRef) -> PyResult { - wire_values_to_kwargs( - self_.py(), - self_.as_ref().message.as_send_value().unwrap().5.clone(), - ) - } - - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_send_value().unwrap().6.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct SplitComm; - -#[pymethods] -impl SplitComm { - #[new] - #[pyo3(signature = (*, dims, device_mesh, stream))] - fn new(dims: Vec, device_mesh: Ref, stream: StreamRef) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::SplitComm { - dims, - device_mesh, - stream, - config: None, - }, - }, - ) - } - - #[getter] - fn dims(self_: PyRef) -> Vec { - self_.as_ref().message.as_split_comm().unwrap().0.clone() - } - - #[getter] - fn device_mesh(self_: PyRef) -> Ref { - self_.as_ref().message.as_split_comm().unwrap().1.clone() - } - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_split_comm().unwrap().2.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct SplitCommForProcessGroup; - -#[pymethods] -impl SplitCommForProcessGroup { - #[new] - #[pyo3(signature = (*, remote_process_group, stream))] - fn new(remote_process_group: Ref, stream: StreamRef) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::SplitCommForProcessGroup { - remote_process_group, - stream, - config: None, - }, - }, - ) - } - - #[getter] - fn remote_process_group(self_: PyRef) -> Ref { - self_.as_ref().message.as_split_comm().unwrap().1.clone() - } - - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_split_comm().unwrap().2.clone() - } -} - -#[pyclass(extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct DefineRecording; - -#[pymethods] -impl DefineRecording { - #[new] - #[pyo3(signature = (*, result, nresults, nformals, commands, ntotal_messages, index))] - fn new( - result: Ref, - nresults: usize, - nformals: usize, - commands: Vec>, - ntotal_messages: usize, - index: usize, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::DefineRecording { - result, - nresults, - nformals, - commands: commands - .into_iter() - .map(|command| command.message.clone()) - .collect(), - ntotal_messages, - index, - }, - }, - ) - } - - fn append(mut self_: PyRefMut<'_, Self>, command: PyRef) -> PyResult<()> { - self_ - .as_super() - .deref_mut() - .message - .as_define_recording_mut() - .unwrap() - .3 - .push(command.message.clone()); - Ok(()) - } - - fn append_call_function( - mut self_: PyRefMut<'_, Self>, - seq: u64, - results: Vec>, - mutates: Vec, - function: ResolvableFunction, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, - stream: StreamRef, - remote_process_groups: Vec, - ) -> PyResult<()> { - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), args, kwargs)?; - self_ - .as_super() - .deref_mut() - .message - .as_define_recording_mut() - .unwrap() - .3 - .push(WorkerMessage::CallFunction(CallFunctionParams { - seq: seq.into(), - results, - mutates, - function, - args, - kwargs, - stream, - remote_process_groups, - })); - Ok(()) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct RecordingFormal; - -#[pymethods] -impl RecordingFormal { - #[new] - #[pyo3(signature = (*, result, argument_index, stream))] - fn new(result: Ref, argument_index: usize, stream: StreamRef) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::RecordingFormal { - result, - argument_index, - stream, - }, - }, - ) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct RecordingResult; - -#[pymethods] -impl RecordingResult { - #[new] - #[pyo3(signature = (*, result, output_index, stream))] - fn new(result: Ref, output_index: usize, stream: StreamRef) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::RecordingResult { - result, - output_index, - stream, - }, - }, - ) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CallRecording; - -#[pymethods] -impl CallRecording { - #[new] - #[pyo3(signature = (*, seq, recording, results, actuals))] - fn new( - seq: u64, - recording: Ref, - results: Vec, - actuals: Vec, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::CallRecording { - seq: seq.into(), - recording, - results, - actuals, - }, - }, - ) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct PipeRecv; - -#[pymethods] -impl PipeRecv { - #[new] - #[pyo3(signature = (*, seq, pipe, results, stream))] - fn new( - seq: u64, - pipe: Ref, - results: Vec>, - stream: StreamRef, - ) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::PipeRecv { - seq: seq.into(), - results, - pipe, - stream, - }, - }, - ) - } - - #[getter] - fn seq(self_: PyRef) -> u64 { - self_.as_ref().message.as_pipe_recv().unwrap().0.into() - } - - #[getter] - fn pipe(self_: PyRef) -> Ref { - self_.as_ref().message.as_pipe_recv().unwrap().2.clone() - } - - #[getter] - fn results(self_: PyRef) -> Vec> { - self_.as_ref().message.as_pipe_recv().unwrap().1.clone() - } - - #[getter] - fn stream(self_: PyRef) -> StreamRef { - self_.as_ref().message.as_pipe_recv().unwrap().3.clone() - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct Exit; - -#[pymethods] -impl Exit { - #[new] - #[pyo3(signature = (*, error_reason))] - fn new(error_reason: Option<(Option, String)>) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::Exit { - error: error_reason - .map(|(actor_id, reason)| (actor_id.map(|id| ActorId::from(&id)), reason)), - }, - }, - ) - } -} - -#[pyclass(frozen, extends=PyWorkerMessage, module = "monarch._rust_bindings.monarch_extension.tensor_worker")] -struct CommandGroup; - -#[pymethods] -impl CommandGroup { - #[new] - #[pyo3(signature = (*, commands))] - fn new(commands: Vec>) -> (Self, PyWorkerMessage) { - ( - Self, - PyWorkerMessage { - message: WorkerMessage::CommandGroup( - commands - .into_iter() - .map(|command| command.message.clone()) - .collect(), - ), - }, - ) - } - - #[getter] - fn commands(self_: PyRef) -> PyResult> { - let py = self_.py(); - self_ - .as_ref() - .message - .as_command_group() - .unwrap() - .iter() - .map(|message| worker_message_to_py(py, message)) - .collect::, PyErr>>() - } -} - -fn wire_values_to_args(py: Python<'_>, args: Vec) -> PyResult { - let py_ags = args - .into_iter() - .map(|arg| { - // SAFETY: This is ok as its used just to return the wire value back to the user and not - // to mutate it. - unsafe { - arg.try_to_object_unsafe(py).map_err(|err| { - PyValueError::new_err(format!( - "Failed to convert a pytree of WireValues to a PyObject: {:?}", - err - )) - }) - } - }) - .collect::, PyErr>>()?; - Ok(PyTuple::new(py, py_ags)?.into()) -} - -fn wire_values_to_kwargs(py: Python<'_>, kwargs: HashMap) -> PyResult { - kwargs - .into_iter() - .map(|(k, v)| { - // SAFETY: This is ok as its used just to return the wire value back to the user and not - // to mutate it. - Ok((k.clone(), unsafe { - v.try_to_object_unsafe(py).map_err(|err| { - PyValueError::new_err(format!( - "Failed to convert a pytree of WireValues to a PyObject: {:?}", - err - )) - })? - })) - }) - .collect::, PyErr>>()? - .into_py_any(py) -} - -// TODO: This can become an impl on WorkerMessage once we adjust crate split with monarch_messages -pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> PyResult { - let initializer = PyClassInitializer::from(PyWorkerMessage { - message: message.clone(), - }); - match message { - WorkerMessage::BackendNetworkInit { .. } => { - Py::new(py, initializer.add_subclass(BackendNetworkInit {}))?.into_py_any(py) - } - WorkerMessage::BackendNetworkPointToPointInit { .. } => Py::new( - py, - initializer.add_subclass(BackendNetworkPointToPointInit {}), - )? - .into_py_any(py), - WorkerMessage::CallFunction { .. } => { - Py::new(py, initializer.add_subclass(CallFunction {}))?.into_py_any(py) - } - WorkerMessage::CreateStream { .. } => { - Py::new(py, initializer.add_subclass(CreateStream {}))?.into_py_any(py) - } - WorkerMessage::CreateRemoteProcessGroup { .. } => { - Py::new(py, initializer.add_subclass(CreateRemoteProcessGroup {}))?.into_py_any(py) - } - WorkerMessage::CreateDeviceMesh { .. } => { - Py::new(py, initializer.add_subclass(CreateDeviceMesh {}))?.into_py_any(py) - } - WorkerMessage::BorrowCreate { .. } => { - Py::new(py, initializer.add_subclass(BorrowCreate {}))?.into_py_any(py) - } - WorkerMessage::BorrowFirstUse { .. } => { - Py::new(py, initializer.add_subclass(BorrowFirstUse {}))?.into_py_any(py) - } - WorkerMessage::BorrowLastUse { .. } => { - Py::new(py, initializer.add_subclass(BorrowLastUse {}))?.into_py_any(py) - } - WorkerMessage::BorrowDrop { .. } => { - Py::new(py, initializer.add_subclass(BorrowDrop {}))?.into_py_any(py) - } - WorkerMessage::DeleteRefs { .. } => { - Py::new(py, initializer.add_subclass(DeleteRefs {}))?.into_py_any(py) - } - WorkerMessage::RequestStatus { .. } => { - Py::new(py, initializer.add_subclass(RequestStatus {}))?.into_py_any(py) - } - WorkerMessage::Reduce { .. } => { - Py::new(py, initializer.add_subclass(Reduce {}))?.into_py_any(py) - } - WorkerMessage::SendTensor { .. } => { - Py::new(py, initializer.add_subclass(SendTensor {}))?.into_py_any(py) - } - WorkerMessage::CreatePipe { .. } => { - Py::new(py, initializer.add_subclass(CreatePipe {}))?.into_py_any(py) - } - WorkerMessage::SendValue { .. } => { - Py::new(py, initializer.add_subclass(SendValue {}))?.into_py_any(py) - } - WorkerMessage::PipeRecv { .. } => { - Py::new(py, initializer.add_subclass(PipeRecv {}))?.into_py_any(py) - } - WorkerMessage::SplitComm { .. } => { - Py::new(py, initializer.add_subclass(SplitComm {}))?.into_py_any(py) - } - WorkerMessage::SplitCommForProcessGroup { .. } => { - Py::new(py, initializer.add_subclass(SplitCommForProcessGroup {}))?.into_py_any(py) - } - WorkerMessage::Exit { .. } => { - Py::new(py, initializer.add_subclass(Exit {}))?.into_py_any(py) - } - WorkerMessage::CommandGroup { .. } => { - Py::new(py, initializer.add_subclass(CommandGroup {}))?.into_py_any(py) - } - WorkerMessage::DefineRecording { .. } => { - Py::new(py, initializer.add_subclass(DefineRecording {}))?.into_py_any(py) - } - WorkerMessage::RecordingFormal { .. } => { - Py::new(py, initializer.add_subclass(RecordingFormal {}))?.into_py_any(py) - } - WorkerMessage::RecordingResult { .. } => { - Py::new(py, initializer.add_subclass(RecordingResult {}))?.into_py_any(py) - } - WorkerMessage::CallRecording { .. } => { - Py::new(py, initializer.add_subclass(CallRecording {}))?.into_py_any(py) - } - WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(), - WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(), - WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(), - WorkerMessage::CallActorMethod { .. } => unimplemented!(), - } -} - pub(crate) fn register_python_bindings(worker_mod: &Bound<'_, PyModule>) -> PyResult<()> { - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; worker_mod.add_class::()?; worker_mod.add_class::()?; - worker_mod.add_class::()?; worker_mod.add_class::()?; worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - worker_mod.add_class::()?; - Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_extension/controller.pyi b/python/monarch/_rust_bindings/monarch_extension/controller.pyi index eaf30bbc0..7bb2a102f 100644 --- a/python/monarch/_rust_bindings/monarch_extension/controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/controller.pyi @@ -9,12 +9,10 @@ from enum import Enum from typing import Any, final, List, Optional, Tuple, Union -from monarch._rust_bindings.monarch_extension.tensor_worker import Ref, WorkerMessage +from monarch._rust_bindings.monarch_extension.tensor_worker import Ref from monarch._rust_bindings.monarch_hyperactor.proc import Serialized -from monarch._rust_bindings.monarch_hyperactor.shape import Slice - @final class Node: """ @@ -60,35 +58,3 @@ class Node: def from_serialized(serialized: Serialized) -> Node: """Deserialize the message from a Serialized object.""" ... - -@final -class Send: - """ - Send a message to the workers mapping to the ranks provided in the given slices. - - Args: - - `ranks`: Slices of ranks of the workers to send the message to. - - `message`: Message to send to the workers. - """ - - def __init__( - self, *, ranks: Slice | List[Slice], message: WorkerMessage - ) -> None: ... - @property - def ranks(self) -> List[Slice]: - """Slices of ranks of the workers to send the message to.""" - ... - - @property - def message(self) -> WorkerMessage: - """Message to send to the workers.""" - ... - - def serialize(self) -> Serialized: - """Serialize the message into a Serialized object.""" - ... - - @staticmethod - def from_serialized(serialized: Serialized) -> Send: - """Deserialize the message from a Serialized object.""" - ... diff --git a/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi b/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi index afa0dd88f..b7eba5cf9 100644 --- a/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/tensor_worker.pyi @@ -6,11 +6,7 @@ # pyre-unsafe -from typing import Callable, final, Optional, Sequence, Tuple - -import torch -from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._rust_bindings.monarch_hyperactor.shape import Slice +from typing import Callable, final @final class Ref: @@ -65,49 +61,6 @@ class StreamRef: def __ge__(self, other: Ref) -> bool: ... def __hash__(self) -> int: ... -@final -class TensorFactory: - """ - Factory class to hold necessary metadata to create tensors on the worker. - - Args: - - `size`: The size of the tensor. - - `dtype`: The data type of the tensor. - - `layout`: The layout of the tensor. - - `device`: The device of the tensor. (TODO: support torch.device) - """ - - def __init__( - self, - *, - size: Sequence[int], - # pyre-ignore - dtype: torch.dtype, - # pyre-ignore - layout: torch.layout, - # pyre-ignore - device: torch.device, - ) -> None: ... - @property - def size(self) -> tuple[int, ...]: - """The size of the tensor.""" - ... - - @property - def dtype(self) -> torch.dtype: - """The data type of the tensor.""" - ... - - @property - def layout(self) -> torch.layout: - """The layout of the tensor.""" - ... - - @property - def device(self) -> str: - """The device of the tensor.""" - ... - @final class FunctionPath: """ @@ -144,929 +97,3 @@ class Cloudpickle: ... ResolvableFunction = FunctionPath | Cloudpickle - -@final -class StreamCreationMode: - """ - Used to specify what CUDA stream to use for the worker stream creation. - """ - - UseDefaultStream: StreamCreationMode - CreateNewStream: StreamCreationMode - - def __eq__(self, value: StreamCreationMode) -> bool: ... - def __ne__(self, value: StreamCreationMode) -> bool: ... - def __repr__(self) -> str: ... - def __int__(self) -> int: ... - -@final -class ReductionType: - """Used to specify the reduction type for the Reduce command.""" - - Stack: ReductionType - Sum: ReductionType - Prod: ReductionType - Max: ReductionType - Min: ReductionType - Avg: ReductionType - - def __eq__(self, value: ReductionType) -> bool: ... - def __ne__(self, value: ReductionType) -> bool: ... - def __repr__(self) -> str: ... - def __int__(self) -> int: ... - -class WorkerMessage: - """ - The base class for all messages that can be sent to the worker. - This class is not meant to be instantiated or inherited directly. - Instead, use the subclasses of this class to send messages to - the worker. - TODO: Expose all subclasses as attributes of this class. - """ - - ... - -@final -class BackendNetworkInit(WorkerMessage): - """Instruct the worker to initialize the backend network.""" - - def __init__(self) -> None: ... - -@final -class BackendNetworkPointToPointInit(WorkerMessage): - """Instruct the worker to initialize the backend network for point-to-point communication.""" - - def __init__(self, *, from_stream: StreamRef, to_stream: StreamRef) -> None: ... - @property - def from_stream(self) -> StreamRef: - """Reference to the src stream to use for the point-to-point communication.""" - ... - - @property - def to_stream(self) -> StreamRef: - """Reference to the dst stream to use for the point-to-point communication.""" - ... - -@final -class CallFunction(WorkerMessage): - """ - Instruct the worker to call a function, either a torch op - or a Python `remote_function`. - - Args: - - `seq`: Sequence number of the message. - - `results`: References to the values that the function returns. - - `mutates`: References to the values that the function mutates. - - `function`: Fully qualified path to the function. - - `args`: Pytree-serializable arguments to the function. - - `kwargs`: Pytree-serializable keyword arguments to the function. - - `stream`: Reference to the stream the worker should use to execute the function. - - `remote_process_groups`: References to the process groups the worker should use to execute the function. - """ - - def __init__( - self, - *, - seq: int, - results: Sequence[Ref | None], - mutates: Sequence[Ref], - function: ResolvableFunction, - args: tuple[object, ...], - kwargs: dict[str, object], - stream: StreamRef, - remote_process_groups: Sequence[Ref], - ) -> None: ... - @property - def seq(self) -> int: - """Sequence number of the message.""" - ... - - @property - def results(self) -> list[Ref | None]: - """References to the values that the function returns.""" - ... - - @property - def mutates(self) -> list[Ref]: - """References to the values that the function mutates.""" - ... - - @property - def function(self) -> ResolvableFunction: - """Fully qualified path to the function.""" - ... - - @property - def args(self) -> tuple[object, ...]: - """ - Pytree-serializable arguments to the function. - Accessing this property can be expensive as it clones. - """ - ... - - @property - def kwargs(self) -> dict[str, object]: - """ - Pytree-serializable keyword arguments to the function. - Accessing this property can be expensive as it clones. - """ - ... - - @property - def stream(self) -> StreamRef: - """Reference to the stream the worker should use to execute the function.""" - ... - - @property - def remote_process_groups(self) -> list[Ref]: - """References to the process groups the worker should use to execute the function.""" - ... - -@final -class CreateStream(WorkerMessage): - """ - Instruct the worker to create a new stream. Worker will execute commands - on streams concurrently. - - Args: - - `id`: The id of the stream on the worker. - - `stream_creation`: The CUDA stream to use for the created stream. - """ - - def __init__( - self, *, id: StreamRef, stream_creation: StreamCreationMode - ) -> None: ... - @property - def id(self) -> StreamRef: - """The id of the stream on the worker.""" - ... - - @property - def stream_creation(self) -> StreamCreationMode: - """The CUDA stream to use for the created stream.""" - ... - -@final -class CreateDeviceMesh(WorkerMessage): - """ - Instruct the worker to create a new device mesh which can be used to schedule - efficient inter-worker communication. - - Args: - - `result`: Reference to the created device mesh. - - `names`: Names of the dimensions in the device mesh. - - `ranks`: Multi-dimensional slice of the ranks of the devices in the device mesh. - The number of dimensions must match the number of names. - """ - - def __init__(self, *, result: Ref, names: Sequence[str], ranks: Slice) -> None: ... - @property - def result(self) -> Ref: - """The reference to the created device mesh.""" - ... - - @property - def names(self) -> list[str]: - """The names of the dimensions in the device mesh.""" - ... - - @property - def ranks(self) -> Slice: - """The multi-dimensional slice of the ranks of the devices in the device mesh.""" - ... - -@final -class CreateRemoteProcessGroup(WorkerMessage): - """ - Instruct the worker to create a new PyTorch process group to allow UDFs to - perform collectives. - - Args: - - `result`: Reference to the created process group. - - `device_mesh`: Device mesh to create group on. - - `dims`: Device mesh dimensions group should use. - """ - - def __init__( - self, *, result: Ref, device_mesh: Ref, dims: Sequence[str] - ) -> None: ... - @property - def result(self) -> Ref: - """The reference to the created process group.""" - ... - - @property - def device_mesh(self) -> Ref: - """The names of the dimensions in the device mesh.""" - ... - - @property - def dims(self) -> list[str]: - """The device mesh dimension to communicate over.""" - ... - -@final -class BorrowCreate(WorkerMessage): - """ - Instruct the worker to create a borrow of a tensor from one stream to another. - - Args: - - `result`: Reference to the resulting borrowed tensor - - `borrow`: The ID for the borrow. - - `tensor`: Reference to the tensor to borrow. - - `from_stream`: Reference to the stream to borrow from. - - `to_stream`: Reference to the stream to borrow to. - """ - - def __init__( - self, - *, - result: Ref, - borrow: int, - tensor: Ref, - from_stream: StreamRef, - to_stream: StreamRef, - ) -> None: ... - @property - def result(self) -> Ref: - """The reference to the resulting borrowed tensor.""" - ... - - @property - def borrow(self) -> int: - """The ID for the borrow.""" - ... - - @property - def tensor(self) -> Ref: - """The reference to the tensor to borrow.""" - ... - - @property - def from_stream(self) -> StreamRef: - """The reference to the stream to borrow from.""" - ... - - @property - def to_stream(self) -> StreamRef: - """The reference to the stream to borrow to.""" - ... - -@final -class BorrowFirstUse(WorkerMessage): - """ - A synchronization marker for the worker on first use of the borrowed tensor. - - Args: - - borrow: The ID for the borrow. - """ - - def __init__(self, *, borrow: int) -> None: ... - @property - def borrow(self) -> int: - """The ID for the borrow.""" - ... - -@final -class BorrowLastUse(WorkerMessage): - """ - A synchronization marker for the worker on last use of the borrowed tensor. - - Args: - - borrow: The ID for the borrow. - """ - - def __init__(self, *, borrow: int) -> None: ... - @property - def borrow(self) -> int: - """The ID for the borrow.""" - ... - -@final -class BorrowDrop(WorkerMessage): - """ - Instruct the worker to drop a borrow of a tensor. - - Args: - - borrow: The ID for the borrow. - """ - - def __init__(self, *, borrow: int) -> None: ... - @property - def borrow(self) -> int: - """The ID for the borrow.""" - ... - -@final -class DeleteRefs(WorkerMessage): - """ - Instruct the worker to delete the values referenced by the given refs - from its state. - - Args: - - refs: References to the values to delete. - """ - - def __init__(self, *, refs: Sequence[Ref]) -> None: ... - @property - def refs(self) -> list[Ref]: - """References to the values to delete.""" - ... - -@final -class RequestStatus(WorkerMessage): - """ - Instruct the worker to respond back when all the messages before this - message have been processed on all streams. - - Args: - - seq: Sequence number of the message. - - controller: Whether this message was sent by the controller. - """ - - def __init__(self, *, seq: int, controller: bool) -> None: ... - @property - def seq(self) -> int: - """Sequence number of the message.""" - ... - - @property - def controller(self) -> bool: - """If this message was sent by the controller.""" - ... - -@final -class Reduce(WorkerMessage): - """ - Perform a reduction operation, using an efficient communication backend. - - Args: - - `result`: Reference to the resulting tensor. - - `tensor`: Reference to the tensor to reduce. - - `factory`: Tensor metadata to create the resulting tensor if `tensor` - is not available for some reason. - - `mesh`: Reference to the device mesh to use for the reduction. - - `stream`: Reference to the stream to use for the reduction. - - `dims`: The dimensions of the device mesh to reduce over. - - `reduction`: The reduction type to use for the reduction. - - `scatter`: Whether to evenly split the resulting tensor across the ranks - of the request `dim` in the device mesh. - - `in_place`: Whether to perform the reduction in-place on `tensor`. - """ - - def __init__( - self, - *, - result: Ref, - tensor: Ref, - factory: TensorFactory, - mesh: Ref, - stream: StreamRef, - dims: Sequence[str], - reduction: ReductionType, - scatter: bool, - in_place: bool, - out: Ref | None, - ) -> None: ... - @property - def result(self) -> Ref: - """Reference to the resulting tensor.""" - ... - - @property - def tensor(self) -> Ref: - """Reference to the tensor to reduce.""" - ... - - @property - def factory(self) -> TensorFactory: - """ - Tensor metadata to create the resulting tensor if `tensor` is not - available for some reason. - """ - ... - - @property - def mesh(self) -> Ref: - """Reference to the device mesh to use for the reduction.""" - ... - - @property - def stream(self) -> StreamRef: - """Reference to the stream to use for the reduction.""" - ... - - @property - def dims(self) -> list[str]: - """The dimension of the device mesh to reduce over.""" - ... - - @property - def reduction(self) -> ReductionType: - """The reduction type to use for the reduction.""" - ... - - @property - def scatter(self) -> bool: - """ - Whether to evenly split the resulting tensor across the ranks of the - request `dim` in the device mesh. - """ - ... - - @property - def in_place(self) -> bool: - """Whether to perform the reduction in-place on `tensor`.""" - ... - - @property - def out(self) -> Ref: - """Reference to the out tensor.""" - ... - -@final -class SendTensor(WorkerMessage): - """ - Send a tenser from one slice of ranks to another slice of ranks. - - Args: - - `result`: Reference to the resulting tensor. - - `from_ranks`: Slice of ranks to send the tensor from. - - `to_ranks`: Slice of ranks to send the tensor to. - - `tensor`: Reference to the tensor to send. - - `factory`: Tensor metadata to create the resulting tensor if `tensor` - is not available for some reason. - - `from_stream`: Reference to the src stream to use for this operation. - - `to_stream`: Reference to the dst stream to use for this operation. - """ - - def __init__( - self, - *, - result: Ref, - from_ranks: Slice, - to_ranks: Slice, - tensor: Ref, - factory: TensorFactory, - from_stream: StreamRef, - to_stream: StreamRef, - ) -> None: ... - @property - def result(self) -> Ref: - """Reference to the resulting tensor.""" - ... - - @property - def from_ranks(self) -> Slice: - """Slice of ranks to send the tensor from.""" - ... - - @property - def to_ranks(self) -> Slice: - """Slice of ranks to send the tensor to.""" - ... - - @property - def tensor(self) -> Ref: - """Reference to the tensor to send.""" - ... - - @property - def factory(self) -> TensorFactory: - """ - Tensor metadata to create the resulting tensor if `tensor` is not - available for some reason. - """ - ... - - @property - def from_stream(self) -> StreamRef: - """Reference to the src stream to use for this operation.""" - ... - - @property - def to_stream(self) -> StreamRef: - """Reference to the dst stream to use for this operation.""" - ... - -@final -class CreatePipe(WorkerMessage): - """ - Create a pipe on the worker. - - Args: - - `result`: Reference to the resulting pipe. - - `key`: The key of the pipe this mainly exists for backwards compatibility - with the python impl. - - `function`: Fully qualified path to the function to call to create the pipe. - - `max_messages`: Maximum number of messages to buffer in the pipe. - - `mesh`: Reference to the device mesh on which the pipes have been created. - - `args`: Pytree-serializable arguments to the function. - - `kwargs`: Pytree-serializable keyword arguments to the function. - """ - - def __init__( - self, - *, - result: Ref, - key: str, - function: ResolvableFunction, - max_messages: int, - mesh: Ref, - args: tuple[object, ...], - kwargs: dict[str, object], - ) -> None: ... - @property - def result(self) -> Ref: - """Reference to the resulting pipe.""" - ... - - @property - def key(self) -> str: - """The key of the pipe this mainly exists for backwards compatibility with the python impl.""" - ... - - @property - def function(self) -> ResolvableFunction: - """Fully qualified path to the function to call to create the pipe.""" - ... - - @property - def max_messages(self) -> int: - """Maximum number of messages to buffer in the pipe.""" - ... - - @property - def mesh(self) -> Ref: - """Reference to the device mesh on which the pipes have been created.""" - ... - - @property - def args(self) -> tuple[object, ...]: - """ - Pytree-serializable arguments to the function. - Accessing this property can be expensive as it clones. - """ - ... - - @property - def kwargs(self) -> dict[str, object]: - """ - Pytree-serializable keyword arguments to the function. - Accessing this property can be expensive as it clones. - """ - ... - -@final -class SendValue(WorkerMessage): - """ - Send a value from one slice of ranks to another slice of ranks. - - Args: - - `seq`: Sequence number of the message. - - `destination`: Reference to the destination (Pipe) of the value. If `None` - the value will be sent to the controller. - - `function`: Fully qualified path to the function to call to transform the - value before sending it. - - `mutates`: References to the values that the function mutates. - - `args`: Pytree-serializable arguments to the function. If `function` is - `None` this must be a single value to send. - - `kwargs`: Pytree-serializable keyword arguments to the function. If - `function` is `None` this must be empty. - - `stream`: Reference to the stream the worker should use to execute the - operation. - """ - - def __init__( - self, - *, - seq: int, - destination: Ref | None, - function: ResolvableFunction | None, - mutates: Sequence[Ref], - args: tuple[object, ...], - kwargs: dict[str, object], - stream: StreamRef, - ) -> None: ... - @property - def seq(self) -> int: - """Sequence number of the message.""" - ... - - @property - def destination(self) -> Ref | None: - """Reference to the destination (Pipe) of the value. If `None` the value will be sent to the controller.""" - ... - - @property - def function(self) -> ResolvableFunction | None: - """Fully qualified path to the function to call to transform the value before sending it.""" - ... - - @property - def mutates(self) -> list[Ref]: - """References to the values that the function mutates.""" - ... - - @property - def args(self) -> list[object]: - """ - Pytree-serializable arguments to the function. - If `function` is `None` this must be a single value to send. - Accessing this property can be expensive as it clones. - """ - ... - - @property - def kwargs(self) -> dict[str, object]: - """ - Pytree-serializable keyword arguments to the function. - If `function` is `None` this must be empty. - Accessing this property can be expensive as it clones. - """ - ... - - @property - def stream(self) -> StreamRef: - """Reference to the stream the worker should use to execute the operation.""" - ... - -@final -class PipeRecv(WorkerMessage): - """ - Receive a value from a pipe. - - Args: - - `seq`: Sequence number of the message. - - `pipe`: Reference to the pipe to receive from. - - `results`: References to the values that the pipe returns. - - `stream`: Reference to the stream the worker should use to execute the - operation. - """ - - def __init__( - self, - *, - seq: int, - pipe: Ref, - results: Sequence[Ref | None], - stream: StreamRef, - ) -> None: ... - @property - def seq(self) -> int: - """Sequence number of the message.""" - ... - - @property - def pipe(self) -> Ref: - """Reference to the pipe to receive from.""" - ... - - @property - def results(self) -> list[Ref | None]: - """References to the values that the pipe returns.""" - ... - - @property - def stream(self) -> StreamRef: - """Reference to the stream the worker should use to execute the operation.""" - ... - -@final -class CommandGroup(WorkerMessage): - """ - A group of commands that should be executed on the worker. - - Args: - - `commands`: The commands to execute. - """ - - def __init__(self, *, commands: Sequence[WorkerMessage]) -> None: ... - @property - def commands(self) -> list[WorkerMessage]: - """The commands to execute.""" - ... - -@final -class Exit(WorkerMessage): - """Instruct the worker to exit.""" - - def __init__( - self, *, error_reason: Optional[tuple[Optional[ActorId], str]] - ) -> None: ... - -@final -class SplitComm(WorkerMessage): - """ - Create a new communicator on each rank in `ranks`, capable of communicating - with its peers along the specified dimensions. - - Args: - - `dims`: The device mesh dimensions along which the constructed - communicator should be able to exchange data. - - `device_mesh`: The device mesh associated with the new communicator. One - communicator will be created for every member of the mesh. - - `stream`: The stream associated with the communicator. Communicator - operations will be ordered with respect to other operations scheduled on - this stream. - """ - - def __init__( - self, - *, - dims: Sequence[str], - device_mesh: Ref, - stream: StreamRef, - ) -> None: ... - @property - def dims(self) -> Sequence[str]: - """ - The device mesh dimensions along which the constructed communicator - should be able to exchange data. - """ - ... - - @property - def device_mesh(self) -> Ref: - """ - The device mesh associated with the new communicator. One - communicator will be created for every member of the mesh. - """ - ... - - @property - def stream(self) -> StreamRef: - """ - The stream associated with the communicator. Communicator operations - will be ordered with respect to other operations scheduled on this - stream. - """ - ... - -@final -class SplitCommForProcessGroup(WorkerMessage): - """ - Create a new communicator for the given `remote_process_group` for the given - `stream`, capable of communicating with its peers along the specified - dimensions. - - Args: - - `remote_process_group`: The process group associated with the new - communicator. One communicator will be created for every member of the - mesh. - - `stream`: The stream associated with the communicator. Communicator - operations will be ordered with respect to other operations scheduled on - this stream. - """ - - def __init__( - self, - *, - remote_process_group: Ref, - stream: StreamRef, - ) -> None: ... - @property - def remote_process_group(self) -> Ref: - """ - The remote process group associated with the new communicator. One - communicator will be created for every member of the mesh. - """ - ... - - @property - def stream(self) -> StreamRef: - """ - The stream associated with the communicator. Communicator operations - will be ordered with respect to other operations scheduled on this - stream. - """ - ... - -@final -class DefineRecording(WorkerMessage): - """ - Defines (part of) a new recording on the worker. This is a list of commands - representing the execution of a function that was defined using - monarch.compile. If there are too many commands to send in a single - DefineRecording message, the commands may be chunked into `ntotal_messages`, - with the `index` field indicating how to order the DefineRecording messages - for a single recording. - - Args: - - `result`: The ref associated with this recording that will be used to - call it in the future. - - `nresults`: The number of output tensors. - - `nformals`: The number of input tensors. - - `commands`: The list of commands to run. - - `ntotal_messages`: How many total DefineRecording messages make up this - recording. - - `index`: This DefineRecording message's index in the set of messages - that make up this recording. - """ - - def __init__( - self, - *, - result: Ref, - nresults: int, - nformals: int, - commands: Sequence[WorkerMessage], - ntotal_messages: int, - index: int, - ) -> None: ... - def append(self, command: WorkerMessage) -> None: - """ - Append a command to the DefineRecording. - - Args: - - `command`: The WorkerMessage to append. - """ - ... - - def append_call_function( - self, - *, - seq: int, - results: Sequence[Ref | None], - mutates: Sequence[Ref], - function: ResolvableFunction, - args: tuple[object, ...], - kwargs: dict[str, object], - stream: StreamRef, - remote_process_groups: Sequence[Ref], - ) -> None: - """ - Append a CallFunction command to the DefineRecording. - - Args: - - `seq`: Sequence number of the message. - - `results`: References to the values that the function returns. - - `mutates`: References to the values that the function mutates. - - `function`: Fully qualified path to the function. - - `args`: Pytree-serializable arguments to the function. - - `kwargs`: Pytree-serializable keyword arguments to the function. - - `stream`: Reference to the stream the worker should use to execute the function. - - `remote_process_groups`: References to the process groups the worker should use to execute the function. - """ - ... - -@final -class RecordingFormal(WorkerMessage): - """ - Defines an input tensor for a recording. - - Args: - - `result`: The ref that will be used to pass the input tensor to the - recording. - - `argument_index`: The index of the input tensor in the list of input tensors. - - `stream`: The stream that this input tensor will be used on. - """ - - def __init__( - self, - *, - result: Ref, - argument_index: int, - stream: StreamRef, - ) -> None: ... - -@final -class RecordingResult(WorkerMessage): - """ - Defines an output tensor for a recording. - - Args: - - `result`: The ref that will be used to store the output tensor. - - `output_index`: The index of the output tensor in the list of output tensors. - - `stream`: The stream that this output tensor will come from. - """ - - def __init__( - self, - *, - result: Ref, - output_index: int, - stream: StreamRef, - ) -> None: ... - -@final -class CallRecording(WorkerMessage): - """ - Calls a recording that was previously defined using DefineRecording. - - Args: - - `seq`: The sequence number of the invocation. - - `recording`: The ref of the recording to call. - - `results`: The list of refs where the result tensors from the recording - will be stored. - - `actuals`: The list of refs of input tensors to the recording. - """ - - def __init__( - self, - *, - seq: int, - recording: Ref, - results: Sequence[Ref], - actuals: Sequence[Ref], - ) -> None: ... diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index 6af3e4e73..0503db2f1 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -10,13 +10,11 @@ from traceback import FrameSummary from typing import ( - cast, Dict, List, Literal, NamedTuple, Optional, - Protocol, Sequence, Tuple, TYPE_CHECKING, @@ -76,58 +74,22 @@ def _ref(r: Referenceable | tensor_worker.Ref) -> tensor_worker.Ref: return r -# We cant do inheritance with NamedTuple so we can use this protocol for -# type casting for now until we can move to rust messages entirely. -# Preferring this over a massive if else to keep everything co-located and -# easier to identify drift. -class SupportsToRustMessage(Protocol): - def to_rust_message(self) -> tensor_worker.WorkerMessage: ... - - class CreateDeviceMesh(NamedTuple): result: DeviceMesh names: Dims ranks: NDSlice - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CreateDeviceMesh( - result=tensor_worker.Ref(id=self.result.ref), - names=self.names, - ranks=NDSlice( - offset=self.ranks.offset, - sizes=self.ranks.sizes, - strides=self.ranks.strides, - ), - ) - class CreateStream(NamedTuple): result: "StreamRef" default: bool - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CreateStream( - id=tensor_worker.StreamRef(id=self.result.ref), - stream_creation=( - tensor_worker.StreamCreationMode.UseDefaultStream - if self.default - else tensor_worker.StreamCreationMode.CreateNewStream - ), - ) - class CreateRemoteProcessGroup(NamedTuple): result: Referenceable device_mesh: DeviceMesh dims: Dims - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CreateRemoteProcessGroup( - result=tensor_worker.Ref(id=none_throws(self.result.ref)), - device_mesh=tensor_worker.Ref(id=self.device_mesh.ref), - dims=self.dims, - ) - class CallFunction(NamedTuple): ident: int @@ -140,78 +102,27 @@ class CallFunction(NamedTuple): device_mesh: DeviceMesh remote_process_groups: List[RemoteProcessGroup] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CallFunction( - seq=self.ident, - results=_result_to_references(self.result), - mutates=[_ref(r) for r in self.mutates], - function=_to_rust_function(self.function), - args=self.args, - kwargs=self.kwargs, - stream=tensor_worker.StreamRef(id=self.stream.ref), - remote_process_groups=[ - tensor_worker.Ref(id=none_throws(remote_process_group.ref)) - for remote_process_group in self.remote_process_groups - ], - ) - class Exit(NamedTuple): destroy_pg: bool error: Optional[RemoteException | DeviceException | Exception] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - actor_id = None - error_message = None - if isinstance(self.error, (RemoteException, DeviceException)): - actor_id = self.error.source_actor_id - error_message = self.error.message - elif self.error is not None: - error_message = str(self.error) - - error_reason = None if error_message is None else (actor_id, error_message) - return tensor_worker.Exit(error_reason=error_reason) - class CommandGroup(NamedTuple): commands: List[NamedTuple] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - rust_commands = [] - for c in self.commands: - if hasattr(c, "to_rust_message"): - c = cast(SupportsToRustMessage, c) - rust_commands.append(c.to_rust_message()) - else: - raise NotImplementedError(f"Unsupported command {c}") - return tensor_worker.CommandGroup(commands=rust_commands) - class RecordingFormal(NamedTuple): result: Tensor | tensor_worker.Ref argument_index: int stream: "StreamRef" - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.RecordingFormal( - result=_ref(self.result), - argument_index=self.argument_index, - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - class RecordingResult(NamedTuple): input: Tensor | tensor_worker.Ref output_index: int stream: "StreamRef" - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.RecordingResult( - result=_ref(self.input), - output_index=self.output_index, - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - class DefineRecording(NamedTuple): result: Recording @@ -221,38 +132,6 @@ class DefineRecording(NamedTuple): ntotal_messages: int message_index: int - def to_rust_message(self) -> tensor_worker.WorkerMessage: - define_recording = tensor_worker.DefineRecording( - result=tensor_worker.Ref(id=none_throws(self.result.ref)), - nresults=self.nresults, - nformals=self.nformals, - commands=[], - ntotal_messages=self.ntotal_messages, - index=self.message_index, - ) - for c in self.commands: - if hasattr(c, "to_rust_message"): - c = cast(SupportsToRustMessage, c) - if isinstance(c, CallFunction): - define_recording.append_call_function( - seq=c.ident, - results=_result_to_references(c.result), - mutates=[_ref(r) for r in c.mutates], - function=_to_rust_function(c.function), - args=c.args, - kwargs=c.kwargs, - stream=tensor_worker.StreamRef(id=c.stream.ref), - remote_process_groups=[ - tensor_worker.Ref(id=none_throws(remote_process_group.ref)) - for remote_process_group in c.remote_process_groups - ], - ) - else: - define_recording.append(c.to_rust_message()) - else: - raise NotImplementedError(f"Unsupported command {c}") - return define_recording - class CallRecording(NamedTuple): ident: int @@ -260,23 +139,10 @@ class CallRecording(NamedTuple): results: List[Tensor | tensor_worker.Ref] actuals: List[Tensor | tensor_worker.Ref] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CallRecording( - seq=self.ident, - recording=tensor_worker.Ref(id=none_throws(self.recording.ref)), - results=[_ref(r) for r in self.results], - actuals=[_ref(r) for r in self.actuals], - ) - class DeleteRefs(NamedTuple): refs: List[int] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.DeleteRefs( - refs=[tensor_worker.Ref(id=r) for r in self.refs] - ) - # This is worker <> controller/backend comms only will be supported differently class Restarted(NamedTuple): @@ -293,19 +159,6 @@ class SendValue(NamedTuple): kwargs: Dict[str, object] stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.SendValue( - seq=self.ident, - destination=( - tensor_worker.Ref(id=self.destination.ref) if self.destination else None - ), - mutates=[_ref(r) for r in self.mutates], - function=_to_rust_function(self.function) if self.function else None, - args=self.args, - kwargs=self.kwargs, - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - # Worker -> Controller comm only handled differently class FetchResult(NamedTuple): @@ -345,9 +198,6 @@ class RequestStatus(NamedTuple): ident: int controller: bool - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.RequestStatus(seq=self.ident, controller=self.controller) - class BorrowCreate(NamedTuple): result: Tensor | tensor_worker.Ref @@ -356,42 +206,18 @@ class BorrowCreate(NamedTuple): from_stream: StreamRef to_stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BorrowCreate( - result=_ref(self.result), - borrow=self.borrow, - tensor=_ref(self.tensor), - from_stream=tensor_worker.StreamRef(id=self.from_stream.ref), - to_stream=tensor_worker.StreamRef(id=self.to_stream.ref), - ) - class BorrowDrop(NamedTuple): borrow: int # id of borrowed tensor - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BorrowDrop( - borrow=self.borrow, - ) - class BorrowFirstUse(NamedTuple): borrow: int # id of borrowed tensor - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BorrowFirstUse( - borrow=self.borrow, - ) - class BorrowLastUse(NamedTuple): borrow: int # id of borrowed tensor - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BorrowLastUse( - borrow=self.borrow, - ) - class SendTensor(NamedTuple): result: Tensor | tensor_worker.Ref @@ -402,30 +228,6 @@ class SendTensor(NamedTuple): from_stream: StreamRef to_stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.SendTensor( - result=_ref(self.result), - from_ranks=NDSlice( - offset=self.from_ranks.offset, - sizes=self.from_ranks.sizes, - strides=self.from_ranks.strides, - ), - to_ranks=NDSlice( - offset=self.to_ranks.offset, - sizes=self.to_ranks.sizes, - strides=self.to_ranks.strides, - ), - tensor=_ref(self.tensor), - factory=tensor_worker.TensorFactory( - size=self.factory.size, - dtype=self.factory.dtype, - device=self.factory.device, - layout=self.factory.layout, - ), - from_stream=tensor_worker.StreamRef(id=self.from_stream.ref), - to_stream=tensor_worker.StreamRef(id=self.to_stream.ref), - ) - class SendResultOfActorCall(NamedTuple): seq: int @@ -449,24 +251,11 @@ class SplitComm(NamedTuple): device_mesh: DeviceMesh stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.SplitComm( - dims=self.dims, - device_mesh=tensor_worker.Ref(id=self.device_mesh.ref), - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - class SplitCommForProcessGroup(NamedTuple): remote_process_group: DeviceMesh stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.SplitCommForProcessGroup( - remote_process_group=tensor_worker.Ref(id=self.remote_process_group.ref), - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - class Reduce(NamedTuple): result: Tensor | tensor_worker.Ref @@ -480,41 +269,6 @@ class Reduce(NamedTuple): inplace: bool out: Tensor | tensor_worker.Ref | None - def to_rust_message(self) -> tensor_worker.WorkerMessage: - match self.reduction: - case "sum": - reduction = tensor_worker.ReductionType.Sum - case "prod": - reduction = tensor_worker.ReductionType.Prod - case "stack": - reduction = tensor_worker.ReductionType.Stack - case "avg": - reduction = tensor_worker.ReductionType.Avg - case "min": - reduction = tensor_worker.ReductionType.Min - case "max": - reduction = tensor_worker.ReductionType.Max - case _: - raise ValueError(f"Unsupported reduction {self.reduction}") - - return tensor_worker.Reduce( - result=_ref(self.result), - tensor=_ref(self.local_tensor), - factory=tensor_worker.TensorFactory( - size=self.factory.size, - dtype=self.factory.dtype, - device=self.factory.device, - layout=self.factory.layout, - ), - mesh=tensor_worker.Ref(id=self.source_mesh.ref), - stream=tensor_worker.StreamRef(id=self.stream.ref), - dims=self.dims, - reduction=reduction, - scatter=self.scatter, - in_place=self.inplace, - out=_ref(self.out) if self.out is not None else None, - ) - class CreatePipe(NamedTuple): result: Pipe @@ -525,17 +279,6 @@ class CreatePipe(NamedTuple): args: Tuple[object, ...] kwargs: Dict[str, object] - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.CreatePipe( - result=tensor_worker.Ref(id=self.result.ref), - key=self.key, - function=_to_rust_function(self.function), - max_messages=self.max_messages, - mesh=tensor_worker.Ref(id=self.device_mesh.ref), - args=self.args, - kwargs=self.kwargs, - ) - class PipeRecv(NamedTuple): ident: int @@ -543,33 +286,16 @@ class PipeRecv(NamedTuple): pipe: Pipe stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.PipeRecv( - seq=self.ident, - results=_result_to_references(self.result), - pipe=tensor_worker.Ref(id=self.pipe.ref), - stream=tensor_worker.StreamRef(id=self.stream.ref), - ) - class BackendNetworkInit(NamedTuple): hostname: str | None = None port: int | None = None - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BackendNetworkInit() - class BackendNetworkPointToPointInit(NamedTuple): from_stream: StreamRef to_stream: StreamRef - def to_rust_message(self) -> tensor_worker.WorkerMessage: - return tensor_worker.BackendNetworkPointToPointInit( - from_stream=tensor_worker.StreamRef(id=self.from_stream.ref), - to_stream=tensor_worker.StreamRef(id=self.to_stream.ref), - ) - # TODO: This is not supported on the rust side and might be only needed for remote funcs class DebuggerRead(NamedTuple): diff --git a/python/monarch/controller/rust_backend/controller.py b/python/monarch/controller/rust_backend/controller.py index f6dfecf4f..eed5c6900 100644 --- a/python/monarch/controller/rust_backend/controller.py +++ b/python/monarch/controller/rust_backend/controller.py @@ -34,7 +34,6 @@ from monarch.common.controller_api import LogMessage, MessageResult from monarch.common.device_mesh import no_mesh from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.messages import SupportsToRustMessage from monarch.common.tensor import Tensor from monarch.controller.debugger import read as debugger_read, write as debugger_write from pyre_extensions import none_throws diff --git a/python/tests/_monarch/test_controller.py b/python/tests/_monarch/test_controller.py deleted file mode 100644 index dd5d67158..000000000 --- a/python/tests/_monarch/test_controller.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from unittest import TestCase - -import pytest - -from monarch._rust_bindings.monarch_extension import ( # @manual=//monarch/monarch_extension:monarch_extension - controller, - tensor_worker, -) -from monarch._rust_bindings.monarch_hyperactor import shape - - -class TestController(TestCase): - def test_node(self) -> None: - node = controller.Node(seq=10, defs=[tensor_worker.Ref(id=1)], uses=[]) - self.assertEqual(node.seq, 10) - self.assertEqual(node.defs, [tensor_worker.Ref(id=1)]) - self.assertEqual(node.uses, []) - - def test_send(self) -> None: - msg = controller.Send( - ranks=[shape.Slice(offset=1, sizes=[3], strides=[1])], - message=tensor_worker.CreateStream( - id=tensor_worker.StreamRef(id=10), - stream_creation=tensor_worker.StreamCreationMode.CreateNewStream, - ), - ) - self.assertEqual(len(msg.ranks), 1) - self.assertEqual(msg.ranks[0].ndim, 1) - message = msg.message - assert isinstance(message, tensor_worker.CreateStream) - self.assertEqual( - message.stream_creation, tensor_worker.StreamCreationMode.CreateNewStream - ) - self.assertEqual(message.id, tensor_worker.StreamRef(id=10)) - - def test_send_no_ranks(self) -> None: - with self.assertRaises(ValueError): - controller.Send( - ranks=[], - message=tensor_worker.CreateStream( - id=tensor_worker.StreamRef(id=10), - stream_creation=tensor_worker.StreamCreationMode.CreateNewStream, - ), - ) - - def test_node_serde(self) -> None: - node = controller.Node(seq=10, defs=[tensor_worker.Ref(id=1)], uses=[]) - serialized = node.serialize() - deserialized = controller.Node.from_serialized(serialized) - self.assertEqual(node.seq, deserialized.seq) - self.assertEqual(node.defs, deserialized.defs) - self.assertEqual(node.uses, deserialized.uses) - - def test_send_serde(self) -> None: - msg = controller.Send( - ranks=[shape.Slice(offset=1, sizes=[3], strides=[1])], - message=tensor_worker.CreateStream( - id=tensor_worker.StreamRef(id=10), - stream_creation=tensor_worker.StreamCreationMode.CreateNewStream, - ), - ) - serialized = msg.serialize() - deserialized = controller.Send.from_serialized(serialized) - self.assertEqual(len(msg.ranks), len(deserialized.ranks)) - self.assertEqual(msg.ranks[0].ndim, deserialized.ranks[0].ndim) - message = msg.message - deser_message = deserialized.message - assert isinstance(message, tensor_worker.CreateStream) - assert isinstance(deser_message, tensor_worker.CreateStream) - self.assertEqual(message.stream_creation, deser_message.stream_creation) - self.assertEqual(message.id, deser_message.id) diff --git a/python/tests/_monarch/test_worker.py b/python/tests/_monarch/test_worker.py deleted file mode 100644 index 1a8d5436d..000000000 --- a/python/tests/_monarch/test_worker.py +++ /dev/null @@ -1,358 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import math -from typing import cast -from unittest import TestCase - -import cloudpickle - -from monarch._rust_bindings.monarch_extension import tensor_worker -from monarch._rust_bindings.monarch_hyperactor import shape -from pyre_extensions import none_throws - - -def is_nan(val: int) -> bool: - return math.isnan(val) - - -class MockReferencable: - def __init__(self, ref: int) -> None: - self.ref = ref - - def __monarch_ref__(self) -> int: - return self.ref - - -class TestWorker(TestCase): - def test_backend_network_init(self) -> None: - msg = tensor_worker.BackendNetworkInit() - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - - def test_backend_network_init_point_to_point(self) -> None: - msg = tensor_worker.BackendNetworkPointToPointInit( - from_stream=tensor_worker.StreamRef(id=1), - to_stream=tensor_worker.StreamRef(id=2), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.from_stream.id, 1) - self.assertEqual(msg.to_stream.id, 2) - - def test_call_function(self) -> None: - import torch - - msg = tensor_worker.CallFunction( - seq=10, - results=[tensor_worker.Ref(id=2), None], - mutates=[], - function=tensor_worker.FunctionPath(path="torch.ops.aten.ones.default"), - args=([2, 3],), - kwargs={"device": "cpu", "pin_memory": False}, - stream=tensor_worker.StreamRef(id=1), - remote_process_groups=[], - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.seq, 10) - self.assertEqual(msg.results, [tensor_worker.Ref(id=2), None]) - self.assertEqual(msg.mutates, []) - self.assertEqual( - cast(tensor_worker.FunctionPath, msg.function).path, - "torch.ops.aten.ones.default", - ) - self.assertEqual(msg.args, ([2, 3],)) - self.assertEqual( - msg.kwargs, {"device": torch.device("cpu"), "pin_memory": False} - ) - self.assertIsInstance(msg.kwargs["pin_memory"], bool) - # we cannot use isinstance to assert bool vs int - self.assertTrue(msg.kwargs["pin_memory"] is False) - self.assertEqual(msg.stream, tensor_worker.StreamRef(id=1)) - - def test_call_function_live_function(self) -> None: - msg = tensor_worker.CallFunction( - seq=10, - results=[], - mutates=[], - function=tensor_worker.Cloudpickle(bytes=cloudpickle.dumps(is_nan)), - args=(), - kwargs={}, - stream=tensor_worker.StreamRef(id=1), - remote_process_groups=[], - ) - self.assertFalse(msg.function.resolve()(4)) - - def test_call_function_referencable_args(self) -> None: - msg = tensor_worker.CallFunction( - seq=10, - results=[], - mutates=[], - function=tensor_worker.FunctionPath( - path="torch.ops.aten._foreach_add.Tensor" - ), - args=([MockReferencable(1), MockReferencable(2)],), - kwargs={ - "other": MockReferencable(3), - }, - stream=tensor_worker.StreamRef(id=1), - remote_process_groups=[], - ) - self.assertEqual( - msg.args, ([tensor_worker.Ref(id=1), tensor_worker.Ref(id=2)],) - ) - self.assertEqual(msg.kwargs, {"other": tensor_worker.Ref(id=3)}) - - def test_create_stream(self) -> None: - msg = tensor_worker.CreateStream( - id=tensor_worker.StreamRef(id=10), - stream_creation=tensor_worker.StreamCreationMode.CreateNewStream, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.id.id, 10) - self.assertEqual( - msg.stream_creation, tensor_worker.StreamCreationMode.CreateNewStream - ) - - def test_create_device_mesh(self) -> None: - msg = tensor_worker.CreateDeviceMesh( - result=tensor_worker.Ref(id=10), - names=("x", "y"), - ranks=shape.Slice(offset=0, sizes=[2, 3], strides=[3, 1]), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.result.id, 10) - self.assertEqual(msg.names, ["x", "y"]) - self.assertEqual(msg.ranks.ndim, 2) - - def test_create_remote_process_group(self) -> None: - msg = tensor_worker.CreateRemoteProcessGroup( - result=tensor_worker.Ref(id=10), - device_mesh=tensor_worker.Ref(id=12), - dims=("x", "y"), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.result.id, 10) - self.assertEqual(msg.device_mesh.id, 12) - self.assertEqual(msg.dims, ["x", "y"]) - - def test_borrow_create(self) -> None: - msg = tensor_worker.BorrowCreate( - result=tensor_worker.Ref(id=10), - borrow=23, - tensor=tensor_worker.Ref(id=20), - from_stream=tensor_worker.StreamRef(id=1), - to_stream=tensor_worker.StreamRef(id=2), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.result.id, 10) - self.assertEqual(msg.borrow, 23) - self.assertEqual(msg.tensor.id, 20) - self.assertEqual(msg.from_stream.id, 1) - self.assertEqual(msg.to_stream.id, 2) - - def test_borrow_first_use(self) -> None: - msg = tensor_worker.BorrowFirstUse( - borrow=23, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.borrow, 23) - - def test_borrow_last_use(self) -> None: - msg = tensor_worker.BorrowLastUse( - borrow=23, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.borrow, 23) - - def test_borrow_drop(self) -> None: - msg = tensor_worker.BorrowDrop( - borrow=23, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.borrow, 23) - - def test_delete_refs(self) -> None: - msg = tensor_worker.DeleteRefs( - refs=[tensor_worker.Ref(id=1), tensor_worker.Ref(id=2)] - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.refs, [tensor_worker.Ref(id=1), tensor_worker.Ref(id=2)]) - - def test_request_status(self) -> None: - msg = tensor_worker.RequestStatus(seq=10, controller=False) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.seq, 10) - - def test_reduce(self) -> None: - import torch - - msg = tensor_worker.Reduce( - result=tensor_worker.Ref(id=10), - tensor=tensor_worker.Ref(id=20), - factory=tensor_worker.TensorFactory( - size=(2, 3), - dtype=torch.bfloat16, - device=torch.device("cpu"), - layout=torch.sparse_csr, - ), - mesh=tensor_worker.Ref(id=30), - dims=("x",), - stream=tensor_worker.StreamRef(id=1), - scatter=False, - in_place=False, - reduction=tensor_worker.ReductionType.Stack, - out=None, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.result.id, 10) - self.assertEqual(msg.tensor.id, 20) - self.assertEqual(msg.factory.size, (2, 3)) - self.assertEqual(msg.factory.dtype, torch.bfloat16) - self.assertEqual(msg.factory.device, str(torch.device("cpu"))) - self.assertEqual(msg.factory.layout, torch.sparse_csr) - self.assertEqual(msg.mesh.id, 30) - self.assertEqual(msg.dims, ["x"]) - self.assertEqual(msg.stream.id, 1) - self.assertEqual(msg.scatter, False) - self.assertEqual(msg.in_place, False) - self.assertEqual(msg.reduction, tensor_worker.ReductionType.Stack) - self.assertIsNone(msg.out) - - msg = tensor_worker.Reduce( - result=tensor_worker.Ref(id=10), - tensor=tensor_worker.Ref(id=20), - factory=tensor_worker.TensorFactory( - size=(2, 3), - dtype=torch.bfloat16, - device=torch.device("cpu"), - layout=torch.sparse_csr, - ), - mesh=tensor_worker.Ref(id=30), - dims=("x",), - stream=tensor_worker.StreamRef(id=1), - scatter=False, - in_place=False, - reduction=tensor_worker.ReductionType.Stack, - out=tensor_worker.Ref(id=40), - ) - self.assertEqual(msg.out.id, 40) - - def test_send_tensor(self) -> None: - import torch - - msg = tensor_worker.SendTensor( - tensor=tensor_worker.Ref(id=10), - from_stream=tensor_worker.StreamRef(id=1), - to_stream=tensor_worker.StreamRef(id=2), - from_ranks=shape.Slice(offset=0, sizes=[2, 3], strides=[3, 1]), - to_ranks=shape.Slice(offset=0, sizes=[3, 4, 5], strides=[20, 5, 1]), - result=tensor_worker.Ref(id=2), - factory=tensor_worker.TensorFactory( - size=(2, 5), - dtype=torch.float32, - device=torch.device("cuda"), - layout=torch.strided, - ), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.tensor.id, 10) - self.assertEqual(msg.from_stream.id, 1) - self.assertEqual(msg.to_stream.id, 2) - self.assertEqual(msg.from_ranks.ndim, 2) - self.assertEqual(msg.to_ranks.ndim, 3) - self.assertEqual(msg.result.id, 2) - self.assertEqual(msg.factory.size, (2, 5)) - self.assertEqual(msg.factory.dtype, torch.float32) - - def test_create_pipe(self) -> None: - msg = tensor_worker.CreatePipe( - result=tensor_worker.Ref(id=10), - key="some_key", - function=tensor_worker.FunctionPath(path="builtins.range"), - max_messages=1, - mesh=tensor_worker.Ref(id=20), - args=(1, 10), - kwargs={}, - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.result.id, 10) - self.assertEqual(msg.key, "some_key") - self.assertEqual( - cast(tensor_worker.FunctionPath, msg.function).path, "builtins.range" - ) - self.assertEqual(msg.mesh.id, 20) - self.assertEqual(msg.args, (1, 10)) - self.assertEqual(msg.kwargs, {}) - - def test_send_value(self) -> None: - msg = tensor_worker.SendValue( - seq=100, - destination=tensor_worker.Ref(id=10), - mutates=[], - function=None, - args=(500,), - kwargs={}, - stream=tensor_worker.StreamRef(id=1), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.seq, 100) - self.assertEqual(none_throws(msg.destination).id, 10) - self.assertEqual(msg.mutates, []) - self.assertEqual(msg.function, None) - self.assertEqual(msg.args, (500,)) - self.assertEqual(msg.kwargs, {}) - self.assertEqual(msg.stream.id, 1) - - def test_pipe_recv(self) -> None: - msg = tensor_worker.PipeRecv( - seq=101, - results=[tensor_worker.Ref(id=10), tensor_worker.Ref(id=11)], - pipe=tensor_worker.Ref(id=1), - stream=tensor_worker.StreamRef(id=2), - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(msg.seq, 101) - self.assertEqual( - msg.results, [tensor_worker.Ref(id=10), tensor_worker.Ref(id=11)] - ) - self.assertEqual(msg.pipe.id, 1) - self.assertEqual(msg.stream.id, 2) - - def test_exit(self) -> None: - msg = tensor_worker.Exit(error_reason=None) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - - def test_command_group(self) -> None: - msg = tensor_worker.CommandGroup( - commands=[ - tensor_worker.CallFunction( - seq=10, - results=[tensor_worker.Ref(id=2), None], - mutates=[], - function=tensor_worker.FunctionPath( - path="torch.ops.aten.ones.default" - ), - args=([2, 3],), - kwargs={"device": "cpu"}, - stream=tensor_worker.StreamRef(id=1), - remote_process_groups=[], - ), - tensor_worker.Exit(error_reason=None), - ] - ) - self.assertTrue(isinstance(msg, tensor_worker.WorkerMessage)) - self.assertEqual(len(msg.commands), 2) - msg0 = msg.commands[0] - assert isinstance(msg0, tensor_worker.CallFunction) - self.assertEqual(msg0.seq, 10) - self.assertEqual(msg0.results, [tensor_worker.Ref(id=2), None]) - self.assertEqual( - cast(tensor_worker.FunctionPath, msg0.function).path, - "torch.ops.aten.ones.default", - ) - self.assertEqual(msg0.args, ([2, 3],)) - self.assertTrue(isinstance(msg.commands[1], tensor_worker.Exit))