Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ const CASTING_FANOUT_SIZE: usize = 8;
/// compute tasks. This acts a proxy managing comms with the workers and handling things like history,
/// data dependency, worker lifecycles etc for the client abstracting it away.
#[derive(Debug)]
#[hyperactor::export_spawn(ControllerMessage)]
#[hyperactor::export(
spawn = true,
handlers = [
ControllerMessage,
],
)]
pub(crate) struct ControllerActor {
client_actor_ref: OnceCell<ActorRef<ClientActor>>,
comm_actor_ref: ActorRef<CommActor>,
Expand Down Expand Up @@ -1772,7 +1777,11 @@ mod tests {
}

#[derive(Debug)]
#[hyperactor::export(PanickingMessage)]
#[hyperactor::export(
handlers = [
PanickingMessage,
],
)]
struct PanickingActor;

#[async_trait]
Expand Down
7 changes: 6 additions & 1 deletion hyper/src/commands/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ enum DemoMessage {
}

#[derive(Debug)]
#[hyperactor::export_spawn(DemoMessage)]
#[hyperactor::export(
spawn = true,
handlers = [
DemoMessage,
],
)]
struct DemoActor;

#[async_trait]
Expand Down
7 changes: 6 additions & 1 deletion hyperactor/example/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ enum ShoppingList {

// Define an actor.
#[derive(Debug)]
#[hyperactor::export_spawn(ShoppingList)]
#[hyperactor::export(
spawn = true,
handlers = [
ShoppingList,
],
)]
struct ShoppingListActor(HashSet<String>);

