Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ members = [
"examples/ping_pong",
"examples/ping_pong_threads",
"examples/updater",
"examples/updater_threads", "examples/blocking_genserver",
"examples/updater_threads",
"examples/blocking_genserver",
]

[workspace.dependencies]
Expand Down
5 changes: 3 additions & 2 deletions concurrency/src/tasks/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ pub trait GenServer
where
Self: Send + Sized,
{
type CallMsg: Send + Sized;
type CastMsg: Send + Sized;
type CallMsg: Clone + Send + Sized + Sync;
type CastMsg: Clone + Send + Sized + Sync;
type OutMsg: Send + Sized;
type State: Clone + Send;
type Error: Debug + Send;
Expand Down Expand Up @@ -259,6 +259,7 @@ where

#[cfg(test)]
mod tests {

use super::*;
use crate::tasks::send_after;
use std::{thread, time::Duration};
Expand Down
5 changes: 4 additions & 1 deletion concurrency/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ mod gen_server;
mod process;
mod time;

#[cfg(test)]
mod timer_tests;

pub use error::GenServerError;
pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg};
pub use process::{send, Process, ProcessInfo};
pub use time::send_after;
pub use time::{send_after, send_interval};
63 changes: 57 additions & 6 deletions concurrency/src/tasks/time.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,72 @@
use futures::future::select;
use std::time::Duration;

use spawned_rt::tasks::{self as rt, JoinHandle};
use spawned_rt::tasks::{self as rt, CancellationToken, JoinHandle};

use super::{GenServer, GenServerHandle};

pub struct TimerHandle {
pub join_handle: JoinHandle<()>,
pub cancellation_token: CancellationToken,
}

// Sends a message after a given period to the specified GenServer. The task terminates
// once the send has completed
pub fn send_after<T>(
period: Duration,
mut handle: GenServerHandle<T>,
message: T::CastMsg,
) -> JoinHandle<()>
) -> TimerHandle
where
T: GenServer + 'static,
{
let cancellation_token = CancellationToken::new();
let cloned_token = cancellation_token.clone();
let join_handle = rt::spawn(async move {
let _ = select(
Box::pin(cloned_token.cancelled()),
Box::pin(async {
rt::sleep(period).await;
let _ = handle.cast(message.clone()).await;
}),
)
.await;
});
TimerHandle {
join_handle,
cancellation_token,
}
}

// Sends a message to the specified GenServe repeatedly after `Time` milliseconds.
pub fn send_interval<T>(
period: Duration,
mut handle: GenServerHandle<T>,
message: T::CastMsg,
) -> TimerHandle
where
T: GenServer + 'static,
{
rt::spawn(async move {
rt::sleep(period).await;
let _ = handle.cast(message).await;
})
let cancellation_token = CancellationToken::new();
let cloned_token = cancellation_token.clone();
let join_handle = rt::spawn(async move {
loop {
let result = select(
Box::pin(cloned_token.cancelled()),
Box::pin(async {
rt::sleep(period).await;
let _ = handle.cast(message.clone()).await;
}),
)
.await;
match result {
futures::future::Either::Left(_) => break,
futures::future::Either::Right(_) => (),
}
}
});
TimerHandle {
join_handle,
cancellation_token,
}
}
248 changes: 248 additions & 0 deletions concurrency/src/tasks/timer_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
use crate::tasks::{send_interval, CallResponse, CastResponse, GenServer, GenServerHandle};
use spawned_rt::tasks::{self as rt, CancellationToken};
use std::time::Duration;

use super::send_after;

type RepeaterHandle = GenServerHandle<Repeater>;

#[derive(Clone)]
struct RepeaterState {
pub(crate) count: i32,
pub(crate) cancellation_token: Option<CancellationToken>,
}

#[derive(Clone)]
enum RepeaterCastMessage {
Inc,
StopTimer,
}

#[derive(Clone)]
enum RepeaterCallMessage {
GetCount,
}

#[derive(PartialEq, Debug)]
enum RepeaterOutMessage {
Count(i32),
}

struct Repeater;

impl Repeater {
pub async fn stop_timer(server: &mut RepeaterHandle) -> Result<(), ()> {
server
.cast(RepeaterCastMessage::StopTimer)
.await
.map_err(|_| ())
}

pub async fn get_count(server: &mut RepeaterHandle) -> Result<RepeaterOutMessage, ()> {
server
.call(RepeaterCallMessage::GetCount)
.await
.map_err(|_| ())
}
}

