Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/migrating-to-0.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ The `finalize` and `on_cancel` hooks follow the same pattern:

```rust
impl TypedExecutor<Thumbnail> for ThumbnailExec {
async fn finalize(&self, thumb: Thumbnail, ctx: &TaskContext) -> Result<(), TaskError> {
async fn finalize(&self, thumb: Thumbnail, _memo: (), ctx: &TaskContext) -> Result<(), TaskError> {
// called after all children settle
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl TypedExecutor<MultipartUpload> for MultipartUploader {
}

async fn finalize(
&self, upload: MultipartUpload, ctx: &TaskContext,
&self, upload: MultipartUpload, _memo: (), ctx: &TaskContext,
) -> Result<(), TaskError> {
// Called after all children complete
complete_multipart_upload(&upload).await
Expand Down
3 changes: 3 additions & 0 deletions migrations/009_memo.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- Execute-to-finalize memo: typed state persisted between phases.
ALTER TABLE tasks ADD COLUMN memo BLOB DEFAULT NULL;
ALTER TABLE task_history ADD COLUMN memo BLOB DEFAULT NULL;
130 changes: 107 additions & 23 deletions src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use serde::{de::DeserializeOwned, Serialize};

use crate::module::{ExecutorOptions, ModuleExecutor, ModuleHandle};
use crate::priority::Priority;
use crate::registry::{DomainTaskContext, ErasedExecutor, TaskContext, TaskExecutor};
Expand Down Expand Up @@ -181,19 +183,29 @@ pub struct TaskTypeOptions {
/// }
/// }
/// ```
pub trait TypedExecutor<T: TypedTask>: Send + Sync + 'static {
pub trait TypedExecutor<
T: TypedTask,
Memo: Serialize + DeserializeOwned + Send + Sync + 'static = (),
>: Send + Sync + 'static
{
/// Primary execution. Called once per dispatch.
///
/// Returns a `Memo` that will be persisted and passed to [`finalize()`](Self::finalize)
/// after all children complete. For the default `Memo = ()`, the return type
/// is `Result<(), TaskError>` — identical to the pre-memo API.
fn execute<'a>(
&'a self,
payload: T,
ctx: DomainTaskContext<'a, T::Domain>,
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a;
) -> impl Future<Output = Result<Memo, TaskError>> + Send + 'a;

/// Called when all child tasks spawned by this task have settled.
/// Receives the `Memo` returned by [`execute()`](Self::execute).
/// Default: no-op.
fn finalize<'a>(
&'a self,
_payload: T,
_memo: Memo,
_ctx: DomainTaskContext<'a, T::Domain>,
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a {
async { Ok(()) }
Expand All @@ -212,26 +224,46 @@ pub trait TypedExecutor<T: TypedTask>: Send + Sync + 'static {

// ── TypedExecutorAdapter ─────────────────────────────────────────────

/// Internal adapter that wraps a [`TypedExecutor<T>`] into a [`TaskExecutor`]
/// Internal adapter that wraps a [`TypedExecutor<T, Memo>`] into a [`TaskExecutor`]
/// for the scheduler engine.
///
/// Handles payload deserialization before delegating to the typed executor.
struct TypedExecutorAdapter<T, E> {
/// Handles payload deserialization and memo serialization/deserialization.
struct TypedExecutorAdapter<T, M, E> {
executor: E,
_marker: PhantomData<fn() -> T>,
_marker: PhantomData<fn() -> (T, M)>,
}

impl<T: TypedTask, E: TypedExecutor<T>> TaskExecutor for TypedExecutorAdapter<T, E> {
async fn execute<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
impl<T, M, E> TaskExecutor for TypedExecutorAdapter<T, M, E>
where
T: TypedTask,
M: Serialize + DeserializeOwned + Send + Sync + 'static,
E: TypedExecutor<T, M>,
{
async fn execute<'a>(&'a self, ctx: &'a TaskContext) -> Result<Option<Vec<u8>>, TaskError> {
let payload: T = ctx.payload()?;
let dctx = DomainTaskContext::<T::Domain>::new(ctx);
self.executor.execute(payload, dctx).await
let memo = self.executor.execute(payload, dctx).await?;

// Don't persist () — serialize to None.
if std::any::TypeId::of::<M>() == std::any::TypeId::of::<()>() {
return Ok(None);
}

let bytes = serde_json::to_vec(&memo)
.map_err(|e| TaskError::permanent(format!("memo serialization: {e}")))?;
Ok(Some(bytes))
}

async fn finalize<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
let payload: T = ctx.payload()?;
let memo: M = match &ctx.record().memo {
Some(bytes) => serde_json::from_slice(bytes)
.map_err(|e| TaskError::permanent(format!("memo deserialization: {e}")))?,
None => serde_json::from_value(serde_json::Value::Null)
.map_err(|e| TaskError::permanent(format!("memo deserialization: {e}")))?,
};
let dctx = DomainTaskContext::<T::Domain>::new(ctx);
self.executor.finalize(payload, dctx).await
self.executor.finalize(payload, memo, dctx).await
}

async fn on_cancel<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
Expand All @@ -241,6 +273,19 @@ impl<T: TypedTask, E: TypedExecutor<T>> TaskExecutor for TypedExecutorAdapter<T,
}
}