#[async_trait]
Expand Down
2 changes: 1 addition & 1 deletion hyperactor/src/actor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ mod tests {
impl MultiValuesTest {}

#[derive(Debug)]
#[hyperactor::export(u64, String)]
#[hyperactor::export(handlers = [u64, String])]
struct MultiActor(MultiValues);

#[async_trait]
Expand Down
2 changes: 1 addition & 1 deletion hyperactor/src/actor/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ mod tests {
use crate::Instance;

#[derive(Debug)]
#[hyperactor::export(())]
#[hyperactor::export(handlers = [()])]
struct MyActor;

#[async_trait]
Expand Down
2 changes: 0 additions & 2 deletions hyperactor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ pub use hyperactor_macros::RefClient;
#[doc(inline)]
pub use hyperactor_macros::export;
#[doc(inline)]
pub use hyperactor_macros::export_spawn;
#[doc(inline)]
pub use hyperactor_macros::forward;
#[doc(inline)]
pub use hyperactor_macros::instrument;
Expand Down
2 changes: 1 addition & 1 deletion hyperactor/src/test_utils/pingpong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl PingPongActorParams {

/// A PingPong actor that can play the PingPong game by sending messages around.
#[derive(Debug)]
#[hyperactor::export(PingPongMessage)]
#[hyperactor::export(handlers = [PingPongMessage])]
pub struct PingPongActor {
params: PingPongActorParams,
}
Expand Down
146 changes: 107 additions & 39 deletions hyperactor_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use syn::Data;
use syn::DataEnum;
use syn::DeriveInput;
use syn::Expr;
use syn::ExprLit;
use syn::Field;
use syn::Fields;
use syn::Ident;
Expand All @@ -32,6 +33,9 @@ use syn::Meta;
use syn::MetaNameValue;
use syn::Token;
use syn::Type;
use syn::bracketed;
use syn::parse::Parse;
use syn::parse::ParseStream;
use syn::parse_macro_input;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
Expand Down Expand Up @@ -443,7 +447,12 @@ fn parse_message_enum(input: DeriveInput) -> Result<Vec<Message>, syn::Error> {
///
/// // Define an actor.
/// #[derive(Debug)]
/// #[hyperactor::export_spawn(ShoppingList)]
/// #[hyperactor::export(
/// spawn = true,
/// handlers = [
/// ShoppingList,
/// ],
/// )]
/// struct ShoppingListActor(HashSet<String>);
///
/// #[async_trait]
Expand Down Expand Up @@ -1139,45 +1148,97 @@ pub fn named_derive(input: TokenStream) -> TokenStream {
TokenStream::from(expanded)
}

/// Exports an actor so that it may be bound to [`hyperactor::ActorRef`]s.
///
/// The macro must be provided with the set of types that are exported, and
/// which may therefore be dispatched through references to the actor.
#[proc_macro_attribute]
pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
export_impl("export", attr, &parse_macro_input!(item as DeriveInput))
/// Attribute Struct for [`fn export`] macro.
struct ExportAttr {
spawn: bool,
handlers: Vec<Type>,
}

/// A version of [`export`] which also makes the actor remotely spawnable.
#[proc_macro_attribute]
pub fn export_spawn(attr: TokenStream, item: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(item as DeriveInput);

let mut exported = export_impl("export_spawn", attr, &input);
impl Parse for ExportAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut spawn = false;
let mut handlers = vec![];

while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;

if key == "spawn" {
let expr: Expr = input.parse()?;
if let Expr::Lit(ExprLit {
lit: Lit::Bool(b), ..
}) = expr
{
spawn = b.value;
} else {
return Err(syn::Error::new_spanned(
expr,
"expected boolean for `spawn`",
));
}
} else if key == "handlers" {
let content;
bracketed!(content in input);
let types = content.parse_terminated(Type::parse, Token![,])?;
if types.is_empty() {
return Err(syn::Error::new_spanned(
types,
"`handlers` must include at least one type",
));
}
handlers = types.into_iter().collect();
} else {
return Err(syn::Error::new_spanned(
key,
"unexpected key in `#[export(...)]`. Only supports `spawn` and `handlers`",
));
}

let data_type_name = &input.ident;
exported.extend(TokenStream::from(quote! {
hyperactor::remote!(#data_type_name);
}));
// optional trailing comma
let _ = input.parse::<Token![,]>();
}

exported
Ok(ExportAttr { spawn, handlers })
}
}

fn export_impl(which: &'static str, attr: TokenStream, input: &DeriveInput) -> TokenStream {
/// Exports handlers for this actor. The set of exported handlers
/// determine the messages that may be sent to remote references of
/// the actor ([`hyperaxtor::ActorRef`]). Only messages that implement
/// [`hyperactor::RemoteMessage`] may be exported.
///
/// Additionally, an exported actor may be remotely spawned,
/// indicated by `spawn = true`. Such actors must also ensure that
/// their parameter type implements [`hyperactor::RemoteMessage`].
///
/// # Example
///
/// In the following example, `MyActor` can be spawned remotely. It also has
/// exports handlers for two message types, `MyMessage` and `MyOtherMessage`.
/// Consequently, `ActorRef`s of the actor's type may dispatch messages of these
/// types.
///
/// ```ignore
/// #[export(
/// spawn = true,
/// handlers = [
/// MyMessage,
/// MyOtherMessage,
/// ],
/// )]
/// struct MyActor {}
/// ```
#[proc_macro_attribute]
pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(item as DeriveInput);
let data_type_name = &input.ident;

let attr_args =
parse_macro_input!(attr with Punctuated::<syn::Type, Token![,]>::parse_terminated);
if attr_args.is_empty() {
return TokenStream::from(
syn::Error::new_spanned(attr_args, format!("`{}` expects one or more type path arguments\n\n= help: use `#[{}(MyType, MyOtherType<T>)]`", which, which)).to_compile_error(),
);
}
let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);

let mut handles = Vec::new();
let mut bindings = Vec::new();

for ty in &attr_args {
for ty in &handlers {
handles.push(quote! {
impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
});
Expand All @@ -1186,26 +1247,33 @@ fn export_impl(which: &'static str, attr: TokenStream, input: &DeriveInput) -> T
});
}

let expanded = quote! {
#input
let mut expanded = quote! {
#input

impl hyperactor::actor::RemoteActor for #data_type_name {}
impl hyperactor::actor::RemoteActor for #data_type_name {}

#(#handles)*
#(#handles)*

// Always export the `Signal` type.
impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
// Always export the `Signal` type.
impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}

impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
fn bind(ports: &hyperactor::proc::Ports<Self>) {
#(#bindings)*
}
}
impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
fn bind(ports: &hyperactor::proc::Ports<Self>) {
#(#bindings)*
}
}

impl hyperactor::data::Named for #data_type_name {
fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
}
};

