Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions concurrency/src/tasks/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#[derive(Debug)]
pub enum GenServerError {
CallbackError,
ServerError,
Callback,
Initialization,
Server,
}

impl<T> From<spawned_rt::tasks::mpsc::SendError<T>> for GenServerError {
fn from(_value: spawned_rt::tasks::mpsc::SendError<T>) -> Self {
Self::ServerError
Self::Server
}
}
141 changes: 86 additions & 55 deletions concurrency/src/tasks/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ impl<G: GenServer> Clone for GenServerHandle<G> {
}

impl<G: GenServer> GenServerHandle<G> {
pub(crate) fn new(mut initial_state: G::State) -> Self {
pub(crate) fn new(initial_state: G::State) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let handle = GenServerHandle { tx };
let mut gen_server: G = GenServer::new();
let handle_clone = handle.clone();
// Ignore the JoinHandle for now. Maybe we'll use it in the future
let _join_handle = rt::spawn(async move {
if gen_server
.run(&handle, &mut rx, &mut initial_state)
.run(&handle, &mut rx, initial_state)
.await
.is_err()
{
Expand All @@ -38,7 +38,7 @@ impl<G: GenServer> GenServerHandle<G> {
handle_clone
}

pub(crate) fn new_blocking(mut initial_state: G::State) -> Self {
pub(crate) fn new_blocking(initial_state: G::State) -> Self {
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
let handle = GenServerHandle { tx };
let mut gen_server: G = GenServer::new();
Expand All @@ -47,7 +47,7 @@ impl<G: GenServer> GenServerHandle<G> {
let _join_handle = rt::spawn_blocking(|| {
rt::block_on(async move {
if gen_server
.run(&handle, &mut rx, &mut initial_state)
.run(&handle, &mut rx, initial_state)
.await
.is_err()
{
Expand All @@ -70,34 +70,34 @@ impl<G: GenServer> GenServerHandle<G> {
})?;
match oneshot_rx.await {
Ok(result) => result,
Err(_) => Err(GenServerError::ServerError),
Err(_) => Err(GenServerError::Server),
}
}

pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
self.tx
.send(GenServerInMsg::Cast { message })
.map_err(|_error| GenServerError::ServerError)
.map_err(|_error| GenServerError::Server)
}
}

pub enum GenServerInMsg<A: GenServer> {
pub enum GenServerInMsg<G: GenServer> {
Call {
sender: oneshot::Sender<Result<A::OutMsg, GenServerError>>,
message: A::CallMsg,
sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
message: G::CallMsg,
},
Cast {
message: A::CastMsg,
message: G::CastMsg,
},
}

pub enum CallResponse<U> {
Reply(U),
Stop(U),
pub enum CallResponse<G: GenServer> {
Reply(G::State, G::OutMsg),
Stop(G::OutMsg),
}

pub enum CastResponse {
NoReply,
pub enum CastResponse<G: GenServer> {
NoReply(G::State),
Stop,
}

Expand All @@ -109,7 +109,7 @@ where
type CastMsg: Send + Sized;
type OutMsg: Send + Sized;
type State: Clone + Send;
type Error: Debug;
type Error: Debug + Send;

fn new() -> Self;

Expand All @@ -130,25 +130,46 @@ where
&mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
state: &mut Self::State,
state: Self::State,
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
self.main_loop(handle, rx, state).await?;
Ok(())
match self.init(handle, state).await {
Ok(new_state) => {
self.main_loop(handle, rx, new_state).await?;
Ok(())
}
Err(err) => {
tracing::error!("Initialization failed: {err:?}");
Err(GenServerError::Initialization)
}
}
}
}

/// Initialization function. It's called before main loop. It
/// can be overrided on implementations in case initial steps are
/// required.
fn init(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should probably be documented / have an example that utilizes it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added documentation and implemented it in bank examples

&mut self,
_handle: &GenServerHandle<Self>,
state: Self::State,
) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
async { Ok(state) }
}

fn main_loop(
&mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
state: &mut Self::State,
mut state: Self::State,
) -> impl Future<Output = Result<(), GenServerError>> + Send {
async {
loop {
if !self.receive(handle, rx, state).await? {
let (new_state, cont) = self.receive(handle, rx, state).await?;
if !cont {
break;
}
state = new_state;
}
tracing::trace!("Stopping GenServer");
Ok(())
Expand All @@ -159,81 +180,88 @@ where
&mut self,
handle: &GenServerHandle<Self>,
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
state: &mut Self::State,
) -> impl std::future::Future<Output = Result<bool, GenServerError>> + Send {
async {
state: Self::State,
) -> impl std::future::Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
async move {
let message = rx.recv().await;

// Save current state in case of a rollback
let state_clone = state.clone();

let (keep_running, error) = match message {
let (keep_running, new_state) = match message {
Some(GenServerInMsg::Call { sender, message }) => {
let (keep_running, error, response) =
let (keep_running, new_state, response) =
match AssertUnwindSafe(self.handle_call(message, handle, state))
.catch_unwind()
.await
{
Ok(response) => match response {
CallResponse::Reply(response) => (true, None, Ok(response)),
CallResponse::Stop(response) => (false, None, Ok(response)),
CallResponse::Reply(new_state, response) => {
(true, new_state, Ok(response))
}
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
},
Err(error) => (true, Some(error), Err(GenServerError::CallbackError)),
Err(error) => {
tracing::trace!(
"Error in callback, reverting state - Error: '{error:?}'"
);
(true, state_clone, Err(GenServerError::Callback))
}
};
// Send response back
if sender.send(response).is_err() {
tracing::trace!(
"GenServer failed to send response back, client must have died"
)
};
(keep_running, error)
(keep_running, new_state)
}
Some(GenServerInMsg::Cast { message }) => {
match AssertUnwindSafe(self.handle_cast(message, handle, state))
.catch_unwind()
.await
{
Ok(response) => match response {
CastResponse::NoReply => (true, None),
CastResponse::Stop => (false, None),
CastResponse::NoReply(new_state) => (true, new_state),
CastResponse::Stop => (false, state_clone),
},
Err(error) => (true, Some(error)),
Err(error) => {
tracing::trace!(
"Error in callback, reverting state - Error: '{error:?}'"
);
(true, state_clone)
}
}
}
None => {
// Channel has been closed; won't receive further messages. Stop the server.
(false, None)
(false, state)
}
};
if let Some(error) = error {
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
// Restore initial state (ie. dismiss any change)
*state = state_clone;
};
Ok(keep_running)
Ok((new_state, keep_running))
}
}

fn handle_call(
&mut self,
message: Self::CallMsg,
handle: &GenServerHandle<Self>,
state: &mut Self::State,
) -> impl std::future::Future<Output = CallResponse<Self::OutMsg>> + Send;
state: Self::State,
) -> impl std::future::Future<Output = CallResponse<Self>> + Send;

fn handle_cast(
&mut self,
message: Self::CastMsg,
handle: &GenServerHandle<Self>,
state: &mut Self::State,
) -> impl std::future::Future<Output = CastResponse> + Send;
state: Self::State,
) -> impl std::future::Future<Output = CastResponse<Self>> + Send;
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tasks::send_after;
use std::{process::exit, thread, time::Duration};
use std::{thread, time::Duration};
struct BadlyBehavedTask;

#[derive(Clone)]
Expand Down Expand Up @@ -261,17 +289,17 @@ mod tests {
&mut self,
_: Self::CallMsg,
_: &GenServerHandle<Self>,
_: &mut Self::State,
) -> CallResponse<Self::OutMsg> {
_: Self::State,
) -> CallResponse<Self> {
CallResponse::Stop(())
}

async fn handle_cast(
&mut self,
_: Self::CastMsg,
_: &GenServerHandle<Self>,
_: &mut Self::State,
) -> CastResponse {
_: Self::State,
) -> CastResponse<Self> {
rt::sleep(Duration::from_millis(20)).await;
thread::sleep(Duration::from_secs(2));
CastResponse::Stop
Expand Down Expand Up @@ -300,10 +328,13 @@ mod tests {
&mut self,
message: Self::CallMsg,
_: &GenServerHandle<Self>,
state: &mut Self::State,
) -> CallResponse<Self::OutMsg> {
state: Self::State,
) -> CallResponse<Self> {
match message {
InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)),
InMessage::GetCount => {
let count = state.count;
CallResponse::Reply(state, OutMsg::Count(count))
}
InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
}
}
Expand All @@ -312,12 +343,12 @@ mod tests {
&mut self,
_: Self::CastMsg,
handle: &GenServerHandle<Self>,
state: &mut Self::State,
) -> CastResponse {
mut state: Self::State,
) -> CastResponse<Self> {
state.count += 1;
println!("{:?}: good still alive", thread::current().id());
send_after(Duration::from_millis(100), handle.to_owned(), ());
CastResponse::NoReply
CastResponse::NoReply(state)
}
}

Expand Down
7 changes: 4 additions & 3 deletions concurrency/src/threads/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#[derive(Debug)]
pub enum GenServerError {
CallbackError,
ServerError,
Callback,
Initialization,
Server,
}

impl<T> From<spawned_rt::threads::mpsc::SendError<T>> for GenServerError {
fn from(_value: spawned_rt::threads::mpsc::SendError<T>) -> Self {
Self::ServerError
Self::Server
}
}
Loading