Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,6 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
to_stream: p.parseStreamRef("to_stream")?,
})
});
m.insert(key("CreatePipe"), |p| {
let function = p.parseFunction("function")?;
let args = p.parse("args")?;
let kwargs = p.parse("kwargs")?;
let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?;
Ok(WorkerMessage::CreatePipe {
result: p.parseRef("result")?,
key: p.parse("key")?,
function,
max_messages: p.parse("max_messages")?,
mesh: p.parseRef("device_mesh")?,
args,
kwargs,
})
});
m.insert(key("SendValue"), |p| {
let function = p.parseOptionalFunction("function")?;
let args: Bound<'_, PyTuple> = p.parse("args")?;
Expand Down
102 changes: 2 additions & 100 deletions monarch_messages/src/wire_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,110 +171,12 @@ impl From<PyObject> for WireValue {
}
}

impl WireValue {
fn from_pyobject_with_torch_op_arg_type(
obj: Bound<'_, PyAny>,
type_: &torch_sys::call_op::TypePtr,
num_elements: i32,
allow_nums_as_tensors: bool,
) -> PyResult<Self> {
if type_.is_tensor() || type_.is_optional_tensor() {
if type_.is_optional_tensor() && obj.is_none() {
return Ok(WireValue::None(()));
} else if let Ok(ref_) = Ref::from_py_object(&obj) {
return Ok(WireValue::Ref(ref_));
}
}
if type_.is_tensor_list() || type_.is_optional_tensor_list() {
if type_.is_optional_tensor_list() && obj.is_none() {
return Ok(WireValue::None(()));
}
let list = obj.downcast::<PyList>()?;
let len = list.len();
if len == 0 {
return Ok(WireValue::RefList(vec![]));
}
// SAFETY: We know it is within bounds
let item = unsafe { list.get_item_unchecked(0) };
if let Ok(ref_) = Ref::from_py_object(&item) {
let mut ref_list = Vec::with_capacity(len);
ref_list.push(ref_);
for item in list.iter().skip(1) {
ref_list.push(Ref::from_py_object(&item).map_err(|_| {
PyValueError::new_err(format!(
"Expected homogeneous list of refs got: {:?}",
list
))
})?);
}
return Ok(WireValue::RefList(ref_list));
}
}
OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
.map(WireValue::IValue)
}
}

pub fn func_call_args_to_wire_values(
func: Option<&ResolvableFunction>,
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) {
torch_op_args_to_wire_values(&op, &overload, args, kwargs)
} else {
python_func_args_to_wire_value(args, kwargs)
}
}

fn torch_op_args_to_wire_values(
op: &str,
overload: &str,
_func: Option<&ResolvableFunction>,
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| {
PyValueError::new_err(format!(
"Failed to get the operator schema for {}::{}: {}",
op, overload, err
))
})?;

let args = args
.iter()
.zip(&args_info)
.map(|(arg, arg_info)| {
WireValue::from_pyobject_with_torch_op_arg_type(
arg,
arg_info.type_,
arg_info.num_elements,
arg_info.allows_number_as_tensor,
)
})
.collect::<Result<Vec<_>, _>>()?;
let kwargs = kwargs
.iter()
.map(|(k, v)| {
let key = k.extract::<String>()?;
let arg_info = args_info
.iter()
.find(|arg_info| arg_info.name == key)
.ok_or_else(|| {
PyValueError::new_err(format!(
"Torch op {}::{} does not support kwarg {}",
op, overload, key
))
})?;
let val = WireValue::from_pyobject_with_torch_op_arg_type(
v,
arg_info.type_,
arg_info.num_elements,
arg_info.allows_number_as_tensor,
)?;
Ok((key, val))
})
.collect::<Result<HashMap<_, _>, PyErr>>()?;
Ok((args, kwargs))
python_func_args_to_wire_value(args, kwargs)
}

fn python_func_args_to_wire_value(
Expand Down
32 changes: 0 additions & 32 deletions monarch_messages/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,6 @@ impl ResolvableFunction {
}
}

pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
match self {
Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
["torch", "ops", namespace, op_name, "default"] => {
Some((format!("{}::{}", namespace, op_name), String::new()))
}
["torch", "ops", namespace, op_name, overload] => {
Some((format!("{}::{}", namespace, op_name), overload.to_string()))
}
_ => None,
},
_ => None,
}
}

/// For testing: this is a special remote function path that induces a panic
/// when called.
pub fn panic_if_requested(&self) {
Expand All @@ -367,13 +352,6 @@ impl ResolvableFunction {
_ => (),
}
}

pub fn supports_pytree_args(&self) -> bool {
match self {
Self::Cloudpickle(_) => true,
Self::FunctionPath(_) => self.as_torch_op().is_none(),
}
}
}

