diff --git a/concurrency/src/tasks/error.rs b/concurrency/src/tasks/error.rs index 05653fa..498e8fb 100644 --- a/concurrency/src/tasks/error.rs +++ b/concurrency/src/tasks/error.rs @@ -1,11 +1,12 @@ #[derive(Debug)] pub enum GenServerError { - CallbackError, - ServerError, + Callback, + Initialization, + Server, } impl From> for GenServerError { fn from(_value: spawned_rt::tasks::mpsc::SendError) -> Self { - Self::ServerError + Self::Server } } diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index 168b3c5..f800a2f 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -20,7 +20,7 @@ impl Clone for GenServerHandle { } impl GenServerHandle { - pub(crate) fn new(mut initial_state: G::State) -> Self { + pub(crate) fn new(initial_state: G::State) -> Self { let (tx, mut rx) = mpsc::channel::>(); let handle = GenServerHandle { tx }; let mut gen_server: G = GenServer::new(); @@ -28,7 +28,7 @@ impl GenServerHandle { // Ignore the JoinHandle for now. Maybe we'll use it in the future let _join_handle = rt::spawn(async move { if gen_server - .run(&handle, &mut rx, &mut initial_state) + .run(&handle, &mut rx, initial_state) .await .is_err() { @@ -38,7 +38,7 @@ impl GenServerHandle { handle_clone } - pub(crate) fn new_blocking(mut initial_state: G::State) -> Self { + pub(crate) fn new_blocking(initial_state: G::State) -> Self { let (tx, mut rx) = mpsc::channel::>(); let handle = GenServerHandle { tx }; let mut gen_server: G = GenServer::new(); @@ -47,7 +47,7 @@ impl GenServerHandle { let _join_handle = rt::spawn_blocking(|| { rt::block_on(async move { if gen_server - .run(&handle, &mut rx, &mut initial_state) + .run(&handle, &mut rx, initial_state) .await .is_err() { @@ -70,34 +70,34 @@ impl GenServerHandle { })?; match oneshot_rx.await { Ok(result) => result, - Err(_) => Err(GenServerError::ServerError), + Err(_) => Err(GenServerError::Server), } } pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> { self.tx .send(GenServerInMsg::Cast { message }) - .map_err(|_error| GenServerError::ServerError) + .map_err(|_error| GenServerError::Server) } } -pub enum GenServerInMsg { +pub enum GenServerInMsg { Call { - sender: oneshot::Sender>, - message: A::CallMsg, + sender: oneshot::Sender>, + message: G::CallMsg, }, Cast { - message: A::CastMsg, + message: G::CastMsg, }, } -pub enum CallResponse { - Reply(U), - Stop(U), +pub enum CallResponse { + Reply(G::State, G::OutMsg), + Stop(G::OutMsg), } -pub enum CastResponse { - NoReply, +pub enum CastResponse { + NoReply(G::State), Stop, } @@ -109,7 +109,7 @@ where type CastMsg: Send + Sized; type OutMsg: Send + Sized; type State: Clone + Send; - type Error: Debug; + type Error: Debug + Send; fn new() -> Self; @@ -130,25 +130,46 @@ where &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, + state: Self::State, ) -> impl Future> + Send { async { - self.main_loop(handle, rx, state).await?; - Ok(()) + match self.init(handle, state).await { + Ok(new_state) => { + self.main_loop(handle, rx, new_state).await?; + Ok(()) + } + Err(err) => { + tracing::error!("Initialization failed: {err:?}"); + Err(GenServerError::Initialization) + } + } } } + /// Initialization function. It's called before main loop. It + /// can be overrided on implementations in case initial steps are + /// required. + fn init( + &mut self, + _handle: &GenServerHandle, + state: Self::State, + ) -> impl Future> + Send { + async { Ok(state) } + } + fn main_loop( &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, + mut state: Self::State, ) -> impl Future> + Send { async { loop { - if !self.receive(handle, rx, state).await? { + let (new_state, cont) = self.receive(handle, rx, state).await?; + if !cont { break; } + state = new_state; } tracing::trace!("Stopping GenServer"); Ok(()) @@ -159,26 +180,33 @@ where &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, - ) -> impl std::future::Future> + Send { - async { + state: Self::State, + ) -> impl std::future::Future> + Send { + async move { let message = rx.recv().await; // Save current state in case of a rollback let state_clone = state.clone(); - let (keep_running, error) = match message { + let (keep_running, new_state) = match message { Some(GenServerInMsg::Call { sender, message }) => { - let (keep_running, error, response) = + let (keep_running, new_state, response) = match AssertUnwindSafe(self.handle_call(message, handle, state)) .catch_unwind() .await { Ok(response) => match response { - CallResponse::Reply(response) => (true, None, Ok(response)), - CallResponse::Stop(response) => (false, None, Ok(response)), + CallResponse::Reply(new_state, response) => { + (true, new_state, Ok(response)) + } + CallResponse::Stop(response) => (false, state_clone, Ok(response)), }, - Err(error) => (true, Some(error), Err(GenServerError::CallbackError)), + Err(error) => { + tracing::trace!( + "Error in callback, reverting state - Error: '{error:?}'" + ); + (true, state_clone, Err(GenServerError::Callback)) + } }; // Send response back if sender.send(response).is_err() { @@ -186,7 +214,7 @@ where "GenServer failed to send response back, client must have died" ) }; - (keep_running, error) + (keep_running, new_state) } Some(GenServerInMsg::Cast { message }) => { match AssertUnwindSafe(self.handle_cast(message, handle, state)) @@ -194,23 +222,23 @@ where .await { Ok(response) => match response { - CastResponse::NoReply => (true, None), - CastResponse::Stop => (false, None), + CastResponse::NoReply(new_state) => (true, new_state), + CastResponse::Stop => (false, state_clone), }, - Err(error) => (true, Some(error)), + Err(error) => { + tracing::trace!( + "Error in callback, reverting state - Error: '{error:?}'" + ); + (true, state_clone) + } } } None => { // Channel has been closed; won't receive further messages. Stop the server. - (false, None) + (false, state) } }; - if let Some(error) = error { - tracing::trace!("Error in callback, reverting state - Error: '{error:?}'"); - // Restore initial state (ie. dismiss any change) - *state = state_clone; - }; - Ok(keep_running) + Ok((new_state, keep_running)) } } @@ -218,22 +246,22 @@ where &mut self, message: Self::CallMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> impl std::future::Future> + Send; + state: Self::State, + ) -> impl std::future::Future> + Send; fn handle_cast( &mut self, message: Self::CastMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> impl std::future::Future + Send; + state: Self::State, + ) -> impl std::future::Future> + Send; } #[cfg(test)] mod tests { use super::*; use crate::tasks::send_after; - use std::{process::exit, thread, time::Duration}; + use std::{thread, time::Duration}; struct BadlyBehavedTask; #[derive(Clone)] @@ -261,8 +289,8 @@ mod tests { &mut self, _: Self::CallMsg, _: &GenServerHandle, - _: &mut Self::State, - ) -> CallResponse { + _: Self::State, + ) -> CallResponse { CallResponse::Stop(()) } @@ -270,8 +298,8 @@ mod tests { &mut self, _: Self::CastMsg, _: &GenServerHandle, - _: &mut Self::State, - ) -> CastResponse { + _: Self::State, + ) -> CastResponse { rt::sleep(Duration::from_millis(20)).await; thread::sleep(Duration::from_secs(2)); CastResponse::Stop @@ -300,10 +328,13 @@ mod tests { &mut self, message: Self::CallMsg, _: &GenServerHandle, - state: &mut Self::State, - ) -> CallResponse { + state: Self::State, + ) -> CallResponse { match message { - InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)), + InMessage::GetCount => { + let count = state.count; + CallResponse::Reply(state, OutMsg::Count(count)) + } InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)), } } @@ -312,12 +343,12 @@ mod tests { &mut self, _: Self::CastMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> CastResponse { + mut state: Self::State, + ) -> CastResponse { state.count += 1; println!("{:?}: good still alive", thread::current().id()); send_after(Duration::from_millis(100), handle.to_owned(), ()); - CastResponse::NoReply + CastResponse::NoReply(state) } } diff --git a/concurrency/src/threads/error.rs b/concurrency/src/threads/error.rs index 735e37e..8834e0f 100644 --- a/concurrency/src/threads/error.rs +++ b/concurrency/src/threads/error.rs @@ -1,11 +1,12 @@ #[derive(Debug)] pub enum GenServerError { - CallbackError, - ServerError, + Callback, + Initialization, + Server, } impl From> for GenServerError { fn from(_value: spawned_rt::threads::mpsc::SendError) -> Self { - Self::ServerError + Self::Server } } diff --git a/concurrency/src/threads/gen_server.rs b/concurrency/src/threads/gen_server.rs index 1541c43..912067b 100644 --- a/concurrency/src/threads/gen_server.rs +++ b/concurrency/src/threads/gen_server.rs @@ -22,17 +22,14 @@ impl Clone for GenServerHandle { } impl GenServerHandle { - pub(crate) fn new(mut initial_state: G::State) -> Self { + pub(crate) fn new(initial_state: G::State) -> Self { let (tx, mut rx) = mpsc::channel::>(); let handle = GenServerHandle { tx }; let mut gen_server: G = GenServer::new(); let handle_clone = handle.clone(); // Ignore the JoinHandle for now. Maybe we'll use it in the future let _join_handle = rt::spawn(move || { - if gen_server - .run(&handle, &mut rx, &mut initial_state) - .is_err() - { + if gen_server.run(&handle, &mut rx, initial_state).is_err() { tracing::trace!("GenServer crashed") }; }); @@ -51,34 +48,34 @@ impl GenServerHandle { })?; match oneshot_rx.recv() { Ok(result) => result, - Err(_) => Err(GenServerError::ServerError), + Err(_) => Err(GenServerError::Server), } } pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> { self.tx .send(GenServerInMsg::Cast { message }) - .map_err(|_error| GenServerError::ServerError) + .map_err(|_error| GenServerError::Server) } } -pub enum GenServerInMsg { +pub enum GenServerInMsg { Call { - sender: oneshot::Sender>, - message: A::CallMsg, + sender: oneshot::Sender>, + message: G::CallMsg, }, Cast { - message: A::CastMsg, + message: G::CastMsg, }, } -pub enum CallResponse { - Reply(U), - Stop(U), +pub enum CallResponse { + Reply(G::State, G::OutMsg), + Stop(G::OutMsg), } -pub enum CastResponse { - NoReply, +pub enum CastResponse { + NoReply(G::State), Stop, } @@ -108,22 +105,43 @@ where &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, + state: Self::State, ) -> Result<(), GenServerError> { - self.main_loop(handle, rx, state)?; - Ok(()) + match self.init(handle, state) { + Ok(new_state) => { + self.main_loop(handle, rx, new_state)?; + Ok(()) + } + Err(err) => { + tracing::error!("Initialization failed: {err:?}"); + Err(GenServerError::Initialization) + } + } + } + + /// Initialization function. It's called before main loop. It + /// can be overrided on implementations in case initial steps are + /// required. + fn init( + &mut self, + _handle: &GenServerHandle, + state: Self::State, + ) -> Result { + Ok(state) } fn main_loop( &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, + mut state: Self::State, ) -> Result<(), GenServerError> { loop { - if !self.receive(handle, rx, state)? { + let (new_state, cont) = self.receive(handle, rx, state)?; + if !cont { break; } + state = new_state; } tracing::trace!("Stopping GenServer"); Ok(()) @@ -133,65 +151,71 @@ where &mut self, handle: &GenServerHandle, rx: &mut mpsc::Receiver>, - state: &mut Self::State, - ) -> Result { + state: Self::State, + ) -> Result<(Self::State, bool), GenServerError> { let message = rx.recv().ok(); // Save current state in case of a rollback let state_clone = state.clone(); - let (keep_running, error) = match message { + let (keep_running, new_state) = match message { Some(GenServerInMsg::Call { sender, message }) => { - let (keep_running, error, response) = match catch_unwind(AssertUnwindSafe(|| { - self.handle_call(message, handle, state) - })) { - Ok(response) => match response { - CallResponse::Reply(response) => (true, None, Ok(response)), - CallResponse::Stop(response) => (false, None, Ok(response)), - }, - Err(error) => (true, Some(error), Err(GenServerError::CallbackError)), - }; + let (keep_running, new_state, response) = + match catch_unwind(AssertUnwindSafe(|| { + self.handle_call(message, handle, state) + })) { + Ok(response) => match response { + CallResponse::Reply(new_state, response) => { + (true, new_state, Ok(response)) + } + CallResponse::Stop(response) => (false, state_clone, Ok(response)), + }, + Err(error) => { + tracing::trace!( + "Error in callback, reverting state - Error: '{error:?}'" + ); + (true, state_clone, Err(GenServerError::Callback)) + } + }; // Send response back if sender.send(response).is_err() { tracing::trace!("GenServer failed to send response back, client must have died") }; - (keep_running, error) + (keep_running, new_state) } Some(GenServerInMsg::Cast { message }) => { match catch_unwind(AssertUnwindSafe(|| { self.handle_cast(message, handle, state) })) { Ok(response) => match response { - CastResponse::NoReply => (true, None), - CastResponse::Stop => (false, None), + CastResponse::NoReply(new_state) => (true, new_state), + CastResponse::Stop => (false, state_clone), }, - Err(error) => (true, Some(error)), + Err(error) => { + tracing::trace!("Error in callback, reverting state - Error: '{error:?}'"); + (true, state_clone) + } } } None => { // Channel has been closed; won't receive further messages. Stop the server. - (false, None) + (false, state) } }; - if let Some(error) = error { - tracing::trace!("Error in callback, reverting state - Error: '{error:?}'"); - // Restore initial state (ie. dismiss any change) - *state = state_clone; - }; - Ok(keep_running) + Ok((new_state, keep_running)) } fn handle_call( &mut self, message: Self::CallMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> CallResponse; + state: Self::State, + ) -> CallResponse; fn handle_cast( &mut self, message: Self::CastMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> CastResponse; + state: Self::State, + ) -> CastResponse; } diff --git a/examples/bank/src/main.rs b/examples/bank/src/main.rs index 90f97eb..7284745 100644 --- a/examples/bank/src/main.rs +++ b/examples/bank/src/main.rs @@ -31,18 +31,33 @@ use spawned_rt::tasks as rt; fn main() { rt::run(async { + // Starting the bank let mut name_server = Bank::start(HashMap::new()); + // Testing initial balance for "main" account + let result = Bank::withdraw(&mut name_server, "main".to_string(), 15).await; + tracing::info!("Withdraw result {result:?}"); + assert_eq!( + result, + Ok(BankOutMessage::WidrawOk { + who: "main".to_string(), + amount: 985 + }) + ); + let joe = "Joe".to_string(); + // Error on deposit for an unexistent account let result = Bank::deposit(&mut name_server, joe.clone(), 10).await; tracing::info!("Deposit result {result:?}"); assert_eq!(result, Err(BankError::NotACustomer { who: joe.clone() })); + // Account creation let result = Bank::new_account(&mut name_server, "Joe".to_string()).await; tracing::info!("New account result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Welcome { who: joe.clone() })); + // Deposit let result = Bank::deposit(&mut name_server, "Joe".to_string(), 10).await; tracing::info!("Deposit result {result:?}"); assert_eq!( @@ -53,6 +68,7 @@ fn main() { }) ); + // Deposit let result = Bank::deposit(&mut name_server, "Joe".to_string(), 30).await; tracing::info!("Deposit result {result:?}"); assert_eq!( @@ -63,6 +79,7 @@ fn main() { }) ); + // Withdrawal let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 15).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( @@ -73,16 +90,29 @@ fn main() { }) ); + // Withdrawal with not enough balance let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 45).await; tracing::info!("Withdraw result {result:?}"); assert_eq!( result, Err(BankError::InsufficientBalance { - who: joe, + who: joe.clone(), amount: 25 }) ); + // Full withdrawal + let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 25).await; + tracing::info!("Withdraw result {result:?}"); + assert_eq!( + result, + Ok(BankOutMessage::WidrawOk { + who: joe, + amount: 0 + }) + ); + + // Stopping the bank let result = Bank::stop(&mut name_server).await; tracing::info!("Stop result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Stopped)); diff --git a/examples/bank/src/server.rs b/examples/bank/src/server.rs index 13287fc..793a4ce 100644 --- a/examples/bank/src/server.rs +++ b/examples/bank/src/server.rs @@ -51,47 +51,68 @@ impl GenServer for Bank { Self {} } + // Initializing "main" account with 1000 in balance to test init() callback. + async fn init( + &mut self, + _handle: &GenServerHandle, + mut state: Self::State, + ) -> Result { + state.insert("main".to_string(), 1000); + Ok(state) + } + async fn handle_call( &mut self, message: Self::CallMsg, _handle: &BankHandle, - state: &mut Self::State, - ) -> CallResponse { + mut state: Self::State, + ) -> CallResponse { match message.clone() { Self::CallMsg::New { who } => match state.get(&who) { - Some(_amount) => CallResponse::Reply(Err(BankError::AlreadyACustomer { who })), + Some(_amount) => { + CallResponse::Reply(state, Err(BankError::AlreadyACustomer { who })) + } None => { state.insert(who.clone(), 0); - CallResponse::Reply(Ok(OutMessage::Welcome { who })) + CallResponse::Reply(state, Ok(OutMessage::Welcome { who })) } }, Self::CallMsg::Add { who, amount } => match state.get(&who) { Some(current) => { let new_amount = current + amount; state.insert(who.clone(), new_amount); - CallResponse::Reply(Ok(OutMessage::Balance { - who, - amount: new_amount, - })) + CallResponse::Reply( + state, + Ok(OutMessage::Balance { + who, + amount: new_amount, + }), + ) } - None => CallResponse::Reply(Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(state, Err(BankError::NotACustomer { who })), }, Self::CallMsg::Remove { who, amount } => match state.get(&who) { - Some(current) => match current < &amount { - true => CallResponse::Reply(Err(BankError::InsufficientBalance { - who, - amount: *current, - })), + Some(¤t) => match current < amount { + true => CallResponse::Reply( + state, + Err(BankError::InsufficientBalance { + who, + amount: current, + }), + ), false => { let new_amount = current - amount; state.insert(who.clone(), new_amount); - CallResponse::Reply(Ok(OutMessage::WidrawOk { - who, - amount: new_amount, - })) + CallResponse::Reply( + state, + Ok(OutMessage::WidrawOk { + who, + amount: new_amount, + }), + ) } }, - None => CallResponse::Reply(Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(state, Err(BankError::NotACustomer { who })), }, Self::CallMsg::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)), } @@ -101,8 +122,8 @@ impl GenServer for Bank { &mut self, _message: Self::CastMsg, _handle: &BankHandle, - _state: &mut Self::State, - ) -> CastResponse { - CastResponse::NoReply + state: Self::State, + ) -> CastResponse { + CastResponse::NoReply(state) } } diff --git a/examples/bank_threads/src/main.rs b/examples/bank_threads/src/main.rs index c04a6ac..ced28da 100644 --- a/examples/bank_threads/src/main.rs +++ b/examples/bank_threads/src/main.rs @@ -31,18 +31,33 @@ use spawned_rt::threads as rt; fn main() { rt::run(|| { + // Starting the bank let mut name_server = Bank::start(HashMap::new()); + // Testing initial balance for "main" account + let result = Bank::withdraw(&mut name_server, "main".to_string(), 15); + tracing::info!("Withdraw result {result:?}"); + assert_eq!( + result, + Ok(BankOutMessage::WidrawOk { + who: "main".to_string(), + amount: 985 + }) + ); + let joe = "Joe".to_string(); + // Error on deposit for an unexistent account let result = Bank::deposit(&mut name_server, joe.clone(), 10); tracing::info!("Deposit result {result:?}"); assert_eq!(result, Err(BankError::NotACustomer { who: joe.clone() })); + // Account creation let result = Bank::new_account(&mut name_server, "Joe".to_string()); tracing::info!("New account result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Welcome { who: joe.clone() })); + // Deposit let result = Bank::deposit(&mut name_server, "Joe".to_string(), 10); tracing::info!("Deposit result {result:?}"); assert_eq!( @@ -53,6 +68,7 @@ fn main() { }) ); + // Deposit let result = Bank::deposit(&mut name_server, "Joe".to_string(), 30); tracing::info!("Deposit result {result:?}"); assert_eq!( @@ -63,6 +79,7 @@ fn main() { }) ); + // Withdrawal let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 15); tracing::info!("Withdraw result {result:?}"); assert_eq!( @@ -73,16 +90,29 @@ fn main() { }) ); + // Withdrawal with not enough balance let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 45); tracing::info!("Withdraw result {result:?}"); assert_eq!( result, Err(BankError::InsufficientBalance { - who: joe, + who: joe.clone(), amount: 25 }) ); + // Full withdrawal + let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 25); + tracing::info!("Withdraw result {result:?}"); + assert_eq!( + result, + Ok(BankOutMessage::WidrawOk { + who: joe, + amount: 0 + }) + ); + + // Stopping the bank let result = Bank::stop(&mut name_server); tracing::info!("Stop result {result:?}"); assert_eq!(result, Ok(BankOutMessage::Stopped)); diff --git a/examples/bank_threads/src/server.rs b/examples/bank_threads/src/server.rs index d5dab95..69820a3 100644 --- a/examples/bank_threads/src/server.rs +++ b/examples/bank_threads/src/server.rs @@ -47,49 +47,70 @@ impl GenServer for Bank { Self {} } + // Initializing "main" account with 1000 in balance to test init() callback. + fn init( + &mut self, + _handle: &GenServerHandle, + mut state: Self::State, + ) -> Result { + state.insert("main".to_string(), 1000); + Ok(state) + } + fn handle_call( &mut self, message: Self::CallMsg, _handle: &BankHandle, - state: &mut Self::State, - ) -> CallResponse { + mut state: Self::State, + ) -> CallResponse { match message.clone() { - InMessage::New { who } => match state.get(&who) { - Some(_amount) => CallResponse::Reply(Err(BankError::AlreadyACustomer { who })), + Self::CallMsg::New { who } => match state.get(&who) { + Some(_amount) => { + CallResponse::Reply(state, Err(BankError::AlreadyACustomer { who })) + } None => { state.insert(who.clone(), 0); - CallResponse::Reply(Ok(OutMessage::Welcome { who })) + CallResponse::Reply(state, Ok(OutMessage::Welcome { who })) } }, - InMessage::Add { who, amount } => match state.get(&who) { + Self::CallMsg::Add { who, amount } => match state.get(&who) { Some(current) => { let new_amount = current + amount; state.insert(who.clone(), new_amount); - CallResponse::Reply(Ok(OutMessage::Balance { - who, - amount: new_amount, - })) + CallResponse::Reply( + state, + Ok(OutMessage::Balance { + who, + amount: new_amount, + }), + ) } - None => CallResponse::Reply(Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(state, Err(BankError::NotACustomer { who })), }, - InMessage::Remove { who, amount } => match state.get(&who) { - Some(current) => match current < &amount { - true => CallResponse::Reply(Err(BankError::InsufficientBalance { - who, - amount: *current, - })), + Self::CallMsg::Remove { who, amount } => match state.get(&who) { + Some(¤t) => match current < amount { + true => CallResponse::Reply( + state, + Err(BankError::InsufficientBalance { + who, + amount: current, + }), + ), false => { let new_amount = current - amount; state.insert(who.clone(), new_amount); - CallResponse::Reply(Ok(OutMessage::WidrawOk { - who, - amount: new_amount, - })) + CallResponse::Reply( + state, + Ok(OutMessage::WidrawOk { + who, + amount: new_amount, + }), + ) } }, - None => CallResponse::Reply(Err(BankError::NotACustomer { who })), + None => CallResponse::Reply(state, Err(BankError::NotACustomer { who })), }, - InMessage::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)), + Self::CallMsg::Stop => CallResponse::Stop(Ok(OutMessage::Stopped)), } } @@ -97,8 +118,8 @@ impl GenServer for Bank { &mut self, _message: Self::CastMsg, _handle: &BankHandle, - _state: &mut Self::State, - ) -> CastResponse { - CastResponse::NoReply + state: Self::State, + ) -> CastResponse { + CastResponse::NoReply(state) } } diff --git a/examples/blocking_genserver/main.rs b/examples/blocking_genserver/main.rs index 9a2b832..8dead78 100644 --- a/examples/blocking_genserver/main.rs +++ b/examples/blocking_genserver/main.rs @@ -34,8 +34,8 @@ impl GenServer for BadlyBehavedTask { &mut self, _: Self::CallMsg, _: &GenServerHandle, - _: &mut Self::State, - ) -> CallResponse { + _: Self::State, + ) -> CallResponse { CallResponse::Stop(()) } @@ -43,8 +43,8 @@ impl GenServer for BadlyBehavedTask { &mut self, _: Self::CastMsg, _: &GenServerHandle, - _: &mut Self::State, - ) -> CastResponse { + _: Self::State, + ) -> CastResponse { rt::sleep(Duration::from_millis(20)).await; loop { println!("{:?}: bad still alive", thread::current().id()); @@ -75,10 +75,13 @@ impl GenServer for WellBehavedTask { &mut self, message: Self::CallMsg, _: &GenServerHandle, - state: &mut Self::State, - ) -> CallResponse { + state: Self::State, + ) -> CallResponse { match message { - InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)), + InMessage::GetCount => { + let count = state.count; + CallResponse::Reply(state, OutMsg::Count(count)) + } InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)), } } @@ -87,12 +90,12 @@ impl GenServer for WellBehavedTask { &mut self, _: Self::CastMsg, handle: &GenServerHandle, - state: &mut Self::State, - ) -> CastResponse { + mut state: Self::State, + ) -> CastResponse { state.count += 1; println!("{:?}: good still alive", thread::current().id()); send_after(Duration::from_millis(100), handle.to_owned(), ()); - CastResponse::NoReply + CastResponse::NoReply(state) } } diff --git a/examples/name_server/src/server.rs b/examples/name_server/src/server.rs index 77a8b58..6bf6b30 100644 --- a/examples/name_server/src/server.rs +++ b/examples/name_server/src/server.rs @@ -40,18 +40,19 @@ impl GenServer for NameServer { &mut self, message: Self::CallMsg, _handle: &NameServerHandle, - state: &mut Self::State, - ) -> CallResponse { + mut state: Self::State, + ) -> CallResponse { match message.clone() { Self::CallMsg::Add { key, value } => { state.insert(key, value); - CallResponse::Reply(Self::OutMsg::Ok) + CallResponse::Reply(state, Self::OutMsg::Ok) } Self::CallMsg::Find { key } => match state.get(&key) { - Some(value) => CallResponse::Reply(Self::OutMsg::Found { - value: value.to_string(), - }), - None => CallResponse::Reply(Self::OutMsg::NotFound), + Some(result) => { + let value = result.to_string(); + CallResponse::Reply(state, Self::OutMsg::Found { value }) + } + None => CallResponse::Reply(state, Self::OutMsg::NotFound), }, } } @@ -60,8 +61,8 @@ impl GenServer for NameServer { &mut self, _message: Self::CastMsg, _handle: &NameServerHandle, - _state: &mut Self::State, - ) -> CastResponse { - CastResponse::NoReply + state: Self::State, + ) -> CastResponse { + CastResponse::NoReply(state) } } diff --git a/examples/name_server_with_error/src/server.rs b/examples/name_server_with_error/src/server.rs index 5906349..96c1091 100644 --- a/examples/name_server_with_error/src/server.rs +++ b/examples/name_server_with_error/src/server.rs @@ -15,8 +15,8 @@ impl NameServer { pub async fn add(server: &mut NameServerHandle, key: String, value: String) -> OutMessage { match server.call(InMessage::Add { key, value }).await { Ok(_) => OutMessage::Ok, - Err(GenServerError::ServerError) => OutMessage::ServerError, - Err(GenServerError::CallbackError) => OutMessage::CallbackError, + Err(GenServerError::Callback) => OutMessage::CallbackError, + Err(_) => OutMessage::ServerError, } } @@ -43,22 +43,23 @@ impl GenServer for NameServer { &mut self, message: Self::CallMsg, _handle: &NameServerHandle, - state: &mut Self::State, - ) -> CallResponse { + mut state: Self::State, + ) -> CallResponse { match message.clone() { Self::CallMsg::Add { key, value } => { state.insert(key.clone(), value); if key == "error" { panic!("error!") } else { - CallResponse::Reply(Self::OutMsg::Ok) + CallResponse::Reply(state, Self::OutMsg::Ok) } } Self::CallMsg::Find { key } => match state.get(&key) { - Some(value) => CallResponse::Reply(Self::OutMsg::Found { - value: value.to_string(), - }), - None => CallResponse::Reply(Self::OutMsg::NotFound), + Some(result) => { + let value = result.to_string(); + CallResponse::Reply(state, Self::OutMsg::Found { value }) + } + None => CallResponse::Reply(state, Self::OutMsg::NotFound), }, } } @@ -67,8 +68,8 @@ impl GenServer for NameServer { &mut self, _message: Self::CastMsg, _handle: &NameServerHandle, - _state: &mut Self::State, - ) -> CastResponse { - CastResponse::NoReply + state: Self::State, + ) -> CastResponse { + CastResponse::NoReply(state) } } diff --git a/examples/updater/src/main.rs b/examples/updater/src/main.rs index 5119b9e..c01f3c6 100644 --- a/examples/updater/src/main.rs +++ b/examples/updater/src/main.rs @@ -8,22 +8,17 @@ mod server; use std::{thread, time::Duration}; -use messages::UpdaterOutMessage; use server::{UpdateServerState, UpdaterServer}; use spawned_concurrency::tasks::GenServer as _; use spawned_rt::tasks as rt; fn main() { rt::run(async { - let mut update_server = UpdaterServer::start(UpdateServerState { + UpdaterServer::start(UpdateServerState { url: "https://httpbin.org/ip".to_string(), periodicity: Duration::from_millis(1000), }); - let result = UpdaterServer::check(&mut update_server).await; - tracing::info!("Update check done: {result:?}"); - assert_eq!(result, UpdaterOutMessage::Ok); - // giving it some time before ending thread::sleep(Duration::from_secs(10)); }) diff --git a/examples/updater/src/server.rs b/examples/updater/src/server.rs index fa75a8b..5c610a2 100644 --- a/examples/updater/src/server.rs +++ b/examples/updater/src/server.rs @@ -15,15 +15,6 @@ pub struct UpdateServerState { } pub struct UpdaterServer {} -impl UpdaterServer { - pub async fn check(server: &mut UpdateServerHandle) -> OutMessage { - match server.cast(InMessage::Check).await { - Ok(_) => OutMessage::Ok, - Err(_) => OutMessage::Error, - } - } -} - impl GenServer for UpdaterServer { type CallMsg = (); type CastMsg = InMessage; @@ -35,21 +26,31 @@ impl GenServer for UpdaterServer { Self {} } + // Initializing GenServer to start periodic checks + async fn init( + &mut self, + handle: &GenServerHandle, + state: Self::State, + ) -> Result { + send_after(state.periodicity, handle.clone(), InMessage::Check); + Ok(state) + } + async fn handle_call( &mut self, _message: Self::CallMsg, _handle: &UpdateServerHandle, - _state: &mut Self::State, - ) -> CallResponse { - CallResponse::Reply(OutMessage::Ok) + state: Self::State, + ) -> CallResponse { + CallResponse::Reply(state, OutMessage::Ok) } async fn handle_cast( &mut self, message: Self::CastMsg, handle: &UpdateServerHandle, - state: &mut Self::State, - ) -> CastResponse { + state: Self::State, + ) -> CastResponse { match message { Self::CastMsg::Check => { send_after(state.periodicity, handle.clone(), InMessage::Check); @@ -59,7 +60,7 @@ impl GenServer for UpdaterServer { tracing::info!("Response: {resp:?}"); - CastResponse::NoReply + CastResponse::NoReply(state) } } } diff --git a/examples/updater_threads/src/main.rs b/examples/updater_threads/src/main.rs index 64236be..b4409b5 100644 --- a/examples/updater_threads/src/main.rs +++ b/examples/updater_threads/src/main.rs @@ -8,22 +8,17 @@ mod server; use std::{thread, time::Duration}; -use messages::UpdaterOutMessage; use server::{UpdateServerState, UpdaterServer}; use spawned_concurrency::threads::GenServer as _; use spawned_rt::threads as rt; fn main() { rt::run(|| { - let mut update_server = UpdaterServer::start(UpdateServerState { + UpdaterServer::start(UpdateServerState { url: "https://httpbin.org/ip".to_string(), periodicity: Duration::from_millis(1000), }); - let result = UpdaterServer::check(&mut update_server); - tracing::info!("Update check done: {result:?}"); - assert_eq!(result, UpdaterOutMessage::Ok); - // giving it some time before ending thread::sleep(Duration::from_secs(10)); }) diff --git a/examples/updater_threads/src/server.rs b/examples/updater_threads/src/server.rs index d26e447..bd5e6cd 100644 --- a/examples/updater_threads/src/server.rs +++ b/examples/updater_threads/src/server.rs @@ -16,15 +16,6 @@ pub struct UpdateServerState { } pub struct UpdaterServer {} -impl UpdaterServer { - pub fn check(server: &mut UpdateServerHandle) -> OutMessage { - match server.cast(InMessage::Check) { - Ok(_) => OutMessage::Ok, - Err(_) => OutMessage::Error, - } - } -} - impl GenServer for UpdaterServer { type CallMsg = (); type CastMsg = InMessage; @@ -36,21 +27,31 @@ impl GenServer for UpdaterServer { Self {} } + // Initializing GenServer to start periodic checks. + fn init( + &mut self, + handle: &GenServerHandle, + state: Self::State, + ) -> Result { + send_after(state.periodicity, handle.clone(), InMessage::Check); + Ok(state) + } + fn handle_call( &mut self, _message: Self::CallMsg, _handle: &UpdateServerHandle, - _state: &mut Self::State, - ) -> CallResponse { - CallResponse::Reply(OutMessage::Ok) + state: Self::State, + ) -> CallResponse { + CallResponse::Reply(state, OutMessage::Ok) } fn handle_cast( &mut self, message: Self::CastMsg, handle: &UpdateServerHandle, - state: &mut Self::State, - ) -> CastResponse { + state: Self::State, + ) -> CastResponse { match message { Self::CastMsg::Check => { send_after(state.periodicity, handle.clone(), InMessage::Check); @@ -60,7 +61,7 @@ impl GenServer for UpdaterServer { tracing::info!("Response: {resp:?}"); - CastResponse::NoReply + CastResponse::NoReply(state) } } } diff --git a/rt/src/tasks/tokio/mod.rs b/rt/src/tasks/tokio/mod.rs index 51a3877..8131b27 100644 --- a/rt/src/tasks/tokio/mod.rs +++ b/rt/src/tasks/tokio/mod.rs @@ -6,5 +6,4 @@ pub use tokio::{ runtime::Runtime, task::{spawn, spawn_blocking, JoinHandle}, time::sleep, - test, };