diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index 68b42f5ad..03279131a 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -13,8 +13,8 @@ use hyperactor::ActorId; use monarch_hyperactor::ndslice::PySlice; use monarch_hyperactor::proc::PyActorId; use monarch_messages::controller::Seq; -use monarch_messages::wire_value::func_call_args_to_wire_values; use monarch_messages::worker; +use monarch_messages::worker::ArgsKwargs; use monarch_messages::worker::CallFunctionParams; use monarch_messages::worker::Cloudpickle; use monarch_messages::worker::Factory; @@ -220,17 +220,16 @@ fn create_map(py: Python) -> HashMap { }); m.insert(key("CallFunction"), |p| { let function = p.parseFunction("function")?; - let args = p.parse("args")?; - let kwargs = p.parse("kwargs")?; + let args: Bound<'_, PyTuple> = p.parse("args")?; + let kwargs: Bound<'_, PyDict> = p.parse("kwargs")?; - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?; + let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?; Ok(WorkerMessage::CallFunction(CallFunctionParams { seq: p.parseSeq("ident")?, results: p.parseFlatReferences("result")?, mutates: p.parseRefList("mutates")?, function, - args, - kwargs, + args_kwargs, stream: p.parseStreamRef("stream")?, remote_process_groups: p.parseRefList("remote_process_groups")?, })) @@ -340,14 +339,13 @@ fn create_map(py: Python) -> HashMap { "SendValue with no function must have exactly one argument and no keyword arguments", )); } - let (args, kwargs) = func_call_args_to_wire_values(function.as_ref(), &args, &kwargs)?; + let args_kwargs = ArgsKwargs::from_python(args.into_any(), kwargs.into_any())?; Ok(WorkerMessage::SendValue { seq: p.parseSeq("ident")?, destination: p.parseOptionalRef("destination")?, mutates: p.parseRefList("mutates")?, function, - args, - kwargs, + args_kwargs, stream: p.parseStreamRef("stream")?, }) }); diff --git a/monarch_messages/src/wire_value.rs b/monarch_messages/src/wire_value.rs index c89d7ce80..6c30ad4c9 100644 --- a/monarch_messages/src/wire_value.rs +++ b/monarch_messages/src/wire_value.rs @@ -6,34 +6,22 @@ * LICENSE file in the root directory of this source tree. */ -use std::collections::HashMap; - use derive_more::From; use derive_more::TryInto; use enum_as_inner::EnumAsInner; use hyperactor::Named; use monarch_types::PickledPyObject; -use monarch_types::TryIntoPyObjectUnsafe; use pyo3::IntoPyObjectExt; -use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyBool; -use pyo3::types::PyDict; -use pyo3::types::PyFloat; -use pyo3::types::PyList; use pyo3::types::PyNone; -use pyo3::types::PyString; -use pyo3::types::PyTuple; use serde::Deserialize; use serde::Serialize; use torch_sys::Device; use torch_sys::Layout; use torch_sys::MemoryFormat; -use torch_sys::OpaqueIValue; use torch_sys::ScalarType; use crate::worker::Ref; -use crate::worker::ResolvableFunction; /// A value used as an input to CallFunction. // TODO, this is basically the same as RValue, but with TensorIndices swapped @@ -59,81 +47,20 @@ pub enum WireValue { // empty enum variants. None(()), PyObject(PickledPyObject), - // It is ok to just have IValue without an alias tracking cell as we just use - // WireValue as a way to serialize and send args to workers. We dont mutate the - // IValue and use the opaque wrapper to make accessing the IValue directly - // an unsafe op. - IValue(torch_sys::OpaqueIValue), } impl FromPyObject<'_> for WireValue { fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult { - if let Ok(ref_) = Ref::from_py_object(obj) { - Ok(WireValue::Ref(ref_)) - } else if let Ok(list) = obj.downcast::() { - let len = list.len(); - if len == 0 { - // TODO: This is done for now as this seems to be the most common case for empty lists - // in torch ops but we should use the op schema to do this correctly. - return Ok(WireValue::IntList(vec![])); - } - - // SAFETY: We know it is within bounds - let item = unsafe { list.get_item_unchecked(0) }; - let len = list.len(); - if let Ok(int) = item.extract::() { - let mut int_list = Vec::with_capacity(len); - int_list.push(int); - for item in list.iter().skip(1) { - int_list.push(item.extract::().map_err(|_| { - PyValueError::new_err(format!( - "Expected homogeneous list of ints got: {:?}", - list - )) - })?); - } - return Ok(WireValue::IntList(int_list)); - } - if let Ok(ref_) = Ref::from_py_object(&item) { - let mut ref_list = Vec::with_capacity(len); - ref_list.push(ref_); - for item in list.iter().skip(1) { - ref_list.push(Ref::from_py_object(&item).map_err(|_| { - PyValueError::new_err(format!( - "Expected homogeneous list of ints got: {:?}", - list - )) - })?); - } - return Ok(WireValue::RefList(ref_list)); - } - Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?)) - } else if obj.is_none() { - Ok(WireValue::None(())) - } else if let Ok(bool_) = obj.downcast::() { - Ok(WireValue::Bool(bool_.is_true())) - } else if let Ok(int) = obj.extract::() { - Ok(WireValue::Int(int)) - } else if let Ok(double) = obj.downcast::() { - Ok(WireValue::Double(double.value())) - } else if let Ok(string) = obj.downcast::() { - Ok(WireValue::String(string.to_str()?.to_string())) - } else if let Ok(device) = obj.extract::() { - Ok(WireValue::Device(device)) - } else if let Ok(layout) = obj.extract::() { - Ok(WireValue::Layout(layout)) - } else if let Ok(scalar_type) = obj.extract::() { - Ok(WireValue::ScalarType(scalar_type)) - } else if let Ok(memory_format) = obj.extract::() { - Ok(WireValue::MemoryFormat(memory_format)) - } else { - Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?)) - } + Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?)) } } -impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue { - unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult> { +impl<'py> IntoPyObject<'py> for WireValue { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult> { match self { WireValue::Ref(ref_) => ref_.into_bound_py_any(py), WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py), @@ -148,190 +75,12 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue { WireValue::MemoryFormat(val) => val.into_bound_py_any(py), WireValue::None(()) => PyNone::get(py).into_bound_py_any(py), WireValue::PyObject(val) => val.unpickle(py), - // SAFETY: WireValue is only used for serde between client and worker. - // This function is used to access the args / kwargs of a function call - // on the client side only. - WireValue::IValue(val) => unsafe { val.try_to_object_unsafe(py) }, } } } -impl<'py> IntoPyObject<'py> for WireValue { - type Target = PyAny; - type Output = Bound<'py, PyAny>; - type Error = PyErr; - - fn into_pyobject(self, py: Python<'py>) -> Result { - unsafe { self.try_to_object_unsafe(py) } - } -} - impl From for WireValue { fn from(obj: PyObject) -> Self { Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap())) } } - -pub fn func_call_args_to_wire_values( - _func: Option<&ResolvableFunction>, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, -) -> PyResult<(Vec, HashMap)> { - python_func_args_to_wire_value(args, kwargs) -} - -fn python_func_args_to_wire_value( - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, -) -> PyResult<(Vec, HashMap)> { - let args = args - .iter() - .map(|arg| Ok(WireValue::PyObject(PickledPyObject::pickle(&arg)?))) - .collect::>()?; - let kwargs = kwargs - .iter() - .map(|(k, v)| { - Ok(( - k.extract::()?, - WireValue::PyObject(PickledPyObject::pickle(&v)?), - )) - }) - .collect::, PyErr>>()?; - Ok((args, kwargs)) -} - -#[cfg(test)] -mod tests { - use std::assert_matches::assert_matches; - - use anyhow::Result; - use anyhow::bail; - use paste::paste; - use pyo3::Python; - use pyo3::ffi::c_str; - use pyo3::types::PyDict; - use torch_sys::DeviceType; - use torch_sys::ScalarType; - - use super::*; - use crate::worker::Ref; - - const MOCK_REFERNCABLE_MODULE: &std::ffi::CStr = c_str!( - r#" -class Referencable: - def __init__(self, ref: int): - self.ref = ref - - def __monarch_ref__(self): - return self.ref -"# - ); - - fn setup() -> Result<()> { - pyo3::prepare_freethreaded_python(); - // We need to load torch to initialize some internal structures used by - // the FFI funcs we use to convert ivalues to/from py objects. - Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?; - Ok(()) - } - - fn create_py_object() -> PyObject { - pyo3::prepare_freethreaded_python(); - Python::with_gil(|py| { - let dict = PyDict::new(py); - dict.set_item("foo", "bar").unwrap(); - dict.into_any().clone().unbind() - }) - } - - macro_rules! generate_wire_value_from_py_tests { - ($($kind:ident, $input:expr);* $(;)?) => { - paste! { - $( - #[test] - fn []() -> Result<()> { - setup()?; - Python::with_gil(|py| { - let actual = $input.into_pyobject(py)?.extract::()?; - assert_matches!(actual, WireValue::$kind(_)); - anyhow::Ok(()) - }) - } - )* - - #[test] - fn test_wire_value_from_py_none() -> Result<()> { - setup()?; - Python::with_gil(|py| { - let obj = PyNone::get(py).into_pyobject(py)?; - let actual = obj.extract::()?; - assert_matches!(actual, WireValue::None(_)); - anyhow::Ok(()) - }) - } - - #[test] - fn test_wire_value_from_py_empty_list() -> Result<()> { - setup()?; - Python::with_gil(|py| { - let obj: PyObject = PyList::empty(py).into_any().unbind(); - let actual = obj.extract::(py)?; - match actual { - WireValue::IntList(list) if list.len() == 0 => (), - _ => bail!("Expected empty list to be converted to empty int list"), - } - anyhow::Ok(()) - }) - } - - #[test] - fn test_wire_value_from_py_referencable_class() -> Result<()> { - setup()?; - Python::with_gil(|py| { - let referencable = PyModule::from_code( - py, - MOCK_REFERNCABLE_MODULE, - c_str!("referencable.py"), - c_str!("referencable"), - )?; - let ref_ = referencable.getattr("Referencable")?.call1((1,))?.unbind(); - let actual = ref_.extract::(py)?; - assert_matches!(actual, WireValue::Ref(Ref { id: 1 })); - anyhow::Ok(()) - }) - } - - #[test] - fn test_wire_value_from_py_roundtrip_was_exhaustive() { - let val = WireValue::Int(0); - match val { - $(WireValue::$kind(_) => (),)* - WireValue::None(_) => (), - // Can't test from py here as PyObject behaves as catch all for conversion from PY. - // We will manually convert torch ops args to IValue respecting the schema so its - // not super important to have this. - WireValue::IValue(_) => (), - } - } - } - } - } - - // Generate exhaustive roundtrip tests for all IValue kind. - // If you got a "non-exhaustive patterns" error here, you need to add a new - // test entry for your IValue kind! - generate_wire_value_from_py_tests! { - Bool, false; - Double, 1.23f64; - Int, 123i64; - IntList, vec![1i64]; - Ref, Ref::from(1); - RefList, vec![Ref::from(1), Ref::from(2)]; - String, "foobar".to_owned(); - Device, Device::new(DeviceType::CPU); - Layout, Layout(2); - ScalarType, ScalarType(3); - MemoryFormat, MemoryFormat(1); - PyObject, create_py_object(); - } -} diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 43c1a822c..cd4ea40ea 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -28,6 +28,7 @@ use hyperactor::RefClient; use hyperactor::Unbind; use hyperactor::reference::ActorId; use monarch_types::SerializablePyErr; +use monarch_types::py_global; use ndslice::Slice; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -282,6 +283,8 @@ impl fmt::Display for Cloudpickle { } } +py_global!(cloudpickle_dumps, "cloudpickle", "dumps"); + #[pyo3::pymethods] impl Cloudpickle { #[new] @@ -301,6 +304,16 @@ impl Cloudpickle { } } +impl Cloudpickle { + pub fn dumps<'py>(obj: Bound<'py, PyAny>) -> PyResult { + let py = obj.py(); + let dumps = cloudpickle_dumps(py); + let bytes_obj = dumps.call1((obj,))?; + let bytes = bytes_obj.downcast::()?.as_bytes().to_vec(); + Ok(Self { bytes }) + } +} + #[derive( PartialEq, Serialize, @@ -319,6 +332,61 @@ pub enum ResolvableFunction { FunctionPath(FunctionPath), } +#[derive(PartialEq, Serialize, Deserialize, Debug, Clone, From)] +pub struct ArgsKwargs { + payload: Cloudpickle, +} + +impl ArgsKwargs { + pub fn from_python<'py>(args: Bound<'py, PyAny>, kwargs: Bound<'py, PyAny>) -> PyResult { + // Create tuple (args, kwargs), then cloudpickle it + let py = args.py(); + let tuple = PyTuple::new(py, vec![args, kwargs])?; + let payload = Cloudpickle::dumps(tuple.into_any())?; + Ok(Self { payload }) + } + + pub fn from_wire_values( + args: Vec, + kwargs: HashMap, + ) -> PyResult { + Python::with_gil(|py| { + // Convert WireValue args to Python objects + let py_args: Vec> = args + .into_iter() + .map(|v| v.into_pyobject(py)) + .collect::>()?; + let args_tuple = PyTuple::new(py, py_args)?; + + // Convert WireValue kwargs to Python dict + let kwargs_dict = PyDict::new(py); + for (k, v) in kwargs { + kwargs_dict.set_item(k, v.into_pyobject(py)?)?; + } + + Self::from_python(args_tuple.into_any(), kwargs_dict.into_any()) + }) + } + + pub fn to_python<'py>( + &self, + py: Python<'py>, + ) -> PyResult<(Bound<'py, PyTuple>, Bound<'py, PyDict>)> { + let tuple = self.payload.resolve(py)?; + let tuple = tuple.downcast::()?; + + // Extract args (first element) + let args = tuple.get_item(0)?; + let args_tuple = args.downcast::()?; + + // Extract kwargs (second element) + let kwargs = tuple.get_item(1)?; + let kwargs_dict = kwargs.downcast::()?; + + Ok((args_tuple.clone(), kwargs_dict.clone())) + } +} + impl<'py> IntoPyObject<'py> for ResolvableFunction { type Target = PyAny; type Output = Bound<'py, Self::Target>; @@ -370,10 +438,8 @@ pub struct CallFunctionParams { pub mutates: Vec, /// The function to call. pub function: ResolvableFunction, - /// The arguments to the function. - pub args: Vec, - /// The keyword arguments to the function. - pub kwargs: HashMap, + /// The arguments and keyword arguments to the function. + pub args_kwargs: ArgsKwargs, /// The stream to call the function on. pub stream: StreamRef, /// The process groups to execute the function on. @@ -783,12 +849,10 @@ pub enum WorkerMessage { /// Pipe to send value to. If `None`, value is sent to controller. destination: Option, mutates: Vec, - /// Function to resolve the value to retrieve. If `None`, then `args` - /// must contain the value as its only element and `kwargs` must be - /// empty. + /// Function to resolve the value to retrieve. If `None`, then `args_kwargs` + /// must contain the value as the only element in args with no kwargs. function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, /// The stream to retrieve from. stream: StreamRef, }, diff --git a/monarch_tensor_worker/src/borrow.rs b/monarch_tensor_worker/src/borrow.rs index 16714dbf4..55029e63f 100644 --- a/monarch_tensor_worker/src/borrow.rs +++ b/monarch_tensor_worker/src/borrow.rs @@ -177,9 +177,11 @@ mod tests { use anyhow::Result; use hyperactor::proc::Proc; use monarch_messages::controller::ControllerMessage; + use monarch_messages::worker::ArgsKwargs; use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerMessageClient; use monarch_messages::worker::WorkerParams; + use pyo3::Python; use timed_test::async_timed_test; use torch_sys::Device; use torch_sys::DeviceType; @@ -225,8 +227,11 @@ mod tests { results: vec![Some(Ref { id: 1 })], mutates: vec![], function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::from([("device".into(), WireValue::Device(device))]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::from([("device".into(), WireValue::Device(device))]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -241,8 +246,11 @@ mod tests { results: vec![Some(Ref { id: 4 })], mutates: vec![], function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::from([("device".into(), WireValue::Device(device))]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::from([("device".into(), WireValue::Device(device))]), + ) + .unwrap(), stream: 3.into(), remote_process_groups: vec![], }), @@ -261,8 +269,11 @@ mod tests { results: vec![Some(Ref { id: 6 })], mutates: vec![], function: "torch.ops.aten.sub_.Tensor".into(), - args: vec![WireValue::Ref(Ref { id: 5 }), WireValue::Ref(Ref { id: 1 })], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(Ref { id: 5 }), WireValue::Ref(Ref { id: 1 })], + HashMap::new(), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -274,8 +285,11 @@ mod tests { results: vec![Some(Ref { id: 7 })], mutates: vec![], function: "torch.ops.aten.zeros.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::from([("device".into(), WireValue::Device(device))]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::from([("device".into(), WireValue::Device(device))]), + ) + .unwrap(), stream: 3.into(), remote_process_groups: vec![], }), @@ -285,8 +299,11 @@ mod tests { results: vec![Some(Ref { id: 8 })], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(Ref { id: 4 }), WireValue::Ref(Ref { id: 7 })], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(Ref { id: 4 }), WireValue::Ref(Ref { id: 7 })], + HashMap::new(), + ) + .unwrap(), stream: 3.into(), remote_process_groups: vec![], }), @@ -362,8 +379,7 @@ mod tests { results: vec![Some(Ref { id: 1 })], mutates: vec![], function: "torch.ops.aten.idont.exist".into(), - args: vec![], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -387,8 +403,11 @@ mod tests { results: vec![Some(Ref { id: 4 })], mutates: vec![], function: "torch.ops.aten.sub_.Scalar".into(), - args: vec![WireValue::Ref(3.into()), WireValue::Int(1)], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(3.into()), WireValue::Int(1)], + HashMap::new(), + ) + .unwrap(), stream: 2.into(), remote_process_groups: vec![], }), @@ -397,8 +416,11 @@ mod tests { results: vec![Some(Ref { id: 5 })], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(Ref { id: 4 }), WireValue::Ref(Ref { id: 4 })], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(Ref { id: 4 }), WireValue::Ref(Ref { id: 4 })], + HashMap::new(), + ) + .unwrap(), stream: 2.into(), remote_process_groups: vec![], }), diff --git a/monarch_tensor_worker/src/comm.rs b/monarch_tensor_worker/src/comm.rs index edaafed24..75281b2b2 100644 --- a/monarch_tensor_worker/src/comm.rs +++ b/monarch_tensor_worker/src/comm.rs @@ -1031,9 +1031,11 @@ mod tests { use futures::future::try_join_all; use hyperactor::actor::ActorStatus; use hyperactor::proc::Proc; + use monarch_messages::worker::ArgsKwargs; use monarch_messages::worker::WorkerMessageClient; use monarch_messages::worker::WorkerParams; use ndslice::Slice; + use pyo3::Python; use timed_test::async_timed_test; use torch_sys::DeviceIndex; use torch_sys::Layout; @@ -1304,11 +1306,14 @@ mod tests { results: vec![Some(2.into())], mutates: vec![], function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1335,11 +1340,14 @@ mod tests { results: vec![Some(4.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1348,8 +1356,11 @@ mod tests { results: vec![Some(5.into())], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(3.into()), WireValue::Ref(4.into())], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(3.into()), WireValue::Ref(4.into())], + HashMap::new(), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1382,8 +1393,11 @@ mod tests { results: vec![Some(7.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(4.0)], - kwargs: HashMap::from([("device".into(), WireValue::Device("cuda".try_into()?))]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(4.0)], + HashMap::from([("device".into(), WireValue::Device("cuda".try_into()?))]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1392,8 +1406,11 @@ mod tests { results: vec![Some(8.into())], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(6.into()), WireValue::Ref(7.into())], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(6.into()), WireValue::Ref(7.into())], + HashMap::new(), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1408,7 +1425,8 @@ mod tests { .await? .unwrap() .unwrap() - .try_into()?; + .try_into() + .unwrap(); assert!(val, "allreduce sum produced unexpected value: {val}"); let val: bool = workers[0] @@ -1416,7 +1434,8 @@ mod tests { .await? .unwrap() .unwrap() - .try_into()?; + .try_into() + .unwrap(); assert!(val, "allreduce sum produced unexpected value: {val}"); for worker in workers.into_iter() { @@ -1482,11 +1501,14 @@ mod tests { results: vec![Some(1.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1554,11 +1576,14 @@ mod tests { results: vec![Some(2.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1567,8 +1592,11 @@ mod tests { results: vec![Some(3.into())], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(1.into()), WireValue::Ref(2.into())], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(1.into()), WireValue::Ref(2.into())], + HashMap::new(), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1637,11 +1665,14 @@ mod tests { results: vec![Some(1.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1650,11 +1681,14 @@ mod tests { results: vec![Some(2.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(4.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(4.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1682,11 +1716,14 @@ mod tests { results: vec![Some(3.into())], mutates: vec![], function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([( - "device".into(), - WireValue::Device("cuda".try_into().unwrap()), - )]), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], + HashMap::from([( + "device".into(), + WireValue::Device("cuda".try_into().unwrap()), + )]), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), @@ -1695,8 +1732,11 @@ mod tests { results: vec![Some(4.into())], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(2.into()), WireValue::Ref(3.into())], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(2.into()), WireValue::Ref(3.into())], + HashMap::new(), + ) + .unwrap(), stream: 0.into(), remote_process_groups: vec![], }), diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 1c36ed254..dcb537664 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -69,6 +69,7 @@ use monarch_messages::controller::Seq; use monarch_messages::wire_value::WireValue; use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::ActorMethodParams; +use monarch_messages::worker::ArgsKwargs; use monarch_messages::worker::CallFunctionParams; use monarch_messages::worker::Factory; use monarch_messages::worker::Reduction; @@ -744,8 +745,7 @@ impl WorkerMessageHandler for WorkerActor { destination: Option, mutates: Vec, function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, stream: StreamRef, ) -> Result<()> { // Resolve the stream. @@ -772,8 +772,7 @@ impl WorkerMessageHandler for WorkerActor { cx.self_id().clone(), mutates, function, - args, - kwargs, + args_kwargs, device_meshes, ) .await @@ -1156,8 +1155,11 @@ mod tests { results: vec![Some(0.into())], mutates: vec![], function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::new(), + ) + .unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1166,8 +1168,11 @@ mod tests { results: vec![Some(Ref { id: 2 })], mutates: vec![0.into()], function: "torch.ops.aten.sub_.Scalar".into(), - args: vec![WireValue::Ref(0.into()), WireValue::Int(1)], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(0.into()), WireValue::Int(1)], + HashMap::new(), + ) + .unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1176,8 +1181,11 @@ mod tests { results: vec![Some(Ref { id: 3 })], mutates: vec![], function: "torch.ops.aten.zeros.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::new(), + ) + .unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1186,8 +1194,11 @@ mod tests { results: vec![Some(Ref { id: 4 })], mutates: vec![], function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })], + HashMap::new(), + ) + .unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1249,8 +1260,7 @@ mod tests { results: vec![Some(0.into())], mutates: vec![], function: "torch.ops.aten.rand.default".into(), - args: vec![], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1314,8 +1324,7 @@ mod tests { results: vec![Some(Ref { id: 2 })], mutates: vec![0.into()], function: "i.dont.exist".into(), - args: vec![], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1380,8 +1389,7 @@ mod tests { results: vec![Some(0.into())], mutates: vec![], function: "i.dont.exist".into(), - args: vec![], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1390,8 +1398,11 @@ mod tests { results: vec![Some(1.into())], mutates: vec![], function: "torch.ops.aten.sub_.Scalar".into(), - args: vec![WireValue::Ref(0.into())], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::Ref(0.into())], + HashMap::new(), + ) + .unwrap(), stream: 1.into(), remote_process_groups: vec![], }), @@ -1541,8 +1552,11 @@ mod tests { results: vec![Some(Ref { id: i + 2 })], mutates: vec![], function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::new(), + args_kwargs: ArgsKwargs::from_wire_values( + vec![WireValue::IntList(vec![2, 3])], + HashMap::new(), + ) + .unwrap(), stream: (i % 2).into(), remote_process_groups: vec![], }, diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 5b9b4b482..7203f0ed3 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -46,6 +46,7 @@ use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::ActorMethodParams; +use monarch_messages::worker::ArgsKwargs; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; use monarch_messages::worker::SeqError; @@ -217,8 +218,7 @@ pub enum StreamMessage { worker_actor_id: ActorId, mutates: Vec, function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, device_meshes: HashMap, }, @@ -652,7 +652,6 @@ impl StreamActor { WireValue::MemoryFormat(val) => RValue::MemoryFormat(val), WireValue::PyObject(val) => RValue::PyObject(val), WireValue::None(()) => RValue::None, - WireValue::IValue(val) => RValue::Opaque(val.into()), }; Ok(ret) } @@ -745,8 +744,7 @@ impl StreamActor { py: Python<'py>, cx: &Context, function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, mutates: &[Ref], device_meshes: HashMap, remote_process_groups: HashMap< @@ -754,6 +752,9 @@ impl StreamActor { (DeviceMesh, Vec, Arc>), >, ) -> Result, CallFunctionError> { + let (args_tuple, kwargs_dict) = args_kwargs + .to_python(py) + .map_err(|e| CallFunctionError::Error(e.into()))?; let function = function .map(|function| { function.resolve(py).map_err(|e| { @@ -806,10 +807,8 @@ impl StreamActor { // this function. let mut multiborrow = MultiBorrow::new(); - let resolve = |val: WireValue| { - val.into_bound_py_any(py) - .map_err(SerializablePyErr::from_fn(py))? - .extract::>() + let resolve = |val: Bound<'py, PyAny>| { + val.extract::>() .map_err(SerializablePyErr::from_fn(py))? .try_into_map(|obj| { Ok(if let Ok(ref_) = Ref::from_py_object(obj.bind(py)) { @@ -827,14 +826,21 @@ impl StreamActor { }) }; - // Resolve refs - let py_args: Vec> = args - .into_iter() - .map(resolve) + // Resolve args and kwargs + let py_args: Vec> = args_tuple + .iter() + .map(|item| resolve(item)) .collect::>()?; - let py_kwargs: HashMap<_, PyTree> = kwargs - .into_iter() - .map(|(k, object)| Ok((k, resolve(object)?))) + + let py_kwargs: HashMap> = kwargs_dict + .iter() + .map(|(k, v)| { + let key = k + .extract::() + .map_err(SerializablePyErr::from_fn(py))?; + let value = resolve(v)?; + Ok((key, value)) + }) .collect::>()?; // Add a shared-borrow for each rvalue reference. @@ -894,8 +900,7 @@ impl StreamActor { &mut self, cx: &hyperactor::Context, function: ResolvableFunction, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, mutates: &[Ref], device_meshes: HashMap, remote_process_groups: HashMap< @@ -908,8 +913,7 @@ impl StreamActor { py, cx, Some(function), - args, - kwargs, + args_kwargs, mutates, device_meshes, remote_process_groups, @@ -975,8 +979,7 @@ impl StreamActor { seq: Seq, mutates: Vec, function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, device_meshes: HashMap, ) -> Result<()> { let rank = self.rank; @@ -988,8 +991,7 @@ impl StreamActor { py, cx, function, - args, - kwargs, + args_kwargs, &mutates, device_meshes, HashMap::new(), @@ -1087,8 +1089,7 @@ impl StreamMessageHandler for StreamActor { self.call_python_fn_pytree( cx, params.function, - params.args, - params.kwargs, + params.args_kwargs, ¶ms.mutates, device_meshes, remote_process_groups, @@ -1512,56 +1513,59 @@ impl StreamMessageHandler for StreamActor { worker_actor_id: ActorId, mutates: Vec, function: Option, - args: Vec, - kwargs: HashMap, + args_kwargs: ArgsKwargs, device_meshes: HashMap, ) -> Result<()> { if self.respond_with_python_message { return self - .send_value_python_message(cx, seq, mutates, function, args, kwargs, device_meshes) + .send_value_python_message(cx, seq, mutates, function, args_kwargs, device_meshes) .await; } - let result = if let Some(function) = function { - // If a function was provided, use that to resolve the value. - tokio::task::block_in_place(|| { - self.call_python_fn_pytree( - cx, - function, - args, - kwargs, - &mutates, - device_meshes, - HashMap::new(), - ) - }) - } else { - // If there's no function provided, there should be exactly one arg - // and no kwargs. - match (args.len(), kwargs.len()) { - (1, 0) => Python::with_gil(|py| { - let arg = args[0] - .clone() - .into_pyobject(py) - .map_err(SerializablePyErr::from_fn(py))?; - arg.extract::>() - .map_err(SerializablePyErr::from_fn(py))? - .try_into_map(|obj| { - let bound_obj = obj.bind(py); - if let Ok(ref_) = Ref::from_py_object(bound_obj) { - self.ref_to_rvalue(&ref_) - } else { - Ok(bound_obj - .extract::() - .map_err(SerializablePyErr::from_fn(py))?) - } - }) - }), - _ => Err(CallFunctionError::TooManyArgsForValue( - format!("{:?}", args), - format!("{:?}", kwargs), - )), + + let result = (|| -> Result, CallFunctionError> { + if let Some(function) = function { + // If a function was provided, use that to resolve the value. + tokio::task::block_in_place(|| { + self.call_python_fn_pytree( + cx, + function, + args_kwargs, + &mutates, + device_meshes, + HashMap::new(), + ) + }) + } else { + // If there's no function provided, there should be exactly one arg + // and no kwargs. + Python::with_gil(|py| { + let (args, kwargs) = args_kwargs + .to_python(py) + .map_err(|e| CallFunctionError::Error(e.into()))?; + match (args.len(), kwargs.len()) { + (1, 0) => { + let arg = args.get_item(0).map_err(SerializablePyErr::from_fn(py))?; + arg.extract::>() + .map_err(SerializablePyErr::from_fn(py))? + .try_into_map(|obj| { + let bound_obj = obj.bind(py); + if let Ok(ref_) = Ref::from_py_object(bound_obj) { + self.ref_to_rvalue(&ref_) + } else { + Ok(bound_obj + .extract::() + .map_err(SerializablePyErr::from_fn(py))?) + } + }) + } + _ => Err(CallFunctionError::TooManyArgsForValue( + format!("args with {} elements", args.len()), + format!("kwargs with {} elements", kwargs.len()), + )), + } + }) } - }; + })(); let value = match result { Ok(rvalue) => { @@ -2142,8 +2146,11 @@ mod tests { stream_actor.actor_id().clone(), Vec::new(), None, - vec![WireValue::PyObject(ref_to_send)], - HashMap::new(), + ArgsKwargs::from_wire_values( + vec![WireValue::PyObject(ref_to_send)], + HashMap::new(), + ) + .unwrap(), HashMap::new(), ) .await