diff --git a/controller/src/lib.rs b/controller/src/lib.rs index 09ce6d28d..8034357c1 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -74,7 +74,12 @@ const CASTING_FANOUT_SIZE: usize = 8; /// compute tasks. This acts a proxy managing comms with the workers and handling things like history, /// data dependency, worker lifecycles etc for the client abstracting it away. #[derive(Debug)] -#[hyperactor::export_spawn(ControllerMessage)] +#[hyperactor::export( + spawn = true, + handlers = [ + ControllerMessage, + ], +)] pub(crate) struct ControllerActor { client_actor_ref: OnceCell>, comm_actor_ref: ActorRef, @@ -1772,7 +1777,11 @@ mod tests { } #[derive(Debug)] - #[hyperactor::export(PanickingMessage)] + #[hyperactor::export( + handlers = [ + PanickingMessage, + ], + )] struct PanickingActor; #[async_trait] diff --git a/hyper/src/commands/demo.rs b/hyper/src/commands/demo.rs index 5a72c8eb0..0db6d92d8 100644 --- a/hyper/src/commands/demo.rs +++ b/hyper/src/commands/demo.rs @@ -207,7 +207,12 @@ enum DemoMessage { } #[derive(Debug)] -#[hyperactor::export_spawn(DemoMessage)] +#[hyperactor::export( + spawn = true, + handlers = [ + DemoMessage, + ], +)] struct DemoActor; #[async_trait] diff --git a/hyperactor/example/derive.rs b/hyperactor/example/derive.rs index 8f882b934..7b07289e6 100644 --- a/hyperactor/example/derive.rs +++ b/hyperactor/example/derive.rs @@ -36,7 +36,12 @@ enum ShoppingList { // Define an actor. #[derive(Debug)] -#[hyperactor::export_spawn(ShoppingList)] +#[hyperactor::export( + spawn = true, + handlers = [ + ShoppingList, + ], +)] struct ShoppingListActor(HashSet); #[async_trait] diff --git a/hyperactor/src/actor/mod.rs b/hyperactor/src/actor/mod.rs index d4a9c0d38..19f6afd08 100644 --- a/hyperactor/src/actor/mod.rs +++ b/hyperactor/src/actor/mod.rs @@ -1005,7 +1005,7 @@ mod tests { impl MultiValuesTest {} #[derive(Debug)] - #[hyperactor::export(u64, String)] + #[hyperactor::export(handlers = [u64, String])] struct MultiActor(MultiValues); #[async_trait] diff --git a/hyperactor/src/actor/remote.rs b/hyperactor/src/actor/remote.rs index 70846939f..0c9e5bafb 100644 --- a/hyperactor/src/actor/remote.rs +++ b/hyperactor/src/actor/remote.rs @@ -143,7 +143,7 @@ mod tests { use crate::Instance; #[derive(Debug)] - #[hyperactor::export(())] + #[hyperactor::export(handlers = [()])] struct MyActor; #[async_trait] diff --git a/hyperactor/src/lib.rs b/hyperactor/src/lib.rs index f1738ca48..44e98e150 100644 --- a/hyperactor/src/lib.rs +++ b/hyperactor/src/lib.rs @@ -116,8 +116,6 @@ pub use hyperactor_macros::RefClient; #[doc(inline)] pub use hyperactor_macros::export; #[doc(inline)] -pub use hyperactor_macros::export_spawn; -#[doc(inline)] pub use hyperactor_macros::forward; #[doc(inline)] pub use hyperactor_macros::instrument; diff --git a/hyperactor/src/test_utils/pingpong.rs b/hyperactor/src/test_utils/pingpong.rs index fa82f2dfd..3ded17854 100644 --- a/hyperactor/src/test_utils/pingpong.rs +++ b/hyperactor/src/test_utils/pingpong.rs @@ -64,7 +64,7 @@ impl PingPongActorParams { /// A PingPong actor that can play the PingPong game by sending messages around. #[derive(Debug)] -#[hyperactor::export(PingPongMessage)] +#[hyperactor::export(handlers = [PingPongMessage])] pub struct PingPongActor { params: PingPongActorParams, } diff --git a/hyperactor_macros/src/lib.rs b/hyperactor_macros/src/lib.rs index cc8667455..2f8a04f6c 100644 --- a/hyperactor_macros/src/lib.rs +++ b/hyperactor_macros/src/lib.rs @@ -22,6 +22,7 @@ use syn::Data; use syn::DataEnum; use syn::DeriveInput; use syn::Expr; +use syn::ExprLit; use syn::Field; use syn::Fields; use syn::Ident; @@ -32,6 +33,9 @@ use syn::Meta; use syn::MetaNameValue; use syn::Token; use syn::Type; +use syn::bracketed; +use syn::parse::Parse; +use syn::parse::ParseStream; use syn::parse_macro_input; use syn::punctuated::Punctuated; use syn::spanned::Spanned; @@ -443,7 +447,12 @@ fn parse_message_enum(input: DeriveInput) -> Result, syn::Error> { /// /// // Define an actor. /// #[derive(Debug)] -/// #[hyperactor::export_spawn(ShoppingList)] +/// #[hyperactor::export( +/// spawn = true, +/// handlers = [ +/// ShoppingList, +/// ], +/// )] /// struct ShoppingListActor(HashSet); /// /// #[async_trait] @@ -1139,45 +1148,97 @@ pub fn named_derive(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } -/// Exports an actor so that it may be bound to [`hyperactor::ActorRef`]s. -/// -/// The macro must be provided with the set of types that are exported, and -/// which may therefore be dispatched through references to the actor. -#[proc_macro_attribute] -pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream { - export_impl("export", attr, &parse_macro_input!(item as DeriveInput)) +/// Attribute Struct for [`fn export`] macro. +struct ExportAttr { + spawn: bool, + handlers: Vec, } -/// A version of [`export`] which also makes the actor remotely spawnable. -#[proc_macro_attribute] -pub fn export_spawn(attr: TokenStream, item: TokenStream) -> TokenStream { - let input: DeriveInput = parse_macro_input!(item as DeriveInput); - - let mut exported = export_impl("export_spawn", attr, &input); +impl Parse for ExportAttr { + fn parse(input: ParseStream) -> syn::Result { + let mut spawn = false; + let mut handlers = vec![]; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + if key == "spawn" { + let expr: Expr = input.parse()?; + if let Expr::Lit(ExprLit { + lit: Lit::Bool(b), .. + }) = expr + { + spawn = b.value; + } else { + return Err(syn::Error::new_spanned( + expr, + "expected boolean for `spawn`", + )); + } + } else if key == "handlers" { + let content; + bracketed!(content in input); + let types = content.parse_terminated(Type::parse, Token![,])?; + if types.is_empty() { + return Err(syn::Error::new_spanned( + types, + "`handlers` must include at least one type", + )); + } + handlers = types.into_iter().collect(); + } else { + return Err(syn::Error::new_spanned( + key, + "unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`", + )); + } - let data_type_name = &input.ident; - exported.extend(TokenStream::from(quote! { - hyperactor::remote!(#data_type_name); - })); + // optional trailing comma + let _ = input.parse::(); + } - exported + Ok(ExportAttr { spawn, handlers }) + } } -fn export_impl(which: &'static str, attr: TokenStream, input: &DeriveInput) -> TokenStream { +/// Exports handlers for this actor. The set of exported handlers +/// determine the messages that may be sent to remote references of +/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement +/// [`hyperactor::RemoteMessage`] may be exported. +/// +/// Additionally, an exported actor may be remotely spawned, +/// indicated by `spawn = true`. Such actors must also ensure that +/// their parameter type implements [`hyperactor::RemoteMessage`]. +/// +/// # Example +/// +/// In the following example, `MyActor` can be spawned remotely. It also has +/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`. +/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these +/// types. +/// +/// ```ignore +/// #[export( +/// spawn = true, +/// handlers = [ +/// MyMessage, +/// MyOtherMessage, +/// ], +/// )] +/// struct MyActor {} +/// ``` +#[proc_macro_attribute] +pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(item as DeriveInput); let data_type_name = &input.ident; - let attr_args = - parse_macro_input!(attr with Punctuated::::parse_terminated); - if attr_args.is_empty() { - return TokenStream::from( - syn::Error::new_spanned(attr_args, format!("`{}` expects one or more type path arguments\n\n= help: use `#[{}(MyType, MyOtherType)]`", which, which)).to_compile_error(), - ); - } + let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr); let mut handles = Vec::new(); let mut bindings = Vec::new(); - for ty in &attr_args { + for ty in &handlers { handles.push(quote! { impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {} }); @@ -1186,26 +1247,33 @@ fn export_impl(which: &'static str, attr: TokenStream, input: &DeriveInput) -> T }); } - let expanded = quote! { - #input + let mut expanded = quote! { + #input - impl hyperactor::actor::RemoteActor for #data_type_name {} + impl hyperactor::actor::RemoteActor for #data_type_name {} - #(#handles)* + #(#handles)* - // Always export the `Signal` type. - impl hyperactor::actor::RemoteHandles for #data_type_name {} + // Always export the `Signal` type. + impl hyperactor::actor::RemoteHandles for #data_type_name {} - impl hyperactor::actor::Binds<#data_type_name> for #data_type_name { - fn bind(ports: &hyperactor::proc::Ports) { - #(#bindings)* - } - } + impl hyperactor::actor::Binds<#data_type_name> for #data_type_name { + fn bind(ports: &hyperactor::proc::Ports) { + #(#bindings)* + } + } impl hyperactor::data::Named for #data_type_name { fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) } } }; + if spawn { + expanded.extend(quote! { + + hyperactor::remote!(#data_type_name); + }); + } + TokenStream::from(expanded) } diff --git a/hyperactor_macros/tests/export.rs b/hyperactor_macros/tests/export.rs index 764918321..05d4c9467 100644 --- a/hyperactor_macros/tests/export.rs +++ b/hyperactor_macros/tests/export.rs @@ -18,7 +18,7 @@ use serde::Deserialize; use crate::Serialize; #[derive(Debug)] -#[hyperactor::export(TestMessage, (), MyGeneric<()>)] +#[hyperactor::export(handlers = [TestMessage, (), MyGeneric<()>,])] struct TestActor { // Forward the received message to this port, so it can be inspected by // the unit test. diff --git a/hyperactor_mesh/examples/dining_philosophers.rs b/hyperactor_mesh/examples/dining_philosophers.rs index 4e61cf424..a5c39ff64 100644 --- a/hyperactor_mesh/examples/dining_philosophers.rs +++ b/hyperactor_mesh/examples/dining_philosophers.rs @@ -47,9 +47,12 @@ enum ChopstickStatus { } #[derive(Debug)] -#[hyperactor::export_spawn( - Cast, - IndexedErasedUnbound>, +#[hyperactor::export( + spawn = true, + handlers = [ + Cast, + IndexedErasedUnbound>, + ], )] struct PhilosopherActor { /// Status of left and right chopsticks diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index d7ff06776..dcbc44dc5 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -375,15 +375,18 @@ pub(crate) mod test_util { // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor // mesh test suite. #[derive(Debug)] - #[hyperactor::export_spawn( - Cast, - Cast, - Cast, - GetRank, - Relay, - IndexedErasedUnbound>, - IndexedErasedUnbound>, - IndexedErasedUnbound>, + #[hyperactor::export( + spawn = true, + handlers = [ + Cast, + Cast, + Cast, + GetRank, + Relay, + IndexedErasedUnbound>, + IndexedErasedUnbound>, + IndexedErasedUnbound>, + ], )] pub struct TestActor; diff --git a/hyperactor_mesh/src/comm/mod.rs b/hyperactor_mesh/src/comm/mod.rs index d5e853d57..e82aa2bdd 100644 --- a/hyperactor_mesh/src/comm/mod.rs +++ b/hyperactor_mesh/src/comm/mod.rs @@ -66,7 +66,14 @@ struct ReceiveState { /// This is the comm actor used for efficient and scalable message multicasting /// and result accumulation. #[derive(Debug)] -#[hyperactor::export_spawn(CommActorMode, CastMessage, ForwardMessage)] +#[hyperactor::export( + spawn = true, + handlers = [ + CommActorMode, + CastMessage, + ForwardMessage, + ], +)] pub struct CommActor { /// Each world will use its own seq num from this caster. send_seq: HashMap, @@ -434,7 +441,15 @@ pub mod test_utils { } #[derive(Debug)] - #[hyperactor::export_spawn(TestMessage, Cast, IndexedErasedUnbound, IndexedErasedUnbound>)] + #[hyperactor::export( + spawn = true, + handlers = [ + TestMessage, + Cast, + IndexedErasedUnbound, + IndexedErasedUnbound>, + ], + )] pub struct TestActor { // Forward the received message to this port, so it can be inspected by // the unit test. diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index 97c05aa73..000c63908 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -88,7 +88,7 @@ pub(crate) enum MeshAgentMessage { /// A mesh agent is responsible for managing procs in a [`ProcMesh`]. #[derive(Debug)] -#[hyperactor::export(MeshAgentMessage)] +#[hyperactor::export(handlers=[MeshAgentMessage])] pub struct MeshAgent { proc: Proc, remote: Remote, diff --git a/hyperactor_mesh/src/test_utils.rs b/hyperactor_mesh/src/test_utils.rs index 4ec4d219d..e1ace7b40 100644 --- a/hyperactor_mesh/src/test_utils.rs +++ b/hyperactor_mesh/src/test_utils.rs @@ -40,8 +40,11 @@ impl Unbind for EmptyMessage { /// No-op actor. #[derive(Debug, PartialEq)] #[hyperactor::export( - EmptyMessage, - Cast, IndexedErasedUnbound> + handlers = [ + EmptyMessage, + Cast, + IndexedErasedUnbound>, + ], )] pub struct EmptyActor(); diff --git a/hyperactor_multiprocess/src/proc_actor.rs b/hyperactor_multiprocess/src/proc_actor.rs index 9f1d00d59..70ef5f9e8 100644 --- a/hyperactor_multiprocess/src/proc_actor.rs +++ b/hyperactor_multiprocess/src/proc_actor.rs @@ -316,7 +316,12 @@ pub struct BootstrappedProc { /// the lifecycle of all of the proc's actors, and to route messages /// accordingly. #[derive(Debug)] -#[hyperactor::export(ProcMessage, MailboxAdminMessage)] +#[hyperactor::export( + handlers = [ + ProcMessage, + MailboxAdminMessage, + ], +)] pub struct ProcActor { params: ProcActorParams, state: ProcState, @@ -931,7 +936,12 @@ mod tests { } #[derive(Debug)] - #[hyperactor::export_spawn(TestActorMessage)] + #[hyperactor::export( + spawn = true, + handlers = [ + TestActorMessage, + ], + )] struct TestActor; #[derive(Handler, HandleClient, RefClient, Serialize, Deserialize, Debug, Named)] @@ -1000,7 +1010,12 @@ mod tests { // Sleep #[derive(Debug)] - #[hyperactor::export_spawn(u64)] + #[hyperactor::export( + spawn = true, + handlers = [ + u64, + ], + )] struct SleepActor {} #[async_trait] diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index 5f45984ff..461fe1d7b 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -1085,7 +1085,13 @@ enum SystemStopMessage { /// procs. The system actor also provides a central mailbox that can /// route messages to any live actor in the system. #[derive(Debug, Clone)] -#[hyperactor::export(SystemMessage, ProcSupervisionMessage, WorldSupervisionMessage)] +#[hyperactor::export( + handlers = [ + SystemMessage, + ProcSupervisionMessage, + WorldSupervisionMessage, + ], +)] pub struct SystemActor { params: SystemActorParams, supervision_state: SystemSupervisionState, diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 64508b412..fda8a03eb 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -273,7 +273,14 @@ impl PythonActorHandle { /// An actor for which message handlers are implemented in Python. #[derive(Debug)] -#[hyperactor::export_spawn(PythonMessage, Cast, IndexedErasedUnbound>)] +#[hyperactor::export( + spawn = true, + handlers = [ + PythonMessage, + Cast, + IndexedErasedUnbound>, + ], +)] pub(super) struct PythonActor { /// The Python object that we delegate message handling to. An instance of /// `monarch.actor_mesh._Actor`. diff --git a/monarch_rdma/examples/parameter_server.rs b/monarch_rdma/examples/parameter_server.rs index 88bc86344..1918ab27f 100644 --- a/monarch_rdma/examples/parameter_server.rs +++ b/monarch_rdma/examples/parameter_server.rs @@ -91,7 +91,14 @@ const BUFFER_SIZE: usize = 8; // Parameter Server Actor #[derive(Debug)] -#[hyperactor::export_spawn(PsGetBuffers, PsUpdate, Log)] +#[hyperactor::export( + spawn = true, + handlers = [ + PsGetBuffers, + PsUpdate, + Log, + ], +)] pub struct ParameterServerActor { weights_data: Box<[u8]>, grad_buffer_data: Box<[Box<[u8]>]>, @@ -252,11 +259,18 @@ impl Handler for ParameterServerActor { // Worker Actor #[derive(Debug)] -#[hyperactor::export_spawn( - Cast, IndexedErasedUnbound>, - Cast, IndexedErasedUnbound>, - Cast, IndexedErasedUnbound>, - Cast, IndexedErasedUnbound>, +#[hyperactor::export( + spawn = true, + handlers = [ + Cast, + IndexedErasedUnbound>, + Cast, + IndexedErasedUnbound>, + Cast, + IndexedErasedUnbound>, + Cast, + IndexedErasedUnbound>, + ], )] pub struct WorkerActor { ps_weights_handle: Option, diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 83afdf853..d8b4a5900 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -152,7 +152,12 @@ pub enum RdmaManagerMessage { } #[derive(Debug)] -#[hyperactor::export_spawn(RdmaManagerMessage)] +#[hyperactor::export( + spawn = true, + handlers = [ + RdmaManagerMessage, + ], +)] pub struct RdmaManagerActor { // Map between ActorIds and their corresponding RdmaQueuePair qp_map: HashMap, diff --git a/monarch_simulator/src/controller.rs b/monarch_simulator/src/controller.rs index 3b840166f..ccf026b9e 100644 --- a/monarch_simulator/src/controller.rs +++ b/monarch_simulator/src/controller.rs @@ -36,7 +36,12 @@ use tokio::sync::OnceCell; use crate::worker::WorkerActor; #[derive(Debug)] -#[hyperactor::export_spawn(ControllerMessage)] +#[hyperactor::export( + spawn = true, + handlers = [ + ControllerMessage, + ], +)] pub struct SimControllerActor { client_actor_ref: OnceCell>, worker_actor_ref: ActorRef, diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index f3d63e8bd..827817a9b 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -151,7 +151,13 @@ fn reduce_op>( } #[derive(Debug)] -#[hyperactor::export_spawn(WorkerMessage, IndexedErasedUnbound)] +#[hyperactor::export( + spawn = true, + handlers = [ + WorkerMessage, + IndexedErasedUnbound, + ], +)] pub struct WorkerActor { rank: usize, worker_actor_id: ActorId, diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index d76ada3c6..ea3129884 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -154,7 +154,17 @@ enum Recording { /// /// See [`WorkerMessage`] for what it can do! #[derive(Debug)] -#[hyperactor::export_spawn(WorkerMessage, IndexedErasedUnbound, Cast, Cast, IndexedErasedUnbound>, IndexedErasedUnbound>)] +#[hyperactor::export( + spawn = true, + handlers = [ + WorkerMessage, + IndexedErasedUnbound, + Cast, + Cast, + IndexedErasedUnbound>, + IndexedErasedUnbound>, + ], +)] pub struct WorkerActor { device: Option, streams: HashMap>>,