/// Build an erased executor from a typed executor and memo type.
fn erase_executor<T, M, E>(executor: E) -> Arc<dyn ErasedExecutor>
where
T: TypedTask,
M: Serialize + DeserializeOwned + Send + Sync + 'static,
E: TypedExecutor<T, M>,
{
Arc::new(TypedExecutorAdapter {
executor,
_marker: PhantomData::<fn() -> (T, M)>,
})
}

// ── Domain<D> ────────────────────────────────────────────────────────

/// A typed module builder that enforces the link between a [`DomainKey`],
Expand Down Expand Up @@ -317,7 +362,36 @@ impl<D: DomainKey> Domain<D> {
T: TypedTask<Domain = D>,
{
let config = T::config();
self.task_inner::<T>(executor, config.ttl, config.retry_policy)
self.task_inner::<T>(
erase_executor::<T, (), _>(executor),
config.ttl,
config.retry_policy,
)
}

/// Register a typed executor that produces a memo in `execute()` which
/// is persisted and passed to `finalize()`.
///
/// Both `T` and `Memo` are inferred from the executor's
/// `TypedExecutor<T, Memo>` impl — turbofish is only needed when the
/// executor is generic over task types.
///
/// # Example
///
/// ```ignore
/// domain.task_memo(ScanL1Executor)
/// ```
pub fn task_memo<T, Memo>(self, executor: impl TypedExecutor<T, Memo>) -> Self
where
T: TypedTask<Domain = D>,
Memo: Serialize + DeserializeOwned + Send + Sync + 'static,
{
let config = T::config();
self.task_inner::<T>(
erase_executor::<T, Memo, _>(executor),
config.ttl,
config.retry_policy,
)
}

/// Register a typed executor with per-type option overrides.
Expand All @@ -332,25 +406,35 @@ impl<D: DomainKey> Domain<D> {
let config = T::config();
let ttl = options.ttl.or(config.ttl);
let retry_policy = options.retry_policy.or(config.retry_policy);
self.task_inner::<T>(executor, ttl, retry_policy)
self.task_inner::<T>(erase_executor::<T, (), _>(executor), ttl, retry_policy)
}

fn task_inner<T>(
mut self,
executor: impl TypedExecutor<T>,
ttl: Option<Duration>,
retry_policy: Option<RetryPolicy>,
/// Like [`task_with()`](Self::task_with), but for executors that produce
/// a memo (see [`task_memo()`](Self::task_memo)).
pub fn task_with_memo<T, Memo>(
self,
executor: impl TypedExecutor<T, Memo>,
options: TaskTypeOptions,
) -> Self
where
T: TypedTask,
T: TypedTask<Domain = D>,
Memo: Serialize + DeserializeOwned + Send + Sync + 'static,
{
let adapter = TypedExecutorAdapter {
executor,
_marker: PhantomData::<fn() -> T>,
};
let config = T::config();
let ttl = options.ttl.or(config.ttl);
let retry_policy = options.retry_policy.or(config.retry_policy);
self.task_inner::<T>(erase_executor::<T, Memo, _>(executor), ttl, retry_policy)
}

fn task_inner<T: TypedTask>(
mut self,
executor: Arc<dyn ErasedExecutor>,
ttl: Option<Duration>,
retry_policy: Option<RetryPolicy>,
) -> Self {
self.executors.push(ModuleExecutor {
task_type: T::TASK_TYPE.to_string(),
executor: Arc::new(adapter) as Arc<dyn ErasedExecutor>,
executor,
options: ExecutorOptions { ttl, retry_policy },
});
self
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@
//! Ok(())
//! }
//!
//! async fn finalize(&self, upload: MultipartUpload, ctx: DomainTaskContext<'_, Uploads>) -> Result<(), TaskError> {
//! async fn finalize(&self, upload: MultipartUpload, _memo: (), ctx: DomainTaskContext<'_, Uploads>) -> Result<(), TaskError> {
//! // All parts uploaded — complete the multipart upload.
//! complete_multipart(&upload).await?;
//! Ok(())
Expand Down
12 changes: 8 additions & 4 deletions src/registry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ pub(crate) trait TaskExecutor: Send + Sync + 'static {
/// - `ctx`: Execution context with the task record, cancellation token,
/// and progress reporter.
///
/// On success, return `Ok(())`. Use [`TaskContext::record_read_bytes`]
/// On success, return `Ok(None)` or `Ok(Some(bytes))` with serialized
/// memo data to pass to `finalize()`. Use [`TaskContext::record_read_bytes`]
/// and [`TaskContext::record_write_bytes`] to report IO during execution.
/// On failure, return a [`TaskError`] indicating whether retry is appropriate.
fn execute<'a>(
&'a self,
ctx: &'a TaskContext,
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a;
) -> impl Future<Output = Result<Option<Vec<u8>>, TaskError>> + Send + 'a;

/// Called after all children of a parent task have completed.
///
Expand Down Expand Up @@ -110,6 +111,9 @@ pub struct TaskTypeRegistry {
type_retry_policies: HashMap<String, RetryPolicy>,
}

/// Serialized memo bytes returned by `execute_erased`.
type MemoBytes = Option<Vec<u8>>;

/// Object-safe wrapper around [`TaskExecutor`] for dynamic dispatch in the registry.
///
/// This trait exists because RPITIT (`impl Future`) in `TaskExecutor` is not
Expand All @@ -119,7 +123,7 @@ pub(crate) trait ErasedExecutor: Send + Sync + 'static {
fn execute_erased<'a>(
&'a self,
ctx: &'a TaskContext,
) -> std::pin::Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>>;
) -> std::pin::Pin<Box<dyn Future<Output = Result<MemoBytes, TaskError>> + Send + 'a>>;