impl GenServer for Repeater {
type CallMsg = RepeaterCallMessage;
type CastMsg = RepeaterCastMessage;
type OutMsg = RepeaterOutMessage;
type State = RepeaterState;
type Error = ();

fn new() -> Self {
Self
}

async fn init(
&mut self,
handle: &RepeaterHandle,
mut state: Self::State,
) -> Result<Self::State, Self::Error> {
let timer = send_interval(
Duration::from_millis(100),
handle.clone(),
RepeaterCastMessage::Inc,
);
state.cancellation_token = Some(timer.cancellation_token);
Ok(state)
}

async fn handle_call(
&mut self,
_message: Self::CallMsg,
_handle: &RepeaterHandle,
state: Self::State,
) -> CallResponse<Self> {
let count = state.count;
CallResponse::Reply(state, RepeaterOutMessage::Count(count))
}

async fn handle_cast(
&mut self,
message: Self::CastMsg,
_handle: &GenServerHandle<Self>,
mut state: Self::State,
) -> CastResponse<Self> {
match message {
RepeaterCastMessage::Inc => {
state.count += 1;
}
RepeaterCastMessage::StopTimer => {
if let Some(ct) = state.cancellation_token.clone() {
ct.cancel()
};
}
};
CastResponse::NoReply(state)
}
}

#[test]
pub fn test_send_interval_and_cancellation() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
// Start a Repeater
let mut repeater = Repeater::start(RepeaterState {
count: 0,
cancellation_token: None,
});

// Wait for 1 second
rt::sleep(Duration::from_secs(1)).await;

// Check count
let count = Repeater::get_count(&mut repeater).await.unwrap();

// 9 messages in 1 second (after first 100 milliseconds sleep)
assert_eq!(RepeaterOutMessage::Count(9), count);

// Pause timer
Repeater::stop_timer(&mut repeater).await.unwrap();

// Wait another second
rt::sleep(Duration::from_secs(1)).await;

// Check count again
let count2 = Repeater::get_count(&mut repeater).await.unwrap();

// As timer was paused, count should remain at 9
assert_eq!(RepeaterOutMessage::Count(9), count2);
});
}

type DelayedHandle = GenServerHandle<Delayed>;

#[derive(Clone)]
struct DelayedState {
pub(crate) count: i32,
}

#[derive(Clone)]
enum DelayedCastMessage {
Inc,
}

#[derive(Clone)]
enum DelayedCallMessage {
GetCount,
}

#[derive(PartialEq, Debug)]
enum DelayedOutMessage {
Count(i32),
}

struct Delayed;

impl Delayed {
pub async fn get_count(server: &mut DelayedHandle) -> Result<DelayedOutMessage, ()> {
server
.call(DelayedCallMessage::GetCount)
.await
.map_err(|_| ())
}
}

impl GenServer for Delayed {
type CallMsg = DelayedCallMessage;
type CastMsg = DelayedCastMessage;
type OutMsg = DelayedOutMessage;
type State = DelayedState;
type Error = ();

fn new() -> Self {
Self
}

async fn handle_call(
&mut self,
_message: Self::CallMsg,
_handle: &DelayedHandle,
state: Self::State,
) -> CallResponse<Self> {
let count = state.count;
CallResponse::Reply(state, DelayedOutMessage::Count(count))
}

async fn handle_cast(
&mut self,
message: Self::CastMsg,
_handle: &DelayedHandle,
mut state: Self::State,
) -> CastResponse<Self> {
match message {
DelayedCastMessage::Inc => {
state.count += 1;
}
};
CastResponse::NoReply(state)
}
}

#[test]
pub fn test_send_after_and_cancellation() {
let runtime = rt::Runtime::new().unwrap();
runtime.block_on(async move {
// Start a Delayed
let mut repeater = Delayed::start(DelayedState { count: 0 });

// Set a just once timed message
let _ = send_after(
Duration::from_millis(100),
repeater.clone(),
DelayedCastMessage::Inc,
);

// Wait for 200 milliseconds
rt::sleep(Duration::from_millis(200)).await;

// Check count
let count = Delayed::get_count(&mut repeater).await.unwrap();

// Only one message (no repetition)
assert_eq!(DelayedOutMessage::Count(1), count);

// New timer
let timer = send_after(
Duration::from_millis(100),
repeater.clone(),
DelayedCastMessage::Inc,
);

// Cancel the new timer before timeout
timer.cancellation_token.cancel();

// Wait another 200 milliseconds
rt::sleep(Duration::from_millis(200)).await;

// Check count again
let count2 = Delayed::get_count(&mut repeater).await.unwrap();

// As timer was cancelled, count should remain at 1
assert_eq!(DelayedOutMessage::Count(1), count2);
});
}
4 changes: 2 additions & 2 deletions concurrency/src/threads/gen_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ pub trait GenServer
where
Self: Send + Sized,
{
type CallMsg: Send + Sized;
type CastMsg: Send + Sized;
type CallMsg: Clone + Send + Sized;
type CastMsg: Clone + Send + Sized;
type OutMsg: Send + Sized;
type State: Clone + Send;
type Error: Debug;
Expand Down
5 changes: 4 additions & 1 deletion concurrency/src/threads/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ mod gen_server;
mod process;
mod time;

#[cfg(test)]
mod timer_tests;

pub use gen_server::{CallResponse, CastResponse, GenServer, GenServerHandle, GenServerInMsg};
pub use process::{send, Process, ProcessInfo};
pub use time::send_after;
pub use time::{send_after, send_interval};
Loading