if spawn {
expanded.extend(quote! {

hyperactor::remote!(#data_type_name);
});
}

TokenStream::from(expanded)
}
2 changes: 1 addition & 1 deletion hyperactor_macros/tests/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use serde::Deserialize;
use crate::Serialize;

#[derive(Debug)]
#[hyperactor::export(TestMessage, (), MyGeneric<()>)]
#[hyperactor::export(handlers = [TestMessage, (), MyGeneric<()>,])]
struct TestActor {
// Forward the received message to this port, so it can be inspected by
// the unit test.
Expand Down
9 changes: 6 additions & 3 deletions hyperactor_mesh/examples/dining_philosophers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ enum ChopstickStatus {
}

#[derive(Debug)]
#[hyperactor::export_spawn(
Cast<PhilosopherMessage>,
IndexedErasedUnbound<Cast<PhilosopherMessage>>,
#[hyperactor::export(
spawn = true,
handlers = [
Cast<PhilosopherMessage>,
IndexedErasedUnbound<Cast<PhilosopherMessage>>,
],
)]
struct PhilosopherActor {
/// Status of left and right chopsticks
Expand Down
21 changes: 12 additions & 9 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,18 @@ pub(crate) mod test_util {
// 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor
// mesh test suite.
#[derive(Debug)]
#[hyperactor::export_spawn(
Cast<Echo>,
Cast<GetRank>,
Cast<Error>,
GetRank,
Relay,
IndexedErasedUnbound<Cast<Echo>>,
IndexedErasedUnbound<Cast<GetRank>>,
IndexedErasedUnbound<Cast<Error>>,
#[hyperactor::export(
spawn = true,
handlers = [
Cast<Echo>,
Cast<GetRank>,
Cast<Error>,
GetRank,
Relay,
IndexedErasedUnbound<Cast<Echo>>,
IndexedErasedUnbound<Cast<GetRank>>,
IndexedErasedUnbound<Cast<Error>>,
],
)]
pub struct TestActor;

Expand Down
19 changes: 17 additions & 2 deletions hyperactor_mesh/src/comm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ struct ReceiveState {
/// This is the comm actor used for efficient and scalable message multicasting
/// and result accumulation.
#[derive(Debug)]
#[hyperactor::export_spawn(CommActorMode, CastMessage, ForwardMessage)]
#[hyperactor::export(
spawn = true,
handlers = [
CommActorMode,
CastMessage,
ForwardMessage,
],
)]
pub struct CommActor {
/// Each world will use its own seq num from this caster.
send_seq: HashMap<Slice, usize>,
Expand Down Expand Up @@ -434,7 +441,15 @@ pub mod test_utils {
}

#[derive(Debug)]
#[hyperactor::export_spawn(TestMessage, Cast<TestMessage>, IndexedErasedUnbound<TestMessage>, IndexedErasedUnbound<Cast<TestMessage>>)]
#[hyperactor::export(
spawn = true,
handlers = [
TestMessage,
Cast<TestMessage>,
IndexedErasedUnbound<TestMessage>,
IndexedErasedUnbound<Cast<TestMessage>>,
],
)]
pub struct TestActor {
// Forward the received message to this port, so it can be inspected by
// the unit test.
Expand Down
2 changes: 1 addition & 1 deletion hyperactor_mesh/src/proc_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub(crate) enum MeshAgentMessage {

/// A mesh agent is responsible for managing procs in a [`ProcMesh`].
#[derive(Debug)]
#[hyperactor::export(MeshAgentMessage)]
#[hyperactor::export(handlers=[MeshAgentMessage])]
pub struct MeshAgent {
proc: Proc,
remote: Remote,
Expand Down
Loading