fn finalize_erased<'a>(
&'a self,
Expand All @@ -136,7 +140,7 @@ impl<T: TaskExecutor> ErasedExecutor for T {
fn execute_erased<'a>(
&'a self,
ctx: &'a TaskContext,
) -> std::pin::Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>> {
) -> std::pin::Pin<Box<dyn Future<Output = Result<MemoBytes, TaskError>> + Send + 'a>> {
Box::pin(self.execute(ctx))
}

Expand Down
7 changes: 5 additions & 2 deletions src/scheduler/spawn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ pub(crate) async fn spawn_task(

let result = match phase {
ExecutionPhase::Execute => executor.execute_erased(&prepared.ctx).await,
ExecutionPhase::Finalize => executor.finalize_erased(&prepared.ctx).await,
ExecutionPhase::Finalize => {
executor.finalize_erased(&prepared.ctx).await.map(|()| None)
} // finalize doesn't produce a memo
};

// Read IO bytes from the context tracker.
Expand All @@ -121,11 +123,12 @@ pub(crate) async fn spawn_task(
drop(prepared.ctx);

match result {
Ok(()) => {
Ok(memo) => {
completion::handle_success(
&task,
phase,
&metrics,
memo,
&completion_deps,
decrement_module,
)
Expand Down
3 changes: 2 additions & 1 deletion src/scheduler/spawn/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub(crate) async fn handle_success(
task: &TaskRecord,
phase: ExecutionPhase,
metrics: &IoBudget,
memo: Option<Vec<u8>>,
deps: &CompletionDeps,
decrement_module: impl FnOnce(),
) {
Expand All @@ -46,7 +47,7 @@ pub(crate) async fn handle_success(
{
match deps.store.active_children_count(task_id).await {
Ok(count) if count > 0 => {
if let Err(e) = deps.store.set_waiting(task_id).await {
if let Err(e) = deps.store.set_waiting(task_id, memo.as_deref()).await {
tracing::error!(task_id, error = %e, "failed to set task to waiting");
}
decrement_module();
Expand Down
1 change: 1 addition & 0 deletions src/scheduler/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ impl TypedExecutor<ParentTask> for FinalizeTrackingExecutor {
async fn finalize<'a>(
&'a self,
_payload: ParentTask,
_memo: (),
_ctx: DomainTaskContext<'a, ParentDomain>,
) -> Result<(), TaskError> {
self.finalized
Expand Down
Loading
Loading