diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index a0dd02b80..68b42f5ad 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -330,21 +330,6 @@ fn create_map(py: Python) -> HashMap { to_stream: p.parseStreamRef("to_stream")?, }) }); - m.insert(key("CreatePipe"), |p| { - let function = p.parseFunction("function")?; - let args = p.parse("args")?; - let kwargs = p.parse("kwargs")?; - let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?; - Ok(WorkerMessage::CreatePipe { - result: p.parseRef("result")?, - key: p.parse("key")?, - function, - max_messages: p.parse("max_messages")?, - mesh: p.parseRef("device_mesh")?, - args, - kwargs, - }) - }); m.insert(key("SendValue"), |p| { let function = p.parseOptionalFunction("function")?; let args: Bound<'_, PyTuple> = p.parse("args")?; diff --git a/monarch_messages/src/wire_value.rs b/monarch_messages/src/wire_value.rs index 3e0b1fc1b..c89d7ce80 100644 --- a/monarch_messages/src/wire_value.rs +++ b/monarch_messages/src/wire_value.rs @@ -40,16 +40,7 @@ use crate::worker::ResolvableFunction; // out for refs. And IValue is the same as RValue, but with real tensors and // C++ types. I wonder if there is a nicer way to express this relationship. // TODO extend this to support other types of values, like bytes, dicts etc. -#[derive( - Serialize, - Deserialize, - Debug, - Clone, - TryInto, - Named, - From, - EnumAsInner -)] +#[derive(Serialize, Deserialize, Debug, Clone, TryInto, Named, From)] pub enum WireValue { // Make sure boolean goes ealier than int as bool is a subclass of int. // Otherwise, bool will be converted to int. @@ -165,116 +156,28 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue { } } -impl From for WireValue { - fn from(obj: PyObject) -> Self { - Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap())) - } -} +impl<'py> IntoPyObject<'py> for WireValue { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; -impl WireValue { - fn from_pyobject_with_torch_op_arg_type( - obj: Bound<'_, PyAny>, - type_: &torch_sys::call_op::TypePtr, - num_elements: i32, - allow_nums_as_tensors: bool, - ) -> PyResult { - if type_.is_tensor() || type_.is_optional_tensor() { - if type_.is_optional_tensor() && obj.is_none() { - return Ok(WireValue::None(())); - } else if let Ok(ref_) = Ref::from_py_object(&obj) { - return Ok(WireValue::Ref(ref_)); - } - } - if type_.is_tensor_list() || type_.is_optional_tensor_list() { - if type_.is_optional_tensor_list() && obj.is_none() { - return Ok(WireValue::None(())); - } - let list = obj.downcast::()?; - let len = list.len(); - if len == 0 { - return Ok(WireValue::RefList(vec![])); - } - // SAFETY: We know it is within bounds - let item = unsafe { list.get_item_unchecked(0) }; - if let Ok(ref_) = Ref::from_py_object(&item) { - let mut ref_list = Vec::with_capacity(len); - ref_list.push(ref_); - for item in list.iter().skip(1) { - ref_list.push(Ref::from_py_object(&item).map_err(|_| { - PyValueError::new_err(format!( - "Expected homogeneous list of refs got: {:?}", - list - )) - })?); - } - return Ok(WireValue::RefList(ref_list)); - } - } - OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors) - .map(WireValue::IValue) + fn into_pyobject(self, py: Python<'py>) -> Result { + unsafe { self.try_to_object_unsafe(py) } } } -pub fn func_call_args_to_wire_values( - func: Option<&ResolvableFunction>, - args: &Bound<'_, PyTuple>, - kwargs: &Bound<'_, PyDict>, -) -> PyResult<(Vec, HashMap)> { - if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) { - torch_op_args_to_wire_values(&op, &overload, args, kwargs) - } else { - python_func_args_to_wire_value(args, kwargs) +impl From for WireValue { + fn from(obj: PyObject) -> Self { + Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap())) } } -fn torch_op_args_to_wire_values( - op: &str, - overload: &str, +pub fn func_call_args_to_wire_values( + _func: Option<&ResolvableFunction>, args: &Bound<'_, PyTuple>, kwargs: &Bound<'_, PyDict>, ) -> PyResult<(Vec, HashMap)> { - let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| { - PyValueError::new_err(format!( - "Failed to get the operator schema for {}::{}: {}", - op, overload, err - )) - })?; - - let args = args - .iter() - .zip(&args_info) - .map(|(arg, arg_info)| { - WireValue::from_pyobject_with_torch_op_arg_type( - arg, - arg_info.type_, - arg_info.num_elements, - arg_info.allows_number_as_tensor, - ) - }) - .collect::, _>>()?; - let kwargs = kwargs - .iter() - .map(|(k, v)| { - let key = k.extract::()?; - let arg_info = args_info - .iter() - .find(|arg_info| arg_info.name == key) - .ok_or_else(|| { - PyValueError::new_err(format!( - "Torch op {}::{} does not support kwarg {}", - op, overload, key - )) - })?; - let val = WireValue::from_pyobject_with_torch_op_arg_type( - v, - arg_info.type_, - arg_info.num_elements, - arg_info.allows_number_as_tensor, - )?; - Ok((key, val)) - }) - .collect::, PyErr>>()?; - Ok((args, kwargs)) + python_func_args_to_wire_value(args, kwargs) } fn python_func_args_to_wire_value( diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 255571b90..43c1a822c 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -340,21 +340,6 @@ impl ResolvableFunction { } } - pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> { - match self { - Self::FunctionPath(func) => match func.path.split(".").collect::>().as_slice() { - ["torch", "ops", namespace, op_name, "default"] => { - Some((format!("{}::{}", namespace, op_name), String::new())) - } - ["torch", "ops", namespace, op_name, overload] => { - Some((format!("{}::{}", namespace, op_name), overload.to_string())) - } - _ => None, - }, - _ => None, - } - } - /// For testing: this is a special remote function path that induces a panic /// when called. pub fn panic_if_requested(&self) { @@ -367,13 +352,6 @@ impl ResolvableFunction { _ => (), } } - - pub fn supports_pytree_args(&self) -> bool { - match self { - Self::Cloudpickle(_) => true, - Self::FunctionPath(_) => self.as_torch_op().is_none(), - } - } } impl> From for ResolvableFunction { @@ -800,16 +778,6 @@ pub enum WorkerMessage { to_stream: StreamRef, }, - CreatePipe { - result: Ref, - key: String, - function: ResolvableFunction, - max_messages: i64, - mesh: Ref, - args: Vec, - kwargs: HashMap, - }, - SendValue { seq: Seq, /// Pipe to send value to. If `None`, value is sent to controller. diff --git a/monarch_tensor_worker/src/borrow.rs b/monarch_tensor_worker/src/borrow.rs index 35957e6d7..16714dbf4 100644 --- a/monarch_tensor_worker/src/borrow.rs +++ b/monarch_tensor_worker/src/borrow.rs @@ -424,7 +424,7 @@ mod tests { .err() .context("expected error")?; assert!( - error.contains("torch operator error"), + error.contains("failed to resolve function"), "If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})", error, ); diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 972bb0673..1c36ed254 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -381,14 +381,11 @@ impl WorkerMessageHandler for WorkerActor { self.maybe_add_stream_to_recording(cx, params.stream) .await?; - let device_meshes = if params.function.as_torch_op().is_some() { - HashMap::new() - } else { - self.device_meshes - .iter() - .map(|(k, v)| (k.clone(), v.0.clone())) - .collect() - }; + let device_meshes = self + .device_meshes + .iter() + .map(|(k, v)| (k.clone(), v.0.clone())) + .collect(); let mut remote_process_groups = HashMap::new(); for remote_process_group_ref in ¶ms.remote_process_groups { @@ -636,22 +633,6 @@ impl WorkerMessageHandler for WorkerActor { Ok(()) } - async fn create_pipe( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - // TODO(agallagher): This is used in the python impl to name the socket - // path to use for comms, but we don't currently use a named socket. - _key: String, - _function: ResolvableFunction, - _max_messages: i64, - _device_mesh: Ref, - _args: Vec, - _kwargs: HashMap, - ) -> Result<()> { - panic!("create_pipe is no longer implemented") - } - async fn send_tensor( &mut self, cx: &hyperactor::Context, @@ -770,7 +751,7 @@ impl WorkerMessageHandler for WorkerActor { // Resolve the stream. let stream = self.try_get_stream(stream)?; - let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) { + let device_meshes = if function.is_none() { HashMap::new() } else { self.device_meshes @@ -1440,294 +1421,6 @@ mod tests { Ok(()) } - #[async_timed_test(timeout_secs = 60)] - async fn py_remote_function_calls() -> Result<()> { - test_setup()?; - - let proc = Proc::local(); - let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); - - let worker_handle = proc - .spawn::( - "worker", - WorkerParams { - world_size: 1, - rank: 0, - device_index: None, - controller_actor: controller_ref, - }, - ) - .await - .unwrap(); - let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) = - Python::with_gil(|py| { - let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar") - .into_any() - .try_into()?; - let sort_list: PickledPyObject = - PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?; - let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?; - let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?; - let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?; - let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?; - let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?; - let device: PickledPyObject = py - .import("torch")? - .getattr("device")? - .call1(("cuda:1",))? - .try_into()?; - let memory_format: PickledPyObject = py - .import("torch")? - .getattr("contiguous_format")? - .try_into()?; - PyResult::Ok(( - split_arg, - sort_list, - mesh_ref, - dim, - layout, - none, - scalar, - device, - memory_format, - )) - })?; - - worker_handle - .command_group( - &client, - vec![ - WorkerMessage::CreateStream { - id: 1.into(), - stream_creation: StreamCreationMode::UseDefaultStream, - }, - WorkerMessage::CallFunction(CallFunctionParams { - seq: 0.into(), - results: vec![Some(0.into()), Some(Ref { id: 2 })], - mutates: vec![], - function: "os.path.split".into(), - args: vec![split_arg.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 2.into(), - results: vec![Some(4.into()), None, None, None, None], - mutates: vec![], - function: "builtins.sorted".into(), - args: vec![sort_list.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CreateDeviceMesh { - result: 5.into(), - names: vec!["x".into()], - ranks: Slice::new(0, vec![2], vec![1]).unwrap(), - }, - WorkerMessage::CallFunction(CallFunctionParams { - seq: 2.into(), - results: vec![Some(6.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(), - args: vec![mesh_ref.into(), dim.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 4.into(), - results: vec![Some(7.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type" - .into(), - args: vec![scalar.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 5.into(), - results: vec![Some(8.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(), - args: vec![layout.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 6.into(), - results: vec![Some(9.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.test_none".into(), - args: vec![none.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - // Verify that a function that returns `None` matches up with an - // empty result list. - WorkerMessage::CallFunction(CallFunctionParams { - seq: 7.into(), - results: vec![None], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.none".into(), - args: vec![], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 8.into(), - results: vec![Some(10.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.test_device".into(), - args: vec![device.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 9.into(), - results: vec![Some(11.into())], - mutates: vec![], - function: "monarch.monarch_tensor_worker.test_utils.test_memory_format" - .into(), - args: vec![memory_format.into()], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - // Test that list of tests can be passes correctly - WorkerMessage::CallFunction(CallFunctionParams { - seq: 10.into(), - results: vec![Some(12.into())], - mutates: vec![], - function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 11.into(), - results: vec![Some(13.into())], - mutates: vec![], - function: "torch.ops.aten.stack.default".into(), - args: vec![WireValue::RefList(vec![12.into(), 12.into()])], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - ], - ) - .await - .unwrap(); - - let result1: String = worker_handle - .get_ref_unit_tests_only(&client, 0.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap(); - let result2: String = worker_handle - .get_ref_unit_tests_only(&client, 2.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap(); - let result3: i64 = worker_handle - .get_ref_unit_tests_only(&client, 4.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap(); - let result4: i64 = worker_handle - .get_ref_unit_tests_only(&client, 6.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap(); - assert_eq!( - ScalarType::Float, - worker_handle - .get_ref_unit_tests_only(&client, 7.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap() - ); - assert_eq!( - Layout::Strided, - worker_handle - .get_ref_unit_tests_only(&client, 8.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap() - ); - assert_matches!( - worker_handle - .get_ref_unit_tests_only(&client, 9.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap(), - WireValue::None(()), - ); - let device: Device = CudaDevice::new(DeviceIndex(1)).into(); - assert_eq!( - device, - worker_handle - .get_ref_unit_tests_only(&client, 10.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap() - .try_into() - .unwrap() - ); - assert_matches!( - worker_handle - .get_ref_unit_tests_only(&client, 11.into(), 1.into()) - .await - .unwrap() - .unwrap() - .unwrap(), - WireValue::MemoryFormat(MemoryFormat::Contiguous), - ); - - worker_handle.drain_and_stop().unwrap(); - worker_handle.await; - let error_responses = controller_rx.drain(); - assert!( - error_responses.is_empty(), - "Expected no error responses, got: {:#?}", - error_responses - ); - - assert_eq!(result1, "/fbs/fbc/foo"); - assert_eq!(result2, "bar"); - assert_eq!(result3, 1); - assert_eq!(result4, 0); - - Ok(()) - } - #[async_timed_test(timeout_secs = 60)] async fn delete_refs() -> Result<()> { test_setup()?; @@ -1931,194 +1624,6 @@ mod tests { worker_handle2.await; } - #[async_timed_test(timeout_secs = 60)] - async fn send_value() -> Result<()> { - test_setup()?; - - let proc = Proc::local(); - let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); - - let worker_handle = proc - .spawn::( - "worker", - WorkerParams { - world_size: 1, - rank: 0, - device_index: None, - controller_actor: controller_ref, - }, - ) - .await - .unwrap(); - worker_handle - .command_group( - &client, - vec![ - WorkerMessage::CreateStream { - id: 1.into(), - stream_creation: StreamCreationMode::UseDefaultStream, - }, - WorkerMessage::CallFunction(CallFunctionParams { - seq: 0.into(), - results: vec![Some(0.into())], - mutates: vec![], - function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - }), - WorkerMessage::SendValue { - seq: 1.into(), - destination: None, - mutates: vec![], - function: None, - args: vec![WireValue::Ref(0.into())], - kwargs: HashMap::new(), - stream: 1.into(), - }, - WorkerMessage::SendValue { - seq: 2.into(), - destination: None, - mutates: vec![], - function: Some("torch.ops.aten.var_mean.default".into()), - args: vec![WireValue::Ref(0.into())], - kwargs: HashMap::new(), - stream: 1.into(), - }, - WorkerMessage::Exit { error: None }, - ], - ) - .await - .unwrap(); - - worker_handle.drain_and_stop()?; - assert_matches!(worker_handle.await, ActorStatus::Stopped); - - let mut responses = controller_rx.drain(); - assert_eq!( - responses.len(), - 3, - "Expected one response, got: {:#?}", - responses - ); - - match responses.pop().unwrap() { - ControllerMessage::FetchResult { seq, value } => { - assert_eq!(seq, 2.into()); - let value = value.unwrap().deserialized::>().unwrap(); - assert_eq!(value.leaves().len(), 2); - } - resp => panic!("unexpected response {:#?}", resp), - }; - match responses.pop().unwrap() { - ControllerMessage::FetchResult { seq, .. } => { - assert_eq!(seq, 1.into()) - } - resp => panic!("unexpected response {:#?}", resp), - }; - Ok(()) - } - - #[async_timed_test(timeout_secs = 60)] - async fn send_value_err_result() -> Result<()> { - test_setup()?; - - let proc = Proc::local(); - let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); - - let worker_handle = proc - .spawn::( - "worker", - WorkerParams { - world_size: 1, - rank: 0, - device_index: None, - controller_actor: controller_ref, - }, - ) - .await - .unwrap(); - - let ref_arg: PickledPyObject = - Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?; - - worker_handle - .command_group( - &client, - vec![ - WorkerMessage::CreateStream { - id: 1.into(), - stream_creation: StreamCreationMode::UseDefaultStream, - }, - WorkerMessage::SetRefUnitTestsOnly { - reference: Ref { id: 2 }, - value: WireValue::Bool(false), - stream: 1.into(), - }, - WorkerMessage::SendValue { - seq: 1.into(), - destination: None, - mutates: vec![Ref { id: 2 }], - function: Some("non.existent.function".into()), - args: vec![], - kwargs: HashMap::new(), - stream: 1.into(), - }, - WorkerMessage::SendValue { - seq: 2.into(), - destination: None, - mutates: vec![], - function: None, - args: vec![ref_arg.into()], - kwargs: HashMap::new(), - stream: 1.into(), - }, - WorkerMessage::Exit { error: None }, - ], - ) - .await - .unwrap(); - - worker_handle.drain_and_stop()?; - assert_matches!(worker_handle.await, ActorStatus::Stopped); - - let mut responses = controller_rx.drain(); - assert_eq!( - responses.len(), - 3, - "Expected one response, got: {:#?}", - responses - ); - match responses.pop() { - Some(ControllerMessage::FetchResult { seq, value }) => { - assert_eq!(seq, 2.into()); - assert!(value.is_err()); - assert!( - value - .unwrap_err() - .backtrace - .contains("failed to resolve function") - ); - } - _ => panic!("unexpected response {:#?}", responses), - } - match responses.pop() { - Some(ControllerMessage::FetchResult { seq, value }) => { - assert_eq!(seq, 1.into()); - assert!(value.is_err()); - assert!( - value - .unwrap_err() - .backtrace - .contains("failed to resolve function") - ); - } - _ => panic!("unexpected response {:#?}", responses), - } - Ok(()) - } - #[allow(dead_code)] fn get_random_channel_addr() -> ChannelAddr { let random_string = rand::thread_rng() diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 605cda893..5b9b4b482 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -53,8 +53,8 @@ use monarch_messages::worker::StreamRef; use monarch_types::PyTree; use monarch_types::SerializablePyErr; use monarch_types::TryIntoPyObjectUnsafe; +use pyo3::IntoPyObjectExt; use pyo3::prelude::*; -use pyo3::types::PyTuple; use tokio::runtime::Handle; use tokio::sync::Mutex; use tokio::task::JoinHandle; @@ -740,34 +740,6 @@ impl StreamActor { Ok(()) } - fn call_torch_op( - &self, - op: String, - overload: String, - args: Vec, - kwargs: HashMap, - ) -> Result, CallFunctionError> { - let args = args - .into_iter() - .map(|arg| self.wire_to_rvalue(arg)) - .collect::, _>>()?; - let kwargs = kwargs - .into_iter() - .map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue))) - .collect::, CallFunctionError>>()?; - - let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?; - - // Handle the case where the op returns nothing and convert it to a list of None. - // This is to ensure handle results does not error out as the client will call - // such a function with expected results of size 1. - Ok(if results.is_empty() { - vec![RValue::None] - } else { - results - }) - } - fn call_python_fn<'py>( &mut self, py: Python<'py>, @@ -835,14 +807,7 @@ impl StreamActor { let mut multiborrow = MultiBorrow::new(); let resolve = |val: WireValue| { - val.into_py_object() - .map_err(|e| { - CallFunctionError::UnsupportedArgType( - format!("{:?}", function), - format!("{:?}", e), - ) - })? - .unpickle(py) + val.into_bound_py_any(py) .map_err(SerializablePyErr::from_fn(py))? .extract::>() .map_err(SerializablePyErr::from_fn(py))? @@ -879,7 +844,7 @@ impl StreamActor { .flat_map(|o| o.iter()) .for_each(|arg| { if let PyArg::RValue(rval) = arg { - multiborrow.add(rval, BorrowType::Shared); + multiborrow.add(&rval, BorrowType::Shared); } }); @@ -1118,21 +1083,17 @@ impl StreamMessageHandler for StreamActor { params.results, ¶ms.mutates, async |self| { - tokio::task::block_in_place(|| match params.function.as_torch_op() { - Some((op, overload)) => { - self.call_torch_op(op, overload, params.args, params.kwargs) - } - _ => self - .call_python_fn_pytree( - cx, - params.function, - params.args, - params.kwargs, - ¶ms.mutates, - device_meshes, - remote_process_groups, - ) - .map(|results| results.into_leaves()), + tokio::task::block_in_place(|| { + self.call_python_fn_pytree( + cx, + params.function, + params.args, + params.kwargs, + ¶ms.mutates, + device_meshes, + remote_process_groups, + ) + .map(|results| results.into_leaves()) }) }, ) @@ -1562,58 +1523,25 @@ impl StreamMessageHandler for StreamActor { } let result = if let Some(function) = function { // If a function was provided, use that to resolve the value. - match function.as_torch_op() { - Some((op, overload)) => { - self.call_torch_op(op, overload, args, kwargs) - .map(|rvalues| { - if rvalues.len() == 1 { - Ok(rvalues[0].clone().into()) - } else { - // TODO: Replace with native pytrees when possible - Python::with_gil(|py| { - Ok((|| { - let py_rvalues = rvalues - .into_iter() - // SAFETY: This inherits the unsafety of `try_to_object_unsafe`. - .map(|rvalue| unsafe { - rvalue.try_to_object_unsafe(py) - }) - .collect::, _>>()?; - PyTuple::new(py, &py_rvalues)?.extract::>() - })() - .map_err(SerializablePyErr::from_fn(py))?) - }) - } - })? - } - // Use block-in-place to allow nested callbacks to re-enter the - // runtime to run async code. - _ => tokio::task::block_in_place(|| { - self.call_python_fn_pytree( - cx, - function, - args, - kwargs, - &mutates, - device_meshes, - HashMap::new(), - ) - }), - } + tokio::task::block_in_place(|| { + self.call_python_fn_pytree( + cx, + function, + args, + kwargs, + &mutates, + device_meshes, + HashMap::new(), + ) + }) } else { // If there's no function provided, there should be exactly one arg // and no kwargs. match (args.len(), kwargs.len()) { (1, 0) => Python::with_gil(|py| { let arg = args[0] - .as_py_object() - .ok_or_else(|| { - CallFunctionError::UnsupportedArgType( - "send_value".to_string(), - "expected a PyObject as the first arg".to_string(), - ) - })? - .unpickle(py) + .clone() + .into_pyobject(py) .map_err(SerializablePyErr::from_fn(py))?; arg.extract::>() .map_err(SerializablePyErr::from_fn(py))? @@ -2602,264 +2530,6 @@ mod tests { Ok(()) } - #[async_timed_test(timeout_secs = 60)] - async fn test_call_function_in_recording() -> Result<()> { - let mut test_setup = TestSetup::new().await?; - - // Define a recording equivalent to: - // def f(x, y): - // w = x + y - // nonlocal z - // z.add_(1.0) - // return w + z - test_setup - .stream_actor - .define_recording(&test_setup.client, 0.into()) - .await?; - - let formal0_ref = test_setup.next_ref(); - let formal0_index = 0; - test_setup - .stream_actor - .recording_formal(&test_setup.client, formal0_ref, formal0_index) - .await?; - - let formal1_ref = test_setup.next_ref(); - let formal1_index = 1; - test_setup - .stream_actor - .recording_formal(&test_setup.client, formal1_ref, formal1_index) - .await?; - - let captured_ref = test_setup.next_ref(); - let result_captured_ref = test_setup.next_ref(); - let add_one_function = - ResolvableFunction::FunctionPath("torch.ops.aten.add_.Scalar".into()); - let add_tensors_function = - ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()); - - let add_result_ref_0 = test_setup.next_ref(); - test_setup - .stream_actor - .call_function( - &test_setup.client, - CallFunctionParams { - seq: 100.into(), - function: add_tensors_function.clone(), - args: vec![WireValue::Ref(formal0_ref), WireValue::Ref(formal1_ref)], - kwargs: HashMap::new(), - results: vec![Some(add_result_ref_0)], - mutates: vec![], - stream: 0.into(), - remote_process_groups: Vec::new(), - }, - HashMap::new(), - HashMap::new(), - ) - .await?; - - test_setup - .stream_actor - .call_function( - &test_setup.client, - CallFunctionParams { - seq: 101.into(), - function: add_one_function, - args: vec![WireValue::Ref(captured_ref), WireValue::Double(1.0)], - kwargs: HashMap::new(), - results: vec![Some(result_captured_ref)], - mutates: vec![captured_ref], - stream: 0.into(), - remote_process_groups: Vec::new(), - }, - HashMap::new(), - HashMap::new(), - ) - .await?; - - let add_result_ref_1 = test_setup.next_ref(); - test_setup - .stream_actor - .call_function( - &test_setup.client, - CallFunctionParams { - seq: 102.into(), - function: add_tensors_function, - args: vec![ - WireValue::Ref(add_result_ref_0), - WireValue::Ref(captured_ref), - ], - kwargs: HashMap::new(), - results: vec![Some(add_result_ref_1)], - mutates: vec![], - stream: 0.into(), - remote_process_groups: Vec::new(), - }, - HashMap::new(), - HashMap::new(), - ) - .await?; - - test_setup - .stream_actor - .recording_result(&test_setup.client, add_result_ref_1, 0) - .await?; - - test_setup - .stream_actor - .delete_refs( - &test_setup.client, - vec![add_result_ref_0, add_result_ref_1, result_captured_ref], - ) - .await?; - - test_setup - .stream_actor - .finalize_recording(&test_setup.client, 0.into()) - .await?; - - let actual0_ref = test_setup.next_ref(); - test_setup.set_tensor(actual0_ref, &[1.0, 2.0, 3.0]).await?; - - let actual1_ref = test_setup.next_ref(); - test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?; - - test_setup - .set_tensor(captured_ref, &[7.0, 8.0, 9.0]) - .await?; - - let actual_result_ref = test_setup.next_ref(); - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 0.into(), - 0.into(), - vec![actual_result_ref], - vec![actual0_ref, actual1_ref], - ) - .await?; - - assert!( - test_setup - .allclose(actual_result_ref, &[13.0, 16.0, 19.0]) - .await - ); - - // Set actual1_tensor to a bad shape which will cause the recording to fail. - test_setup.set_tensor(actual1_ref, &[4.0, 5.0]).await?; - - let actual_result_ref = test_setup.next_ref(); - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 1.into(), - 0.into(), - vec![actual_result_ref], - vec![actual0_ref, actual1_ref], - ) - .await?; - - // Both inputs should still be valid. - for ref_ in [actual0_ref, actual1_ref] { - let _ = test_setup - .stream_actor - .get_tensor_ref_unit_tests_only(&test_setup.client, ref_) - .await? - .unwrap() - .unwrap(); - } - - for ref_ in [captured_ref, actual_result_ref] { - let result_error = test_setup - .stream_actor - .get_tensor_ref_unit_tests_only(&test_setup.client, ref_) - .await? - .unwrap() - .unwrap_err(); - // Check that the error contains the expected strings - let error_str = result_error.to_string(); - assert!( - error_str.contains("torch operator error"), - "Error should contain 'torch operator failed': {}", - error_str - ); - } - - let controller_msg = test_setup.controller_rx.recv().await.unwrap(); - match controller_msg { - ControllerMessage::RemoteFunctionFailed { seq, error } => { - assert_eq!(seq, 1.into()); - assert!( - error.backtrace.contains("torch operator error"), - "Unexpected WorkerError: {:?}", - error - ); - } - _ => panic!("Unexpected controller message: {:?}", controller_msg), - }; - - // Reset input tensor to a valid shape. - test_setup.set_tensor(actual1_ref, &[4.0, 5.0, 6.0]).await?; - - // captured_tensor should still have an error, so calling - // the recording should set DependentErrors and not report - // anything to the controller. - let actual_result_ref = test_setup.next_ref(); - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 2.into(), - 0.into(), - vec![actual_result_ref], - vec![actual0_ref, actual1_ref], - ) - .await?; - - // Both inputs should still be valid. - for ref_ in [actual0_ref, actual1_ref] { - let _ = test_setup - .stream_actor - .get_tensor_ref_unit_tests_only(&test_setup.client, ref_) - .await? - .unwrap() - .unwrap(); - } - - for ref_ in [captured_ref, actual_result_ref] { - let result_error = test_setup - .stream_actor - .get_tensor_ref_unit_tests_only(&test_setup.client, ref_) - .await? - .unwrap() - .unwrap_err(); - // Check that the error contains the expected strings - let error_str = result_error.to_string(); - assert!( - error_str.contains("torch operator error"), - "Error should contain input error: {}", - error_str - ); - } - - // This tests that the DependentError was never reported to the controller. - // If it were reported to the controller, the next message would match - // RemoteFunctionFailed instead of FetchResult. - check_fetch_result_error( - &test_setup.client, - test_setup.stream_actor.clone(), - 3.into(), - captured_ref, - &mut test_setup.controller_rx, - "torch operator error", - ) - .await; - - Ok(()) - } - #[async_timed_test(timeout_secs = 60)] async fn test_borrow_create_duplicate_borrow() -> Result<()> { let mut test_setup = TestSetup::new().await?; @@ -2964,928 +2634,4 @@ mod tests { Ok(()) } - - #[async_timed_test(timeout_secs = 60)] - async fn test_borrow_in_recording() -> Result<()> { - let mut test_setup = TestSetup::new().await?; - - let borrower_stream = test_setup - .proc - .spawn::( - "stream1", - StreamParams { - world_size: 1, - rank: 0, - creation_mode: StreamCreationMode::CreateNewStream, - id: 1.into(), - device: Some(CudaDevice::new(0.into())), - controller_actor: test_setup.controller_actor.clone(), - respond_with_python_message: false, - }, - ) - .await?; - - let lender_stream = test_setup.stream_actor.clone(); - - let borrow_id = 1; - let (first_use_sender, first_use_receiver) = test_setup.client.open_port(); - let (last_use_sender, last_use_receiver) = test_setup.client.open_port(); - - // Stream 1: Define a recording that creates a borrow and drops it. - lender_stream - .define_recording(&test_setup.client, 0.into()) - .await?; - - let formal_ref = test_setup.next_ref(); - lender_stream - .recording_formal(&test_setup.client, formal_ref, 0) - .await?; - - lender_stream - .borrow_create(&test_setup.client, borrow_id, formal_ref, first_use_sender) - .await?; - - lender_stream - .borrow_drop( - &test_setup.client, - borrow_id, - Arc::new(Mutex::new(last_use_receiver)), - ) - .await?; - - lender_stream - .finalize_recording(&test_setup.client, 0.into()) - .await?; - - let borrower_tensor_ref = test_setup.next_ref(); - let borrower_tensor = TensorCell::new(factory_float_tensor( - &[1.0, 2.0, 3.0], - "cuda".try_into().unwrap(), - )); - - borrower_stream - .set_tensor_ref_unit_tests_only( - &test_setup.client, - borrower_tensor_ref, - Ok(borrower_tensor.clone()), - ) - .await?; - - // Stream 2: Define a recording that uses the borrow from Stream 1. - borrower_stream - .define_recording(&test_setup.client, 0.into()) - .await?; - - let borrowed_ref = test_setup.next_ref(); - - borrower_stream - .borrow_first_use( - &test_setup.client, - borrow_id, - borrowed_ref, - Arc::new(Mutex::new(first_use_receiver)), - ) - .await?; - - let result_ref = test_setup.next_ref(); - borrower_stream - .call_function( - &test_setup.client, - CallFunctionParams { - seq: 100.into(), - function: ResolvableFunction::FunctionPath("torch.ops.aten.add.Tensor".into()), - args: vec![ - WireValue::Ref(borrowed_ref), - WireValue::Ref(borrower_tensor_ref), - ], - kwargs: HashMap::new(), - results: vec![Some(result_ref)], - mutates: vec![], - stream: 1.into(), - remote_process_groups: Vec::new(), - }, - HashMap::new(), - HashMap::new(), - ) - .await?; - - borrower_stream - .borrow_last_use(&test_setup.client, borrow_id, borrowed_ref, last_use_sender) - .await?; - - borrower_stream - .recording_result(&test_setup.client, result_ref, 0) - .await?; - - borrower_stream - .finalize_recording(&test_setup.client, 0.into()) - .await?; - - // Set up a tensor in the lender stream and call the recording. - let input_tensor_ref = test_setup.next_ref(); - test_setup - .set_tensor(input_tensor_ref, &[4.0, 5.0, 6.0]) - .await?; - - let result_tensor_ref = test_setup.next_ref(); - - let lender_future = lender_stream.call_recording( - &test_setup.client, - 0.into(), - 0.into(), - vec![], - vec![input_tensor_ref], - ); - - let borrower_future = borrower_stream.call_recording( - &test_setup.client, - 0.into(), - 0.into(), - vec![result_tensor_ref], - vec![], - ); - - tokio::try_join!(lender_future, borrower_future)?; - - let result_tensor = borrower_stream - .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref) - .await? - .unwrap() - .unwrap(); - - let expected_tensor = TensorCell::new(factory_float_tensor( - &[5.0, 7.0, 9.0], - "cpu".try_into().unwrap(), - )); - assert!(allclose(&result_tensor.borrow(), &expected_tensor.borrow()).unwrap()); - - // Set borrower_tensor to a tensor with only 2 elements to cause a failure. - let invalid_borrower_tensor = TensorCell::new(factory_float_tensor( - &[1.0, 2.0], - "cuda".try_into().unwrap(), - )); - borrower_stream - .set_tensor_ref_unit_tests_only( - &test_setup.client, - borrower_tensor_ref, - Ok(invalid_borrower_tensor.clone()), - ) - .await?; - - // Call the recording again. - let lender_future = lender_stream.call_recording( - &test_setup.client, - 1.into(), - 0.into(), - vec![], - vec![input_tensor_ref], - ); - - let borrower_future = borrower_stream.call_recording( - &test_setup.client, - 1.into(), - 0.into(), - vec![result_tensor_ref], - vec![], - ); - - tokio::try_join!(lender_future, borrower_future)?; - - // Check that the borrower_stream reports the error to the controller. - let controller_msg = test_setup.controller_rx.recv().await.unwrap(); - match controller_msg { - ControllerMessage::RemoteFunctionFailed { seq, error } => { - assert_eq!(seq, 1.into()); - assert!( - error.backtrace.contains("recording failed"), - "Unexpected WorkerError: {:?}", - error - ); - assert_eq!(&error.worker_actor_id, borrower_stream.actor_id()); - } - _ => panic!("Unexpected controller message: {:?}", controller_msg), - }; - - // Check that no error was reported from the lender stream - check_fetch_result_value( - &test_setup.client, - lender_stream.clone(), - 2.into(), - input_tensor_ref, - &mut test_setup.controller_rx, - ) - .await; - - // Set the recording's input tensor to an error. - let input_error = fake_seq_error(anyhow!("input error")); - lender_stream - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref, - Err(input_error.clone()), - ) - .await?; - - let lender_future = lender_stream.call_recording( - &test_setup.client, - 3.into(), - 0.into(), - vec![], - vec![input_tensor_ref], - ); - - let borrower_future = borrower_stream.call_recording( - &test_setup.client, - 3.into(), - 0.into(), - vec![result_tensor_ref], - vec![], - ); - - tokio::try_join!(lender_future, borrower_future)?; - - // Verify that borrower_stream sets a CallFunctionError::DependentError on result_tensor_ref. - let result_error = borrower_stream - .get_tensor_ref_unit_tests_only(&test_setup.client, result_tensor_ref) - .await? - .unwrap() - .unwrap_err(); - - // Check that the error contains the expected strings - let error_str = result_error.to_string(); - assert!( - error_str.contains("input error"), - "Error should contain input error: {}", - error_str - ); - - // Since we're checking for pointer equality in the original code, we need to ensure - // the error is propagated correctly. We can check that the original error message is contained. - let input_error_str = input_error.to_string(); - assert!( - error_str.contains(&input_error_str), - "Error should contain the original error: {}", - error_str - ); - - // Verify that neither stream sends a failure message to the controller. - check_fetch_result_error( - &test_setup.client, - lender_stream, - 4.into(), - input_tensor_ref, - &mut test_setup.controller_rx, - "input error", - ) - .await; - - // Verify that neither stream sends a failure message to the controller. - check_fetch_result_error( - &test_setup.client, - borrower_stream, - 5.into(), - result_tensor_ref, - &mut test_setup.controller_rx, - "input error", - ) - .await; - - Ok(()) - } - - #[async_timed_test(timeout_secs = 60)] - async fn test_reduce_in_recording() -> Result<()> { - let mut test_setup = TestSetup::new().await?; - let recording_ref = test_setup.next_ref(); - - let comm = Arc::new( - test_setup - .proc - .spawn::( - "comm", - CommParams::New { - device: CudaDevice::new(0.into()), - unique_id: UniqueId::new()?, - world_size: 1, - rank: 0, - }, - ) - .await?, - ); - - let factory = Factory { - size: vec![3], - dtype: torch_sys::ScalarType::Float, - layout: torch_sys::Layout::Strided, - device: "cuda".try_into().unwrap(), - }; - - let reduction = Reduction::ReduceOp(torch_sys_cuda::nccl::ReduceOp::Sum); - - test_setup - .stream_actor - .define_recording(&test_setup.client, recording_ref) - .await?; - - let formal_tensor_ref_0 = test_setup.next_ref(); - let formal_tensor_ref_1 = test_setup.next_ref(); - let formal_tensor_ref_2 = test_setup.next_ref(); - - test_setup - .stream_actor - .recording_formal(&test_setup.client, formal_tensor_ref_0, 0) - .await?; - test_setup - .stream_actor - .recording_formal(&test_setup.client, formal_tensor_ref_1, 1) - .await?; - test_setup - .stream_actor - .recording_formal(&test_setup.client, formal_tensor_ref_2, 2) - .await?; - - let intermediate_tensor_ref_0 = test_setup.next_ref(); - - // Handle case with in_place = true. - test_setup - .stream_actor - .reduce( - &test_setup.client, - comm.clone(), - 1, - intermediate_tensor_ref_0, - formal_tensor_ref_0, - factory.clone(), - reduction.clone(), - false, - true, - None, - ) - .await?; - - // Handle case with in_place = false and out = None. - let intermediate_tensor_ref_1 = test_setup.next_ref(); - test_setup - .stream_actor - .reduce( - &test_setup.client, - comm.clone(), - 1, - intermediate_tensor_ref_1, - formal_tensor_ref_1, - factory.clone(), - reduction.clone(), - false, - false, - None, - ) - .await?; - - let intermediate_tensor_ref_2 = test_setup.next_ref(); - - // Third reduce call with out = formal_tensor_ref_2 - test_setup - .stream_actor - .reduce( - &test_setup.client, - comm.clone(), - 1, - intermediate_tensor_ref_2, - intermediate_tensor_ref_1, - factory.clone(), - reduction.clone(), - false, - false, - Some(formal_tensor_ref_2), - ) - .await?; - - test_setup - .stream_actor - .recording_result(&test_setup.client, intermediate_tensor_ref_2, 0) - .await?; - - test_setup - .stream_actor - .finalize_recording(&test_setup.client, recording_ref) - .await?; - - let input_tensor_ref_0 = test_setup.next_ref(); - let input_tensor_ref_1 = test_setup.next_ref(); - let input_tensor_ref_2 = test_setup.next_ref(); - - test_setup - .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await?; - - test_setup - .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await?; - - test_setup - .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0]) - .await?; - - let output_ref = test_setup.next_ref(); - - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 0.into(), - recording_ref, - vec![output_ref], - vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2], - ) - .await?; - - // Validate that input_tensor_ref_0 is unchanged. - assert!( - test_setup - .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await - ); - // All the other inputs/outputs should be equal to input 1 - for ref_ in [input_tensor_ref_1, input_tensor_ref_2, output_ref] { - assert!(test_setup.allclose(ref_, &[4.0, 5.0, 6.0]).await); - } - - // Set an error on input 0 - let input_error = fake_seq_error(anyhow!("input error")); - test_setup - .stream_actor - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref_0, - Err(input_error.clone()), - ) - .await?; - - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 1.into(), - recording_ref, - vec![output_ref], - vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2], - ) - .await?; - - // Verify that input_tensor_ref_0, input_tensor_ref_2, and output_ref have a dependent error. - for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] { - test_setup - .validate_dependent_error(ref_, input_error.clone()) - .await; - } - - // Verify that input_tensor_ref_1 is untouched. - assert!( - test_setup - .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await - ); - - // Verify that no failure was reported to the controller. - check_fetch_result_value( - &test_setup.client, - test_setup.stream_actor.clone(), - 2.into(), - input_tensor_ref_1, - &mut test_setup.controller_rx, - ) - .await; - - // Reset input tensors 0 and 2 to their original values - test_setup - .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await?; - test_setup - .set_tensor(input_tensor_ref_2, &[7.0, 8.0, 9.0]) - .await?; - - // Set an error on input tensor 1 - test_setup - .stream_actor - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref_1, - Err(input_error.clone()), - ) - .await?; - - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 3.into(), - recording_ref, - vec![output_ref], - vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2], - ) - .await?; - - // Validate that the mutated inputs and the output have a dependent error containing - // the input error - for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] { - test_setup - .validate_dependent_error(ref_, input_error.clone()) - .await; - } - - // Validate that no error was reported to the controller - check_fetch_result_error( - &test_setup.client, - test_setup.stream_actor.clone(), - 4.into(), - input_tensor_ref_1, - &mut test_setup.controller_rx, - "input error", - ) - .await; - - // Reset input tensors 0 and 1 to their original values - test_setup - .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await?; - test_setup - .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await?; - - // Set an error on input tensor 2 - test_setup - .stream_actor - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref_2, - Err(input_error.clone()), - ) - .await?; - - test_setup - .stream_actor - .call_recording( - &test_setup.client, - 5.into(), - recording_ref, - vec![output_ref], - vec![input_tensor_ref_0, input_tensor_ref_1, input_tensor_ref_2], - ) - .await?; - - // Validate that input tensor 1 has its original values - assert!( - test_setup - .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await - ); - - // Validate that the mutated inputs and the output have a dependent error containing - // the input error - for ref_ in [input_tensor_ref_0, input_tensor_ref_2, output_ref] { - test_setup - .validate_dependent_error(ref_, input_error.clone()) - .await; - } - - // Validate that no error was reported to the controller - check_fetch_result_value( - &test_setup.client, - test_setup.stream_actor.clone(), - 6.into(), - input_tensor_ref_1, - &mut test_setup.controller_rx, - ) - .await; - - Ok(()) - } - - #[async_timed_test(timeout_secs = 60)] - async fn test_send_tensor_in_recording() -> Result<()> { - let mut test_setup = TestSetup::new_with_world_size(2).await?; - let recording_ref = test_setup.next_ref(); - - let unique_id = UniqueId::new()?; - let comm0 = test_setup.proc.spawn::( - "comm0", - CommParams::New { - device: CudaDevice::new(0.into()), - unique_id: unique_id.clone(), - world_size: 2, - rank: 0, - }, - ); - let comm1 = test_setup.proc.spawn::( - "comm1", - CommParams::New { - device: CudaDevice::new(1.into()), - unique_id, - world_size: 2, - rank: 1, - }, - ); - let (comm0, comm1) = tokio::try_join!(comm0, comm1)?; - let comm0 = Arc::new(comm0); - let comm1 = Arc::new(comm1); - - let factory = Factory { - size: vec![3], - dtype: torch_sys::ScalarType::Float, - layout: torch_sys::Layout::Strided, - device: "cuda".try_into().unwrap(), - }; - - let send_stream = test_setup.stream_actor.clone(); - let recv_stream = test_setup - .proc - .spawn::( - "recv_stream", - StreamParams { - world_size: 2, - rank: 1, - creation_mode: StreamCreationMode::CreateNewStream, - id: 1.into(), - device: Some(CudaDevice::new(1.into())), - controller_actor: test_setup.controller_actor.clone(), - respond_with_python_message: false, - }, - ) - .await?; - - send_stream - .define_recording(&test_setup.client, recording_ref) - .await?; - recv_stream - .define_recording(&test_setup.client, recording_ref) - .await?; - - let formal_tensor_ref_0 = test_setup.next_ref(); - let formal_tensor_ref_1 = test_setup.next_ref(); - - send_stream - .recording_formal(&test_setup.client, formal_tensor_ref_0, 0) - .await?; - send_stream - .recording_formal(&test_setup.client, formal_tensor_ref_1, 1) - .await?; - - let _ref = test_setup.next_ref(); - send_stream - .send_tensor( - &test_setup.client, - _ref, - None, - Some(1), - formal_tensor_ref_0, - factory.clone(), - comm0.clone(), - ) - .await?; - - let result_ref_0 = test_setup.next_ref(); - let _ref = test_setup.next_ref(); - recv_stream - .send_tensor( - &test_setup.client, - result_ref_0, - Some(0), - None, - _ref, - factory.clone(), - comm1, - ) - .await?; - - let result_ref_1 = test_setup.next_ref(); - send_stream - .send_tensor( - &test_setup.client, - result_ref_1, - Some(0), - Some(0), - formal_tensor_ref_1, - factory.clone(), - comm0, - ) - .await?; - - send_stream - .recording_result(&test_setup.client, result_ref_1, 0) - .await?; - recv_stream - .recording_result(&test_setup.client, result_ref_0, 0) - .await?; - - send_stream - .finalize_recording(&test_setup.client, recording_ref) - .await?; - recv_stream - .finalize_recording(&test_setup.client, recording_ref) - .await?; - - let input_tensor_ref_0 = test_setup.next_ref(); - let input_tensor_ref_1 = test_setup.next_ref(); - test_setup - .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await?; - test_setup - .set_tensor(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await?; - - let actual_result_ref_0 = test_setup.next_ref(); - let actual_result_ref_1 = test_setup.next_ref(); - let send_fut = send_stream.call_recording( - &test_setup.client, - 0.into(), - recording_ref, - vec![actual_result_ref_1], - vec![input_tensor_ref_0, input_tensor_ref_1], - ); - let recv_fut = recv_stream.call_recording( - &test_setup.client, - 0.into(), - recording_ref, - vec![actual_result_ref_0], - vec![], - ); - tokio::try_join!(send_fut, recv_fut)?; - - assert!( - test_setup - .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await - ); - assert!( - test_setup - .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await - ); - assert!( - test_setup - .allclose(actual_result_ref_1, &[4.0, 5.0, 6.0]) - .await - ); - - let actual_result_0 = recv_stream - .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0) - .await - .unwrap() - .unwrap() - .unwrap(); - assert!(allclose( - &actual_result_0.borrow(), - &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap()) - )?); - - // Validate that failure wasn't reported to controller. - check_fetch_result_value( - &test_setup.client, - send_stream.clone(), - 1.into(), - actual_result_ref_1, - &mut test_setup.controller_rx, - ) - .await; - check_fetch_result_value( - &test_setup.client, - recv_stream.clone(), - 2.into(), - actual_result_ref_0, - &mut test_setup.controller_rx, - ) - .await; - - let input_error = fake_seq_error(anyhow!("input error")); - send_stream - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref_0, - Err(input_error.clone()), - ) - .await?; - - let send_fut = send_stream.call_recording( - &test_setup.client, - 3.into(), - recording_ref, - vec![actual_result_ref_1], - vec![input_tensor_ref_0, input_tensor_ref_1], - ); - let recv_fut = recv_stream.call_recording( - &test_setup.client, - 3.into(), - recording_ref, - vec![actual_result_ref_0], - vec![], - ); - tokio::try_join!(send_fut, recv_fut)?; - - // The result on recv_stream should have a value, but it will be garbage. - let _ = recv_stream - .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0) - .await - .unwrap() - .unwrap() - .unwrap(); - - test_setup - .validate_dependent_error(actual_result_ref_1, input_error.clone()) - .await; - - // Input 1 should be untouched. - assert!( - test_setup - .allclose(input_tensor_ref_1, &[4.0, 5.0, 6.0]) - .await - ); - - // Validate that failure wasn't reported to controller. - check_fetch_result_error( - &test_setup.client, - send_stream.clone(), - 4.into(), - actual_result_ref_1, - &mut test_setup.controller_rx, - "input error", - ) - .await; - check_fetch_result_value( - &test_setup.client, - recv_stream.clone(), - 5.into(), - actual_result_ref_0, - &mut test_setup.controller_rx, - ) - .await; - - test_setup - .set_tensor(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await?; - send_stream - .set_tensor_ref_unit_tests_only( - &test_setup.client, - input_tensor_ref_1, - Err(input_error.clone()), - ) - .await?; - - let send_fut = send_stream.call_recording( - &test_setup.client, - 6.into(), - recording_ref, - vec![actual_result_ref_1], - vec![input_tensor_ref_0, input_tensor_ref_1], - ); - let recv_fut = recv_stream.call_recording( - &test_setup.client, - 6.into(), - recording_ref, - vec![actual_result_ref_0], - vec![], - ); - tokio::try_join!(send_fut, recv_fut)?; - - let actual_result_0 = recv_stream - .get_tensor_ref_unit_tests_only(&test_setup.client, actual_result_ref_0) - .await - .unwrap() - .unwrap() - .unwrap(); - assert!(allclose( - &actual_result_0.borrow(), - &factory_float_tensor(&[1.0, 2.0, 3.0], "cpu".try_into().unwrap()) - )?); - - assert!( - test_setup - .allclose(input_tensor_ref_0, &[1.0, 2.0, 3.0]) - .await - ); - - test_setup - .validate_dependent_error(actual_result_ref_1, input_error) - .await; - - // Validate that failure wasn't reported to controller. - check_fetch_result_error( - &test_setup.client, - send_stream.clone(), - 7.into(), - actual_result_ref_1, - &mut test_setup.controller_rx, - "input error", - ) - .await; - check_fetch_result_value( - &test_setup.client, - recv_stream.clone(), - 8.into(), - actual_result_ref_0, - &mut test_setup.controller_rx, - ) - .await; - - Ok(()) - } }