impl<T: Into<String>> From<T> for ResolvableFunction {
Expand Down Expand Up @@ -800,16 +778,6 @@ pub enum WorkerMessage {
to_stream: StreamRef,
},

CreatePipe {
result: Ref,
key: String,
function: ResolvableFunction,
max_messages: i64,
mesh: Ref,
args: Vec<WireValue>,
kwargs: HashMap<String, WireValue>,
},

SendValue {
seq: Seq,
/// Pipe to send value to. If `None`, value is sent to controller.
Expand Down
32 changes: 5 additions & 27 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ use monarch_messages::worker::StreamRef;
use monarch_messages::worker::WorkerMessage;
use monarch_messages::worker::WorkerMessageHandler;
use monarch_messages::worker::WorkerParams;
use monarch_types::PyTree;
use ndslice::Slice;
use pyo3::Python;
use pyo3::types::PyAnyMethods;
Expand All @@ -92,7 +91,6 @@ use stream::StreamParams;
use torch_sys::CudaDevice;
use torch_sys::DeviceIndex;
use torch_sys::Layout;
use torch_sys::RValue;
use torch_sys::ScalarType;
use torch_sys::TensorCell;
use torch_sys::factory_zeros;
Expand Down Expand Up @@ -383,14 +381,10 @@ impl WorkerMessageHandler for WorkerActor {
self.maybe_add_stream_to_recording(cx, params.stream)
.await?;

let device_meshes = if params.function.as_torch_op().is_some() {
HashMap::new()
} else {
self.device_meshes
.iter()
.map(|(k, v)| (k.clone(), v.0.clone()))
.collect()
};
let device_meshes = self.device_meshes
.iter()
.map(|(k, v)| (k.clone(), v.0.clone()))
.collect();

let mut remote_process_groups = HashMap::new();
for remote_process_group_ref in &params.remote_process_groups {
Expand Down Expand Up @@ -638,22 +632,6 @@ impl WorkerMessageHandler for WorkerActor {
Ok(())
}

async fn create_pipe(
&mut self,
_cx: &hyperactor::Context<Self>,
_result: Ref,
// TODO(agallagher): This is used in the python impl to name the socket
// path to use for comms, but we don't currently use a named socket.
_key: String,
_function: ResolvableFunction,
_max_messages: i64,
_device_mesh: Ref,
_args: Vec<WireValue>,
_kwargs: HashMap<String, WireValue>,
) -> Result<()> {
panic!("create_pipe is no longer implemented")
}

async fn send_tensor(
&mut self,
cx: &hyperactor::Context<Self>,
Expand Down Expand Up @@ -772,7 +750,7 @@ impl WorkerMessageHandler for WorkerActor {
// Resolve the stream.
let stream = self.try_get_stream(stream)?;

let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
let device_meshes = if function.is_none() {
HashMap::new()
} else {
self.device_meshes
Expand Down
104 changes: 22 additions & 82 deletions monarch_tensor_worker/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ use monarch_types::PyTree;
use monarch_types::SerializablePyErr;
use monarch_types::TryIntoPyObjectUnsafe;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use tokio::runtime::Handle;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -740,34 +739,6 @@ impl StreamActor {
Ok(())
}

fn call_torch_op(
&self,
op: String,
overload: String,
args: Vec<WireValue>,
kwargs: HashMap<String, WireValue>,
) -> Result<Vec<RValue>, CallFunctionError> {
let args = args
.into_iter()
.map(|arg| self.wire_to_rvalue(arg))
.collect::<Result<Vec<_>, _>>()?;
let kwargs = kwargs
.into_iter()
.map(|(k, v)| self.wire_to_rvalue(v).map(|rvalue| (k, rvalue)))
.collect::<Result<HashMap<_, _>, CallFunctionError>>()?;

let results = torch_sys::call_op::call_op(op, overload, &args, &kwargs, true)?;

// Handle the case where the op returns nothing and convert it to a list of None.
// This is to ensure handle results does not error out as the client will call
// such a function with expected results of size 1.
Ok(if results.is_empty() {
vec![RValue::None]
} else {
results
})
}

fn call_python_fn<'py>(
&mut self,
py: Python<'py>,
Expand Down Expand Up @@ -1118,21 +1089,17 @@ impl StreamMessageHandler for StreamActor {
params.results,
&params.mutates,
async |self| {
tokio::task::block_in_place(|| match params.function.as_torch_op() {
Some((op, overload)) => {
self.call_torch_op(op, overload, params.args, params.kwargs)
}
_ => self
.call_python_fn_pytree(
cx,
params.function,
params.args,
params.kwargs,
&params.mutates,
device_meshes,
remote_process_groups,
)
.map(|results| results.into_leaves()),
tokio::task::block_in_place(|| {
self.call_python_fn_pytree(
cx,
params.function,
params.args,
params.kwargs,
&params.mutates,
device_meshes,
remote_process_groups,
)
.map(|results| results.into_leaves())
})
},
)
Expand Down Expand Up @@ -1562,44 +1529,17 @@ impl StreamMessageHandler for StreamActor {
}
let result = if let Some(function) = function {
// If a function was provided, use that to resolve the value.
match function.as_torch_op() {
Some((op, overload)) => {
self.call_torch_op(op, overload, args, kwargs)
.map(|rvalues| {
if rvalues.len() == 1 {
Ok(rvalues[0].clone().into())
} else {
// TODO: Replace with native pytrees when possible
Python::with_gil(|py| {
Ok((|| {
let py_rvalues = rvalues
.into_iter()
// SAFETY: This inherits the unsafety of `try_to_object_unsafe`.
.map(|rvalue| unsafe {
rvalue.try_to_object_unsafe(py)
})
.collect::<Result<Vec<_>, _>>()?;
PyTuple::new(py, &py_rvalues)?.extract::<PyTree<RValue>>()
})()
.map_err(SerializablePyErr::from_fn(py))?)
})
}
})?
}
// Use block-in-place to allow nested callbacks to re-enter the
// runtime to run async code.
_ => tokio::task::block_in_place(|| {
self.call_python_fn_pytree(
cx,
function,
args,
kwargs,
&mutates,
device_meshes,
HashMap::new(),
)
}),
}
tokio::task::block_in_place(|| {
self.call_python_fn_pytree(
cx,
function,
args,
kwargs,
&mutates,
device_meshes,
HashMap::new(),
)
})
} else {
// If there's no function provided, there should be exactly one arg
// and no kwargs.
Expand Down
Loading