diff --git a/migrations/007_tags.sql b/migrations/007_tags.sql new file mode 100644 index 0000000..dc3ae29 --- /dev/null +++ b/migrations/007_tags.sql @@ -0,0 +1,18 @@ +-- Task metadata tags (key-value pairs). +CREATE TABLE IF NOT EXISTS task_tags ( + task_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (task_id, key) +); + +CREATE INDEX IF NOT EXISTS idx_task_tags_kv ON task_tags(key, value); + +CREATE TABLE IF NOT EXISTS task_history_tags ( + history_rowid INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + PRIMARY KEY (history_rowid, key) +); + +CREATE INDEX IF NOT EXISTS idx_history_tags_kv ON task_history_tags(key, value); diff --git a/src/lib.rs b/src/lib.rs index 4ae4023..35ac374 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,6 +147,27 @@ //! (with `TtlFrom::Submission`), so a child can never outlive its parent's //! deadline. //! +//! ## Task metadata tags +//! +//! Tasks can carry schema-free key-value metadata tags for filtering, grouping, +//! and display — without deserializing the task payload. Tags are immutable +//! after submission and are persisted, indexed, and queryable. +//! +//! Set tags per-task via [`TaskSubmission::tag`], per-type via +//! [`TypedTask::tags`], or as batch defaults via +//! [`BatchSubmission::default_tag`]. Tag keys and values are validated at submit +//! time against [`MAX_TAG_KEY_LEN`], [`MAX_TAG_VALUE_LEN`], and +//! [`MAX_TAGS_PER_TASK`]. +//! +//! Child tasks inherit parent tags by default (child tags take precedence). +//! Tags are copied to history on all terminal transitions and are included in +//! [`TaskEventHeader`] for event subscribers. +//! +//! Query by tags with [`TaskStore::tasks_by_tags`] (AND semantics), +//! [`TaskStore::count_by_tag`] (grouped counts), or +//! [`TaskStore::tag_values`] (distinct values). Cancel by tag with +//! [`Scheduler::cancel_by_tag`]. +//! //! ## Delayed & scheduled tasks //! //! A task can declare **when** it becomes eligible for dispatch: @@ -774,7 +795,8 @@ pub use task::{ generate_dedup_key, BatchOutcome, BatchSubmission, DependencyFailurePolicy, DuplicateStrategy, HistoryStatus, IoBudget, ParentResolution, RecurringSchedule, RecurringScheduleInfo, SubmitOutcome, TaskError, TaskHistoryRecord, TaskLookup, TaskRecord, TaskStatus, - TaskSubmission, TtlFrom, TypeStats, TypedTask, + TaskSubmission, TtlFrom, TypeStats, TypedTask, MAX_TAGS_PER_TASK, MAX_TAG_KEY_LEN, + MAX_TAG_VALUE_LEN, }; #[cfg(feature = "sysinfo-monitor")] diff --git a/src/registry/child_spawner.rs b/src/registry/child_spawner.rs index 0b8e088..884d402 100644 --- a/src/registry/child_spawner.rs +++ b/src/registry/child_spawner.rs @@ -1,5 +1,6 @@ //! Child task spawning from within an executor. +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -8,6 +9,16 @@ use chrono::{DateTime, Utc}; use crate::store::{StoreError, TaskStore}; use crate::task::{SubmitOutcome, TaskSubmission, TtlFrom}; +/// Inherited parent context for child spawning: TTL and tags. +#[derive(Clone)] +pub(crate) struct ParentContext { + pub created_at: DateTime, + pub ttl_seconds: Option, + pub ttl_from: TtlFrom, + pub started_at: Option>, + pub tags: HashMap, +} + /// Handle for spawning child tasks from within an executor. /// /// Wraps a [`TaskStore`] reference and the parent task ID so that @@ -22,10 +33,7 @@ pub(crate) struct ChildSpawner { store: TaskStore, parent_id: i64, work_notify: Arc, - parent_created_at: DateTime, - parent_ttl_seconds: Option, - parent_ttl_from: TtlFrom, - parent_started_at: Option>, + parent: ParentContext, } impl ChildSpawner { @@ -33,19 +41,13 @@ impl ChildSpawner { store: TaskStore, parent_id: i64, work_notify: Arc, - parent_created_at: DateTime, - parent_ttl_seconds: Option, - parent_ttl_from: TtlFrom, - parent_started_at: Option>, + parent: ParentContext, ) -> Self { Self { store, parent_id, work_notify, - parent_created_at, - parent_ttl_seconds, - parent_ttl_from, - parent_started_at, + parent, } } @@ -55,15 +57,15 @@ impl ChildSpawner { if sub.ttl.is_some() { return; // Child has explicit TTL, don't override. } - let Some(parent_ttl_secs) = self.parent_ttl_seconds else { + let Some(parent_ttl_secs) = self.parent.ttl_seconds else { return; // Parent has no TTL. }; let parent_ttl = Duration::from_secs(parent_ttl_secs as u64); // Determine when the parent's TTL started. - let ttl_start = match self.parent_ttl_from { - TtlFrom::Submission => self.parent_created_at, - TtlFrom::FirstAttempt => match self.parent_started_at { + let ttl_start = match self.parent.ttl_from { + TtlFrom::Submission => self.parent.created_at, + TtlFrom::FirstAttempt => match self.parent.started_at { Some(started) => started, None => return, // Parent hasn't started yet, can't compute remaining. }, @@ -80,10 +82,18 @@ impl ChildSpawner { } } + /// Inherit parent tags into a child submission. Child tags take precedence. + fn inherit_tags(&self, sub: &mut TaskSubmission) { + for (k, v) in &self.parent.tags { + sub.tags.entry(k.clone()).or_insert_with(|| v.clone()); + } + } + /// Submit a single child task. Sets `parent_id` automatically. pub async fn spawn(&self, mut sub: TaskSubmission) -> Result { sub.parent_id = Some(self.parent_id); self.inherit_ttl(&mut sub); + self.inherit_tags(&mut sub); let outcome = self.store.submit(&sub).await?; self.work_notify.notify_one(); Ok(outcome) @@ -97,6 +107,7 @@ impl ChildSpawner { for sub in submissions.iter_mut() { sub.parent_id = Some(self.parent_id); self.inherit_ttl(sub); + self.inherit_tags(sub); } let outcomes = self.store.submit_batch(submissions).await?; self.work_notify.notify_one(); diff --git a/src/registry/mod.rs b/src/registry/mod.rs index a050e45..3d111e1 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::task::TaskError; -pub(crate) use child_spawner::ChildSpawner; +pub(crate) use child_spawner::{ChildSpawner, ParentContext}; pub use context::TaskContext; pub(crate) use io_tracker::IoTracker; pub(crate) use state::{StateMap, StateSnapshot}; diff --git a/src/scheduler/dispatch.rs b/src/scheduler/dispatch.rs index 875e2a0..4b337f5 100644 --- a/src/scheduler/dispatch.rs +++ b/src/scheduler/dispatch.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex}; use tokio_util::sync::CancellationToken; use crate::priority::Priority; -use crate::registry::{ChildSpawner, IoTracker, TaskContext}; +use crate::registry::{ChildSpawner, IoTracker, ParentContext, TaskContext}; use crate::store::TaskStore; use crate::task::{IoBudget, ParentResolution, TaskRecord}; @@ -300,10 +300,13 @@ pub(crate) async fn spawn_task( store.clone(), task.id, work_notify.clone(), - task.created_at, - task.ttl_seconds, - task.ttl_from, - task.started_at, + ParentContext { + created_at: task.created_at, + ttl_seconds: task.ttl_seconds, + ttl_from: task.ttl_from, + started_at: task.started_at, + tags: task.tags.clone(), + }, ); let io = Arc::new(IoTracker::new()); diff --git a/src/scheduler/event.rs b/src/scheduler/event.rs index 89360c9..2cda5c8 100644 --- a/src/scheduler/event.rs +++ b/src/scheduler/event.rs @@ -8,6 +8,8 @@ //! - [`DependencyFailed { task_id, failed_dependency }`](SchedulerEvent::DependencyFailed) //! — a blocked task was cancelled because a dependency failed +use std::collections::HashMap; + use serde::{Deserialize, Serialize}; use tokio::time::Duration; @@ -62,6 +64,8 @@ pub struct TaskEventHeader { pub task_type: String, pub key: String, pub label: String, + /// Key-value metadata tags from the task record. + pub tags: HashMap, } // ── Events ────────────────────────────────────────────────────────── diff --git a/src/scheduler/queries.rs b/src/scheduler/queries.rs index 08b6704..4f685aa 100644 --- a/src/scheduler/queries.rs +++ b/src/scheduler/queries.rs @@ -63,6 +63,35 @@ impl Scheduler { .collect() } + /// Find active tasks matching all specified tag filters (AND semantics). + /// + /// Delegates to [`TaskStore::tasks_by_tags`]. + pub async fn tasks_by_tags( + &self, + filters: &[(&str, &str)], + status: Option, + ) -> Result, StoreError> { + self.inner.store.tasks_by_tags(filters, status).await + } + + /// Count active tasks grouped by a tag key's values. + /// + /// Delegates to [`TaskStore::count_by_tag`]. + pub async fn count_by_tag( + &self, + key: &str, + status: Option, + ) -> Result, StoreError> { + self.inner.store.count_by_tag(key, status).await + } + + /// List distinct values for a tag key across active tasks, with counts. + /// + /// Delegates to [`TaskStore::tag_values`]. + pub async fn tag_values(&self, key: &str) -> Result, StoreError> { + self.inner.store.tag_values(key).await + } + /// Capture a single status snapshot for dashboard UIs. /// /// Gathers running tasks, queue depths, progress estimates, and diff --git a/src/scheduler/submit.rs b/src/scheduler/submit.rs index c141e81..f029bfb 100644 --- a/src/scheduler/submit.rs +++ b/src/scheduler/submit.rs @@ -57,6 +57,7 @@ impl Scheduler { task_type: sub.task_type.clone(), key: sub.effective_key(), label: sub.label.clone(), + tags: sub.tags.clone(), }; let _ = self.inner.event_tx.send(SchedulerEvent::Superseded { old: old_header, @@ -110,6 +111,7 @@ impl Scheduler { task_type: sub.task_type.clone(), key: sub.effective_key(), label: sub.label.clone(), + tags: sub.tags.clone(), }; let _ = self.inner.event_tx.send(SchedulerEvent::Superseded { old: old_header, @@ -294,6 +296,25 @@ impl Scheduler { Ok(cancelled) } + /// Cancel all active tasks matching a tag key-value pair. + /// + /// Finds tasks via [`TaskStore::tasks_by_tags`] and cancels each one. + /// Returns the ids of tasks that were successfully cancelled. + pub async fn cancel_by_tag(&self, key: &str, value: &str) -> Result, StoreError> { + let tasks = self + .inner + .store + .tasks_by_tags(&[(key, value)], None) + .await?; + let mut cancelled = Vec::new(); + for task in &tasks { + if self.cancel(task.id).await? { + cancelled.push(task.id); + } + } + Ok(cancelled) + } + /// Cancel all tasks matching a predicate. pub async fn cancel_where( &self, diff --git a/src/store/hierarchy.rs b/src/store/hierarchy.rs index 6c65a78..ad57810 100644 --- a/src/store/hierarchy.rs +++ b/src/store/hierarchy.rs @@ -402,6 +402,79 @@ mod tests { assert!(t.started_at.is_some()); } + #[tokio::test] + async fn child_inherits_parent_tags() { + use crate::registry::child_spawner::{ChildSpawner, ParentContext}; + use std::sync::Arc; + + let store = test_store().await; + let notify = Arc::new(tokio::sync::Notify::new()); + + // Submit a parent with tags. + let parent_sub = TaskSubmission::new("test") + .key("tagged-parent") + .tag("env", "prod") + .tag("region", "us-east"); + let parent_id = store.submit(&parent_sub).await.unwrap().id().unwrap(); + let parent = store.pop_next().await.unwrap().unwrap(); + + let ctx = ParentContext { + created_at: parent.created_at, + ttl_seconds: None, + ttl_from: crate::task::TtlFrom::Submission, + started_at: parent.started_at, + tags: parent.tags.clone(), + }; + let spawner = ChildSpawner::new(store.clone(), parent_id, notify, ctx); + + // Spawn a child without tags — should inherit parent tags. + let child_sub = TaskSubmission::new("test").key("child-no-tags"); + let outcome = spawner.spawn(child_sub).await.unwrap(); + let child_id = outcome.id().unwrap(); + + let child = store.task_by_id(child_id).await.unwrap().unwrap(); + assert_eq!(child.tags.get("env").unwrap(), "prod"); + assert_eq!(child.tags.get("region").unwrap(), "us-east"); + } + + #[tokio::test] + async fn child_overrides_parent_tag() { + use crate::registry::child_spawner::{ChildSpawner, ParentContext}; + use std::sync::Arc; + + let store = test_store().await; + let notify = Arc::new(tokio::sync::Notify::new()); + + let parent_sub = TaskSubmission::new("test") + .key("tagged-parent-2") + .tag("env", "prod") + .tag("region", "us-east"); + let parent_id = store.submit(&parent_sub).await.unwrap().id().unwrap(); + let parent = store.pop_next().await.unwrap().unwrap(); + + let ctx = ParentContext { + created_at: parent.created_at, + ttl_seconds: None, + ttl_from: crate::task::TtlFrom::Submission, + started_at: parent.started_at, + tags: parent.tags.clone(), + }; + let spawner = ChildSpawner::new(store.clone(), parent_id, notify, ctx); + + // Spawn a child that overrides "region" but inherits "env". + let child_sub = TaskSubmission::new("test") + .key("child-override") + .tag("region", "eu-west") + .tag("extra", "yes"); + let outcome = spawner.spawn(child_sub).await.unwrap(); + let child_id = outcome.id().unwrap(); + + let child = store.task_by_id(child_id).await.unwrap().unwrap(); + assert_eq!(child.tags.get("env").unwrap(), "prod"); // Inherited. + assert_eq!(child.tags.get("region").unwrap(), "eu-west"); // Overridden. + assert_eq!(child.tags.get("extra").unwrap(), "yes"); // Child's own. + } + #[tokio::test] async fn recover_preserves_waiting_parents() { let store = test_store().await; diff --git a/src/store/lifecycle.rs b/src/store/lifecycle.rs index 86bf82f..7318a32 100644 --- a/src/store/lifecycle.rs +++ b/src/store/lifecycle.rs @@ -69,6 +69,20 @@ pub(crate) async fn insert_history( ) .execute(&mut **conn) .await?; + + // Copy tags from task_tags to task_history_tags. + let history_rowid = sqlx::query_scalar::<_, i64>("SELECT last_insert_rowid()") + .fetch_one(&mut **conn) + .await?; + sqlx::query( + "INSERT INTO task_history_tags (history_rowid, key, value) + SELECT ?, key, value FROM task_tags WHERE task_id = ?", + ) + .bind(history_rowid) + .bind(task.id) + .execute(&mut **conn) + .await?; + Ok(()) } @@ -98,7 +112,11 @@ impl TaskStore { .fetch_optional(&self.pool) .await?; - Ok(row.as_ref().map(row_to_task_record)) + let mut record = row.as_ref().map(row_to_task_record); + if let Some(ref mut r) = record { + self.populate_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Atomically claim a specific pending task by id, setting it to running. @@ -126,7 +144,11 @@ impl TaskStore { .await?; tracing::debug!(task_id = id, "store.pop_by_id: UPDATE end"); - Ok(row.as_ref().map(row_to_task_record)) + let mut record = row.as_ref().map(row_to_task_record); + if let Some(ref mut r) = record { + self.populate_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Pop the highest-priority pending task and mark it as running. @@ -157,7 +179,11 @@ impl TaskStore { .fetch_optional(&self.pool) .await?; - Ok(row.map(|r| row_to_task_record(&r))) + let mut record = row.map(|r| row_to_task_record(&r)); + if let Some(ref mut r) = record { + self.populate_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Atomically requeue a running task back to pending. @@ -249,6 +275,16 @@ impl TaskStore { ) .await?; + // Read tags into memory before potential deletion (needed for recurring re-creation). + let saved_tags: Vec<(String, String)> = if task.recurring_interval_secs.is_some() { + sqlx::query_as("SELECT key, value FROM task_tags WHERE task_id = ?") + .bind(task.id) + .fetch_all(&mut **conn) + .await? + } else { + Vec::new() + }; + // Try to delete (normal completion, requeue = 0). let del = sqlx::query("DELETE FROM tasks WHERE id = ? AND requeue = 0") .bind(task.id) @@ -272,6 +308,9 @@ impl TaskStore { return Ok(None); } + // Task was deleted — clean up orphaned tags. + super::delete_task_tags(conn, task.id).await?; + // Handle recurring tasks: create the next instance after deleting // the completed one (to avoid UNIQUE constraint on key). let mut recurring_info = None; @@ -306,7 +345,7 @@ impl TaskStore { _ => None, }; - sqlx::query( + let recurring_result = sqlx::query( "INSERT INTO tasks (task_type, key, label, priority, status, payload, expected_read_bytes, expected_write_bytes, expected_net_rx_bytes, expected_net_tx_bytes, @@ -339,6 +378,19 @@ impl TaskStore { .execute(&mut **conn) .await?; + // Copy tags to the new recurring instance. + let next_id = recurring_result.last_insert_rowid(); + for (key, value) in &saved_tags { + sqlx::query( + "INSERT INTO task_tags (task_id, key, value) VALUES (?, ?, ?)", + ) + .bind(next_id) + .bind(key) + .bind(value) + .execute(&mut **conn) + .await?; + } + recurring_info = Some((next_run, execution_count)); } // If existing.is_some(), skip (pile-up prevention). @@ -433,6 +485,7 @@ impl TaskStore { insert_history(conn, task, "failed", metrics, duration_ms, Some(error)).await?; + super::delete_task_tags(conn, task.id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(task.id) .execute(&mut **conn) @@ -576,6 +629,7 @@ impl TaskStore { .execute(&mut **conn) .await?; + super::delete_task_tags(conn, dep_id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(dep_id) .execute(&mut **conn) @@ -672,6 +726,7 @@ impl TaskStore { .execute(&mut *conn) .await?; + super::delete_task_tags(&mut conn, id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(id) .execute(&mut *conn) @@ -709,6 +764,7 @@ impl TaskStore { .execute(&mut *conn) .await?; + super::delete_task_tags(&mut conn, task.id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(task.id) .execute(&mut *conn) @@ -748,7 +804,8 @@ impl TaskStore { let mut expired = Vec::with_capacity(rows.len()); for row in &rows { - let task = row_to_task_record(row); + let mut task = row_to_task_record(row); + task.tags = super::load_task_tags(&mut conn, task.id).await?; // Record in history as expired. insert_history( @@ -771,7 +828,8 @@ impl TaskStore { .await?; for child_row in &child_rows { - let child = row_to_task_record(child_row); + let mut child = row_to_task_record(child_row); + child.tags = super::load_task_tags(&mut conn, child.id).await?; insert_history( &mut conn, &child, @@ -781,6 +839,7 @@ impl TaskStore { None, ) .await?; + super::delete_task_tags(&mut conn, child.id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(child.id) .execute(&mut *conn) @@ -796,6 +855,7 @@ impl TaskStore { .await?; // Delete the expired task itself. + super::delete_task_tags(&mut conn, task.id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(task.id) .execute(&mut *conn) @@ -838,7 +898,8 @@ impl TaskStore { return Ok(None); }; - let task = row_to_task_record(&row); + let mut task = row_to_task_record(&row); + task.tags = super::load_task_tags(&mut conn, task.id).await?; insert_history( &mut conn, @@ -850,6 +911,7 @@ impl TaskStore { ) .await?; + super::delete_task_tags(&mut conn, task.id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(task.id) .execute(&mut *conn) @@ -1112,4 +1174,120 @@ mod tests { assert!(store.peek_next().await.unwrap().is_none()); } + + // ── Tag lifecycle tests ─────────────────────────────────────────── + + #[tokio::test] + async fn tags_copied_to_history_on_complete() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("hist-tags-complete") + .tag("env", "staging") + .tag("owner", "alice"); + + store.submit(&sub).await.unwrap(); + let task = store.pop_next().await.unwrap().unwrap(); + store.complete(task.id, &IoBudget::default()).await.unwrap(); + + let hist = store.history_by_key(&sub.effective_key()).await.unwrap(); + assert_eq!(hist.len(), 1); + assert_eq!(hist[0].tags.get("env").unwrap(), "staging"); + assert_eq!(hist[0].tags.get("owner").unwrap(), "alice"); + } + + #[tokio::test] + async fn tags_copied_to_history_on_fail() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("hist-tags-fail") + .tag("region", "us-west"); + + store.submit(&sub).await.unwrap(); + let task = store.pop_next().await.unwrap().unwrap(); + store + .fail(task.id, "boom", false, 0, &IoBudget::default()) + .await + .unwrap(); + + let hist = store.failed_tasks(10).await.unwrap(); + assert_eq!(hist.len(), 1); + assert_eq!(hist[0].tags.get("region").unwrap(), "us-west"); + } + + #[tokio::test] + async fn tags_copied_to_history_on_cancel() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("hist-tags-cancel") + .tag("priority_class", "low"); + + let id = store.submit(&sub).await.unwrap().id().unwrap(); + store.cancel_to_history(id).await.unwrap(); + + let hist = store.history_by_key(&sub.effective_key()).await.unwrap(); + assert_eq!(hist.len(), 1); + assert_eq!(hist[0].status, HistoryStatus::Cancelled); + assert_eq!(hist[0].tags.get("priority_class").unwrap(), "low"); + } + + #[tokio::test] + async fn tags_copied_to_history_on_expire() { + use std::time::Duration; + + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("hist-tags-expire") + .tag("source", "cron") + .ttl(Duration::from_secs(0)); // Expire immediately. + + store.submit(&sub).await.unwrap(); + + // Small delay so expires_at is in the past. + tokio::time::sleep(Duration::from_millis(50)).await; + + let expired = store.expire_tasks().await.unwrap(); + assert!(!expired.is_empty()); + + let hist = store.history_by_key(&sub.effective_key()).await.unwrap(); + assert_eq!(hist.len(), 1); + assert_eq!(hist[0].status, HistoryStatus::Expired); + assert_eq!(hist[0].tags.get("source").unwrap(), "cron"); + } + + #[tokio::test] + async fn tags_preserved_on_recurring_requeue() { + use std::time::Duration; + + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("recurring-tags") + .tag("schedule", "hourly") + .recurring(Duration::from_secs(3600)); + + store.submit(&sub).await.unwrap(); + let task = store.pop_next().await.unwrap().unwrap(); + assert_eq!(task.tags.get("schedule").unwrap(), "hourly"); + + store + .complete_with_record(&task, &IoBudget::default()) + .await + .unwrap(); + + // The next recurring instance should have the same tags. + let key = sub.effective_key(); + let next = store.task_by_key(&key).await.unwrap().unwrap(); + assert_eq!(next.tags.get("schedule").unwrap(), "hourly"); + } + + #[tokio::test] + async fn tags_in_pop_next() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("pop-tags") + .tag("color", "blue"); + + store.submit(&sub).await.unwrap(); + let task = store.pop_next().await.unwrap().unwrap(); + assert_eq!(task.tags.get("color").unwrap(), "blue"); + } } diff --git a/src/store/mod.rs b/src/store/mod.rs index b580e83..29cacdd 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -54,6 +54,8 @@ pub enum StoreError { DependencyFailed(i64), #[error("circular dependency detected")] CyclicDependency, + #[error("invalid tag: {0}")] + InvalidTag(String), } impl From for StoreError { @@ -245,6 +247,8 @@ impl TaskStore { include_str!("../../migrations/006_dependencies.sql"), ) .await?; + Self::run_alter_migration(&self.pool, include_str!("../../migrations/007_tags.sql")) + .await?; Ok(()) } @@ -419,10 +423,55 @@ impl TaskStore { /// Delete a task from the active queue by id. Returns true if a row was deleted. pub async fn delete(&self, id: i64) -> Result { + let mut conn = self.begin_write().await?; + delete_task_tags(&mut conn, id).await?; let result = sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(id) - .execute(&self.pool) + .execute(&mut *conn) .await?; + sqlx::query("COMMIT").execute(&mut *conn).await?; Ok(result.rows_affected() > 0) } } + +/// Delete tags for a task. Called before or after deleting the task row itself. +pub(crate) async fn delete_task_tags( + conn: &mut sqlx::pool::PoolConnection, + task_id: i64, +) -> Result<(), StoreError> { + sqlx::query("DELETE FROM task_tags WHERE task_id = ?") + .bind(task_id) + .execute(&mut **conn) + .await?; + Ok(()) +} + +/// Load tags for a single task within an existing connection/transaction. +pub(crate) async fn load_task_tags( + conn: &mut sqlx::pool::PoolConnection, + task_id: i64, +) -> Result, StoreError> { + let rows: Vec<(String, String)> = + sqlx::query_as("SELECT key, value FROM task_tags WHERE task_id = ?") + .bind(task_id) + .fetch_all(&mut **conn) + .await?; + Ok(rows.into_iter().collect()) +} + +/// Insert tags for a task into the task_tags table. +pub(crate) async fn insert_tags( + conn: &mut sqlx::pool::PoolConnection, + task_id: i64, + tags: &std::collections::HashMap, +) -> Result<(), StoreError> { + for (key, value) in tags { + sqlx::query("INSERT INTO task_tags (task_id, key, value) VALUES (?, ?, ?)") + .bind(task_id) + .bind(key) + .bind(value) + .execute(&mut **conn) + .await?; + } + Ok(()) +} diff --git a/src/store/query.rs b/src/store/query.rs index 09fa0c8..05e8041 100644 --- a/src/store/query.rs +++ b/src/store/query.rs @@ -1,5 +1,7 @@ //! Read-only query methods for the active queue and history tables. +use std::collections::HashMap; + use sqlx::Row; use crate::task::{TaskHistoryRecord, TaskLookup, TaskRecord, TypeStats}; @@ -8,6 +10,66 @@ use super::row_mapping::{row_to_history_record, row_to_task_record}; use super::{StoreError, TaskStore}; impl TaskStore { + // ── Tag population ───────────────────────────────────────────── + + /// Populate tags for a slice of task records from the task_tags table. + pub(crate) async fn populate_tags(&self, records: &mut [TaskRecord]) -> Result<(), StoreError> { + if records.is_empty() { + return Ok(()); + } + let ids: Vec = records.iter().map(|r| r.id).collect(); + let placeholders = ids.iter().map(|_| "?").collect::>().join(","); + let query = + format!("SELECT task_id, key, value FROM task_tags WHERE task_id IN ({placeholders})"); + let mut q = sqlx::query_as::<_, (i64, String, String)>(&query); + for id in &ids { + q = q.bind(id); + } + let tag_rows = q.fetch_all(&self.pool).await?; + + let mut tag_map: HashMap> = HashMap::new(); + for (task_id, key, value) in tag_rows { + tag_map.entry(task_id).or_default().insert(key, value); + } + for record in records { + if let Some(tags) = tag_map.remove(&record.id) { + record.tags = tags; + } + } + Ok(()) + } + + /// Populate tags for a slice of history records from the task_history_tags table. + pub(crate) async fn populate_history_tags( + &self, + records: &mut [TaskHistoryRecord], + ) -> Result<(), StoreError> { + if records.is_empty() { + return Ok(()); + } + let ids: Vec = records.iter().map(|r| r.id).collect(); + let placeholders = ids.iter().map(|_| "?").collect::>().join(","); + let query = format!( + "SELECT history_rowid, key, value FROM task_history_tags WHERE history_rowid IN ({placeholders})" + ); + let mut q = sqlx::query_as::<_, (i64, String, String)>(&query); + for id in &ids { + q = q.bind(id); + } + let tag_rows = q.fetch_all(&self.pool).await?; + + let mut tag_map: HashMap> = HashMap::new(); + for (history_rowid, key, value) in tag_rows { + tag_map.entry(history_rowid).or_default().insert(key, value); + } + for record in records { + if let Some(tags) = tag_map.remove(&record.id) { + record.tags = tags; + } + } + Ok(()) + } + // ── Query: active queue ───────────────────────────────────────── /// All currently running tasks. @@ -17,7 +79,9 @@ impl TaskStore { ) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// Count of running tasks. @@ -36,7 +100,9 @@ impl TaskStore { .bind(limit) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// Count of pending tasks. @@ -55,7 +121,9 @@ impl TaskStore { .bind(task_type) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// Count of paused tasks. @@ -73,7 +141,9 @@ impl TaskStore { ) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// Look up an active task by its row id. Returns `None` if no active @@ -83,7 +153,11 @@ impl TaskStore { .bind(id) .fetch_optional(&self.pool) .await?; - Ok(row.as_ref().map(row_to_task_record)) + let mut record = row.as_ref().map(row_to_task_record); + if let Some(ref mut r) = record { + self.populate_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Look up an active task by its dedup key. Returns `None` if no active @@ -93,7 +167,11 @@ impl TaskStore { .bind(key) .fetch_optional(&self.pool) .await?; - Ok(row.as_ref().map(row_to_task_record)) + let mut record = row.as_ref().map(row_to_task_record); + if let Some(ref mut r) = record { + self.populate_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Sum of expected read/write bytes for all running tasks. @@ -136,7 +214,11 @@ impl TaskStore { .bind(id) .fetch_optional(&self.pool) .await?; - Ok(row.as_ref().map(row_to_history_record)) + let mut record = row.as_ref().map(row_to_history_record); + if let Some(ref mut r) = record { + self.populate_history_tags(std::slice::from_mut(r)).await?; + } + Ok(record) } /// Recent history entries, newest first. @@ -151,7 +233,9 @@ impl TaskStore { .bind(offset) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_history_record).collect()) + let mut records: Vec = rows.iter().map(row_to_history_record).collect(); + self.populate_history_tags(&mut records).await?; + Ok(records) } /// History filtered by task type. @@ -167,7 +251,9 @@ impl TaskStore { .bind(limit) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_history_record).collect()) + let mut records: Vec = rows.iter().map(row_to_history_record).collect(); + self.populate_history_tags(&mut records).await?; + Ok(records) } /// History for a specific key (all past runs of that key). @@ -177,7 +263,9 @@ impl TaskStore { .bind(key) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_history_record).collect()) + let mut records: Vec = rows.iter().map(row_to_history_record).collect(); + self.populate_history_tags(&mut records).await?; + Ok(records) } /// Failed tasks from history. @@ -188,7 +276,9 @@ impl TaskStore { .bind(limit) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_history_record).collect()) + let mut records: Vec = rows.iter().map(row_to_history_record).collect(); + self.populate_history_tags(&mut records).await?; + Ok(records) } /// Aggregate stats for a task type from completed history. @@ -252,6 +342,7 @@ impl TaskStore { /// or [`TaskSubmission::effective_key`]). pub async fn task_lookup(&self, key: &str) -> Result { // Check active queue first (pending / running / paused). + // task_by_key already populates tags. if let Some(record) = self.task_by_key(key).await? { return Ok(TaskLookup::Active(record)); } @@ -265,7 +356,12 @@ impl TaskStore { .await?; match row { - Some(r) => Ok(TaskLookup::History(row_to_history_record(&r))), + Some(r) => { + let mut hist = row_to_history_record(&r); + self.populate_history_tags(std::slice::from_mut(&mut hist)) + .await?; + Ok(TaskLookup::History(hist)) + } None => Ok(TaskLookup::NotFound), } } @@ -287,7 +383,9 @@ impl TaskStore { ) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// All active tasks in a specific group. @@ -296,7 +394,9 @@ impl TaskStore { .bind(group_key) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// All active tasks of a specific type. @@ -305,7 +405,9 @@ impl TaskStore { .bind(task_type) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// All active tasks (any status). @@ -313,7 +415,122 @@ impl TaskStore { let rows = sqlx::query("SELECT * FROM tasks ORDER BY id ASC") .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) + } + + // ── Tag-based queries ─────────────────────────────────────────── + + /// Find active tasks matching all specified tag filters (AND semantics). + /// + /// Each `(key, value)` pair adds an INNER JOIN, so only tasks matching + /// **all** filters are returned. Optionally filter by status. + pub async fn tasks_by_tags( + &self, + filters: &[(&str, &str)], + status: Option, + ) -> Result, StoreError> { + if filters.is_empty() { + return Ok(Vec::new()); + } + + let mut sql = String::from("SELECT t.* FROM tasks t"); + for (i, _) in filters.iter().enumerate() { + sql.push_str(&format!( + " INNER JOIN task_tags tt{i} ON t.id = tt{i}.task_id AND tt{i}.key = ? AND tt{i}.value = ?" + )); + } + if let Some(ref s) = status { + sql.push_str(&format!(" WHERE t.status = '{}'", s.as_str())); + } + sql.push_str(" ORDER BY t.priority ASC, t.id ASC"); + + let mut q = sqlx::query(&sql); + for (key, value) in filters { + q = q.bind(key).bind(value); + } + let rows = q.fetch_all(&self.pool).await?; + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) + } + + /// Count active tasks matching all specified tag filters (AND semantics). + pub async fn count_by_tags( + &self, + filters: &[(&str, &str)], + status: Option, + ) -> Result { + if filters.is_empty() { + return Ok(0); + } + + let mut sql = String::from("SELECT COUNT(*) FROM tasks t"); + for (i, _) in filters.iter().enumerate() { + sql.push_str(&format!( + " INNER JOIN task_tags tt{i} ON t.id = tt{i}.task_id AND tt{i}.key = ? AND tt{i}.value = ?" + )); + } + if let Some(ref s) = status { + sql.push_str(&format!(" WHERE t.status = '{}'", s.as_str())); + } + + let mut q = sqlx::query_as::<_, (i64,)>(&sql); + for (key, value) in filters { + q = q.bind(key).bind(value); + } + let (count,) = q.fetch_one(&self.pool).await?; + Ok(count) + } + + /// List distinct values for a tag key across active tasks, with counts. + /// + /// Returns `(value, count)` pairs sorted by count descending. + pub async fn tag_values(&self, key: &str) -> Result, StoreError> { + let rows: Vec<(String, i64)> = sqlx::query_as( + "SELECT value, COUNT(*) as cnt FROM task_tags WHERE key = ? GROUP BY value ORDER BY cnt DESC", + ) + .bind(key) + .fetch_all(&self.pool) + .await?; + Ok(rows) + } + + /// Count active tasks grouped by a tag key's values. + /// + /// Returns `(tag_value, count)` pairs sorted by count descending. + /// Optionally filter by task status. + pub async fn count_by_tag( + &self, + key: &str, + status: Option, + ) -> Result, StoreError> { + let (sql, bind_status) = match status { + Some(ref s) => ( + "SELECT tt.value, COUNT(*) as cnt FROM task_tags tt \ + JOIN tasks t ON t.id = tt.task_id \ + WHERE tt.key = ? AND t.status = ? \ + GROUP BY tt.value ORDER BY cnt DESC" + .to_string(), + Some(s.as_str()), + ), + None => ( + "SELECT tt.value, COUNT(*) as cnt FROM task_tags tt \ + JOIN tasks t ON t.id = tt.task_id \ + WHERE tt.key = ? \ + GROUP BY tt.value ORDER BY cnt DESC" + .to_string(), + None, + ), + }; + + let mut q = sqlx::query_as::<_, (String, i64)>(&sql).bind(key); + if let Some(status_str) = bind_status { + q = q.bind(status_str); + } + let rows = q.fetch_all(&self.pool).await?; + Ok(rows) } // ── Scheduling ───────────────────────────────────────────────── @@ -392,7 +609,9 @@ impl TaskStore { ) .fetch_all(&self.pool) .await?; - Ok(rows.iter().map(row_to_task_record).collect()) + let mut records: Vec = rows.iter().map(row_to_task_record).collect(); + self.populate_tags(&mut records).await?; + Ok(records) } /// Count of blocked tasks. @@ -598,4 +817,146 @@ mod tests { let hist = store.history(100, 0).await.unwrap(); assert_eq!(hist.len(), 3); } + + // ── Tag query tests ─────────────────────────────────────────────── + + #[tokio::test] + async fn tasks_by_tags_single_filter() { + let store = test_store().await; + + store + .submit(&TaskSubmission::new("test").key("tbt-1").tag("env", "prod")) + .await + .unwrap(); + store + .submit( + &TaskSubmission::new("test") + .key("tbt-2") + .tag("env", "staging"), + ) + .await + .unwrap(); + store + .submit(&TaskSubmission::new("test").key("tbt-3").tag("env", "prod")) + .await + .unwrap(); + + let results = store.tasks_by_tags(&[("env", "prod")], None).await.unwrap(); + assert_eq!(results.len(), 2); + + let results = store + .tasks_by_tags(&[("env", "staging")], None) + .await + .unwrap(); + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn tasks_by_tags_multiple_filters_and() { + let store = test_store().await; + + store + .submit( + &TaskSubmission::new("test") + .key("multi-1") + .tag("env", "prod") + .tag("region", "us"), + ) + .await + .unwrap(); + store + .submit( + &TaskSubmission::new("test") + .key("multi-2") + .tag("env", "prod") + .tag("region", "eu"), + ) + .await + .unwrap(); + store + .submit( + &TaskSubmission::new("test") + .key("multi-3") + .tag("env", "staging") + .tag("region", "us"), + ) + .await + .unwrap(); + + // AND semantics: only task matching both filters. + let results = store + .tasks_by_tags(&[("env", "prod"), ("region", "us")], None) + .await + .unwrap(); + assert_eq!(results.len(), 1); + + // With status filter. + let results = store + .tasks_by_tags(&[("env", "prod")], Some(TaskStatus::Pending)) + .await + .unwrap(); + assert_eq!(results.len(), 2); + } + + #[tokio::test] + async fn count_by_tag_groups() { + let store = test_store().await; + + for i in 0..3 { + store + .submit( + &TaskSubmission::new("test") + .key(format!("free-{i}")) + .tag("tier", "free"), + ) + .await + .unwrap(); + } + for i in 0..2 { + store + .submit( + &TaskSubmission::new("test") + .key(format!("pro-{i}")) + .tag("tier", "pro"), + ) + .await + .unwrap(); + } + + let groups = store.count_by_tag("tier", None).await.unwrap(); + assert_eq!(groups.len(), 2); + // Sorted by count descending. + assert_eq!(groups[0].0, "free"); + assert_eq!(groups[0].1, 3); + assert_eq!(groups[1].0, "pro"); + assert_eq!(groups[1].1, 2); + } + + #[tokio::test] + async fn tag_values_distinct() { + let store = test_store().await; + + store + .submit(&TaskSubmission::new("test").key("tv-1").tag("color", "red")) + .await + .unwrap(); + store + .submit(&TaskSubmission::new("test").key("tv-2").tag("color", "red")) + .await + .unwrap(); + store + .submit(&TaskSubmission::new("test").key("tv-3").tag("color", "blue")) + .await + .unwrap(); + + let values = store.tag_values("color").await.unwrap(); + assert_eq!(values.len(), 2); + // Sorted by count descending. + assert_eq!(values[0], ("red".to_string(), 2)); + assert_eq!(values[1], ("blue".to_string(), 1)); + + // Non-existent key returns empty. + let empty = store.tag_values("nonexistent").await.unwrap(); + assert!(empty.is_empty()); + } } diff --git a/src/store/row_mapping.rs b/src/store/row_mapping.rs index 8f1a80f..ba20a2a 100644 --- a/src/store/row_mapping.rs +++ b/src/store/row_mapping.rs @@ -69,6 +69,8 @@ pub(crate) fn row_to_task_record(row: &sqlx::sqlite::SqliteRow) -> TaskRecord { on_dependency_failure: on_dep_failure_str .parse() .unwrap_or(DependencyFailurePolicy::Cancel), + // Tags are populated separately from the task_tags table. + tags: std::collections::HashMap::new(), } } @@ -125,5 +127,7 @@ pub(crate) fn row_to_history_record(row: &sqlx::sqlite::SqliteRow) -> TaskHistor ttl_from: ttl_from_str.parse().unwrap_or(TtlFrom::Submission), expires_at: expires_at_str.map(|s| parse_datetime(&s)), run_after: run_after_str.map(|s| parse_datetime(&s)), + // Tags are populated separately from the task_history_tags table. + tags: std::collections::HashMap::new(), } } diff --git a/src/store/submit.rs b/src/store/submit.rs index aca485c..4eaf06e 100644 --- a/src/store/submit.rs +++ b/src/store/submit.rs @@ -7,7 +7,7 @@ use sqlx::Row; use crate::task::{ DependencyFailurePolicy, DuplicateStrategy, SubmitOutcome, TaskSubmission, TtlFrom, - MAX_PAYLOAD_BYTES, + MAX_PAYLOAD_BYTES, MAX_TAGS_PER_TASK, MAX_TAG_KEY_LEN, MAX_TAG_VALUE_LEN, }; use super::row_mapping::row_to_task_record; @@ -18,6 +18,31 @@ use super::{StoreError, TaskStore}; /// lock for too long. const BATCH_CHUNK_SIZE: usize = 10_000; +/// Validate tag constraints: key length, value length, max count. +fn validate_tags(tags: &HashMap) -> Result<(), StoreError> { + if tags.len() > MAX_TAGS_PER_TASK { + return Err(StoreError::InvalidTag(format!( + "too many tags: {} > {MAX_TAGS_PER_TASK}", + tags.len() + ))); + } + for (k, v) in tags { + if k.len() > MAX_TAG_KEY_LEN { + return Err(StoreError::InvalidTag(format!( + "tag key too long: {} > {MAX_TAG_KEY_LEN}", + k.len() + ))); + } + if v.len() > MAX_TAG_VALUE_LEN { + return Err(StoreError::InvalidTag(format!( + "tag value too long: {} > {MAX_TAG_VALUE_LEN}", + v.len() + ))); + } + } + Ok(()) +} + /// Core dedup logic for a single task submission within an existing connection. /// /// Performs the three-step dedup: INSERT OR IGNORE → upgrade priority on @@ -31,6 +56,8 @@ pub(crate) async fn submit_one( return Err(StoreError::Serialization(err.clone())); } + validate_tags(&sub.tags)?; + let key = sub.effective_key(); let priority = sub.priority.value() as i32; let fail_fast_val: i32 = if sub.fail_fast { 1 } else { 0 }; @@ -95,6 +122,9 @@ pub(crate) async fn submit_one( if result.rows_affected() > 0 { let task_id = result.last_insert_rowid(); + // Insert tags. + super::insert_tags(conn, task_id, &sub.tags).await?; + // Handle dependencies if any. if !sub.dependencies.is_empty() { // First resolve which deps are active (need edges) vs already @@ -256,6 +286,10 @@ pub(crate) async fn supersede_existing( .execute(&mut **conn) .await?; + // Replace tags: delete old, insert new. + super::delete_task_tags(conn, replaced_id).await?; + super::insert_tags(conn, replaced_id, &sub.tags).await?; + Ok(SubmitOutcome::Superseded { new_task_id: replaced_id, replaced_task_id: replaced_id, @@ -263,6 +297,7 @@ pub(crate) async fn supersede_existing( } crate::task::TaskStatus::Running | crate::task::TaskStatus::Waiting => { // Delete existing and insert new. + super::delete_task_tags(conn, replaced_id).await?; sqlx::query("DELETE FROM tasks WHERE id = ?") .bind(replaced_id) .execute(&mut **conn) @@ -293,8 +328,11 @@ pub(crate) async fn supersede_existing( .execute(&mut **conn) .await?; + let new_task_id = result.last_insert_rowid(); + super::insert_tags(conn, new_task_id, &sub.tags).await?; + Ok(SubmitOutcome::Superseded { - new_task_id: result.last_insert_rowid(), + new_task_id, replaced_task_id: replaced_id, }) } @@ -421,6 +459,7 @@ impl TaskStore { return Err(StoreError::PayloadTooLarge); } } + validate_tags(&sub.tags)?; let mut conn = self.begin_write().await?; tracing::debug!(task_type = %sub.task_type, "store.submit: INSERT start"); @@ -449,7 +488,7 @@ impl TaskStore { &self, submissions: &[TaskSubmission], ) -> Result, StoreError> { - // Pre-validate all payloads before starting the transaction + // Pre-validate all payloads and tags before starting the transaction // to avoid partial inserts on validation errors. for sub in submissions { if let Some(ref p) = sub.payload { @@ -457,6 +496,7 @@ impl TaskStore { return Err(StoreError::PayloadTooLarge); } } + validate_tags(&sub.tags)?; } // Intra-batch dedup: last-wins. Map each effective key to its last @@ -783,4 +823,176 @@ mod tests { let count = store.pending_count().await.unwrap(); assert_eq!(count, 0); } + + // ── Tag tests ───────────────────────────────────────────────────── + + #[tokio::test] + async fn submit_with_tags() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("tagged-1") + .tag("profile", "default") + .tag("source", "upload"); + + let outcome = store.submit(&sub).await.unwrap(); + let id = outcome.id().unwrap(); + + let task = store.task_by_id(id).await.unwrap().unwrap(); + assert_eq!(task.tags.len(), 2); + assert_eq!(task.tags.get("profile").unwrap(), "default"); + assert_eq!(task.tags.get("source").unwrap(), "upload"); + } + + #[tokio::test] + async fn submit_batch_with_tags() { + let store = test_store().await; + let subs: Vec<_> = (0..3) + .map(|i| { + TaskSubmission::new("test") + .key(format!("batch-tag-{i}")) + .tag("batch", "true") + .tag("index", i.to_string()) + }) + .collect(); + + let results = store.submit_batch(&subs).await.unwrap(); + assert!(results.iter().all(|r| r.is_inserted())); + + for (i, result) in results.iter().enumerate() { + let task = store + .task_by_id(result.id().unwrap()) + .await + .unwrap() + .unwrap(); + assert_eq!(task.tags.get("batch").unwrap(), "true"); + assert_eq!(task.tags.get("index").unwrap(), &i.to_string()); + } + } + + #[tokio::test] + async fn submit_with_default_tags() { + use crate::task::BatchSubmission; + + let store = test_store().await; + let subs = BatchSubmission::new() + .default_tag("env", "prod") + .default_tag("region", "us-east") + .task(TaskSubmission::new("test").key("dt-1")) + .task( + TaskSubmission::new("test") + .key("dt-2") + .tag("region", "eu-west"), + ) + .build(); + + let results = store.submit_batch(&subs).await.unwrap(); + + // First task gets both defaults. + let t1 = store + .task_by_id(results[0].id().unwrap()) + .await + .unwrap() + .unwrap(); + assert_eq!(t1.tags.get("env").unwrap(), "prod"); + assert_eq!(t1.tags.get("region").unwrap(), "us-east"); + + // Second task overrides "region" but inherits "env". + let t2 = store + .task_by_id(results[1].id().unwrap()) + .await + .unwrap() + .unwrap(); + assert_eq!(t2.tags.get("env").unwrap(), "prod"); + assert_eq!(t2.tags.get("region").unwrap(), "eu-west"); + } + + #[tokio::test] + async fn tags_validation_key_too_long() { + use crate::store::StoreError; + use crate::task::MAX_TAG_KEY_LEN; + + let store = test_store().await; + let long_key = "x".repeat(MAX_TAG_KEY_LEN + 1); + let sub = TaskSubmission::new("test") + .key("bad-key") + .tag(long_key, "value"); + + let err = store.submit(&sub).await.unwrap_err(); + assert!(matches!(err, StoreError::InvalidTag(_))); + } + + #[tokio::test] + async fn tags_validation_value_too_long() { + use crate::store::StoreError; + use crate::task::MAX_TAG_VALUE_LEN; + + let store = test_store().await; + let long_val = "x".repeat(MAX_TAG_VALUE_LEN + 1); + let sub = TaskSubmission::new("test") + .key("bad-val") + .tag("key", long_val); + + let err = store.submit(&sub).await.unwrap_err(); + assert!(matches!(err, StoreError::InvalidTag(_))); + } + + #[tokio::test] + async fn tags_validation_too_many() { + use crate::store::StoreError; + use crate::task::MAX_TAGS_PER_TASK; + + let store = test_store().await; + let mut sub = TaskSubmission::new("test").key("too-many"); + for i in 0..=MAX_TAGS_PER_TASK { + sub = sub.tag(format!("key-{i}"), "value"); + } + + let err = store.submit(&sub).await.unwrap_err(); + assert!(matches!(err, StoreError::InvalidTag(_))); + } + + #[tokio::test] + async fn tags_preserved_on_supersede() { + use crate::task::DuplicateStrategy; + + let store = test_store().await; + let sub1 = TaskSubmission::new("test") + .key("supersede-tags") + .tag("version", "1") + .on_duplicate(DuplicateStrategy::Supersede); + let id1 = store.submit(&sub1).await.unwrap().id().unwrap(); + + let t1 = store.task_by_id(id1).await.unwrap().unwrap(); + assert_eq!(t1.tags.get("version").unwrap(), "1"); + + // Supersede with new tags. + let sub2 = TaskSubmission::new("test") + .key("supersede-tags") + .tag("version", "2") + .tag("extra", "yes") + .on_duplicate(DuplicateStrategy::Supersede); + let outcome = store.submit(&sub2).await.unwrap(); + let id2 = outcome.id().unwrap(); + + let t2 = store.task_by_id(id2).await.unwrap().unwrap(); + assert_eq!(t2.tags.get("version").unwrap(), "2"); + assert_eq!(t2.tags.get("extra").unwrap(), "yes"); + } + + #[tokio::test] + async fn tags_dedup_no_change() { + let store = test_store().await; + let sub = TaskSubmission::new("test") + .key("dedup-tags") + .tag("env", "prod"); + + store.submit(&sub).await.unwrap(); + let outcome = store.submit(&sub).await.unwrap(); + assert_eq!(outcome, SubmitOutcome::Duplicate); + + // Tags should still be intact on the original task. + let key = sub.effective_key(); + let task = store.task_by_key(&key).await.unwrap().unwrap(); + assert_eq!(task.tags.get("env").unwrap(), "prod"); + } } diff --git a/src/task/mod.rs b/src/task/mod.rs index e9cbbc1..22d7e3e 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -35,6 +35,8 @@ mod submission; mod tests; pub mod typed; +use std::collections::HashMap; + use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -44,7 +46,7 @@ pub use dedup::{generate_dedup_key, MAX_PAYLOAD_BYTES}; pub use error::TaskError; pub use submission::{ BatchOutcome, BatchSubmission, DependencyFailurePolicy, DuplicateStrategy, RecurringSchedule, - SubmitOutcome, TaskSubmission, + SubmitOutcome, TaskSubmission, MAX_TAGS_PER_TASK, MAX_TAG_KEY_LEN, MAX_TAG_VALUE_LEN, }; pub use typed::TypedTask; @@ -220,6 +222,8 @@ pub struct TaskRecord { pub dependencies: Vec, /// What happens when a dependency fails. pub on_dependency_failure: DependencyFailurePolicy, + /// Key-value metadata tags for filtering, grouping, and display. + pub tags: HashMap, } impl TaskRecord { @@ -242,6 +246,7 @@ impl TaskRecord { task_type: self.task_type.clone(), key: self.key.clone(), label: self.label.clone(), + tags: self.tags.clone(), } } } @@ -281,6 +286,8 @@ pub struct TaskHistoryRecord { pub expires_at: Option>, /// Delayed dispatch timestamp at submission time (diagnostic). pub run_after: Option>, + /// Key-value metadata tags for filtering, grouping, and display. + pub tags: HashMap, } /// IO budget for a task: expected or actual disk and network IO bytes. diff --git a/src/task/submission.rs b/src/task/submission.rs index fd2f261..60b7e0c 100644 --- a/src/task/submission.rs +++ b/src/task/submission.rs @@ -14,6 +14,7 @@ //! A task with dependencies enters [`Blocked`](crate::TaskStatus::Blocked) status //! and transitions to `Pending` only after all dependencies complete successfully. +use std::collections::HashMap; use std::time::Duration; use chrono::{DateTime, Utc}; @@ -25,6 +26,13 @@ use super::dedup::generate_dedup_key; use super::typed::TypedTask; use super::{IoBudget, TtlFrom}; +/// Maximum length of a tag key in bytes. +pub const MAX_TAG_KEY_LEN: usize = 64; +/// Maximum length of a tag value in bytes. +pub const MAX_TAG_VALUE_LEN: usize = 256; +/// Maximum number of tags per task. +pub const MAX_TAGS_PER_TASK: usize = 32; + /// Configuration for recurring task schedules. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RecurringSchedule { @@ -252,6 +260,7 @@ pub struct BatchSubmission { default_group: Option, default_priority: Option, default_ttl: Option, + default_tags: HashMap, tasks: Vec, } @@ -262,6 +271,7 @@ impl BatchSubmission { default_group: None, default_priority: None, default_ttl: None, + default_tags: HashMap::new(), tasks: Vec::new(), } } @@ -284,6 +294,12 @@ impl BatchSubmission { self } + /// Set a default tag applied to tasks that don't already have this key. + pub fn default_tag(mut self, key: impl Into, value: impl Into) -> Self { + self.default_tags.insert(key.into(), value.into()); + self + } + /// Add a single task to the batch. pub fn task(mut self, sub: TaskSubmission) -> Self { self.tasks.push(sub); @@ -317,6 +333,9 @@ impl BatchSubmission { task.ttl = Some(ttl); } } + for (k, v) in &self.default_tags { + task.tags.entry(k.clone()).or_insert_with(|| v.clone()); + } } self.tasks } @@ -340,7 +359,9 @@ impl Default for BatchSubmission { /// .key("img-001") /// .priority(Priority::HIGH) /// .payload_json(&my_payload)? -/// .expected_io(IoBudget::disk(4096, 1024)); +/// .expected_io(IoBudget::disk(4096, 1024)) +/// .tag("profile", "default") +/// .tag("source", "upload"); /// ``` /// /// For strongly-typed tasks, prefer [`TaskSubmission::from_typed`] or @@ -399,6 +420,11 @@ pub struct TaskSubmission { /// What happens when a dependency fails. Default: [`DependencyFailurePolicy::Cancel`]. #[serde(default)] pub on_dependency_failure: DependencyFailurePolicy, + /// Key-value metadata tags for filtering, grouping, and display. + /// Immutable after submission. Validated against [`MAX_TAG_KEY_LEN`], + /// [`MAX_TAG_VALUE_LEN`], and [`MAX_TAGS_PER_TASK`] at submit time. + #[serde(default)] + pub tags: HashMap, } impl TaskSubmission { @@ -437,6 +463,7 @@ impl TaskSubmission { recurring: None, dependencies: Vec::new(), on_dependency_failure: DependencyFailurePolicy::default(), + tags: HashMap::new(), } } @@ -590,6 +617,22 @@ impl TaskSubmission { self } + /// Add a single metadata tag (key-value pair). + /// + /// Tags are schema-free metadata for filtering and grouping. They are + /// validated at submit time against [`MAX_TAG_KEY_LEN`], + /// [`MAX_TAG_VALUE_LEN`], and [`MAX_TAGS_PER_TASK`]. + pub fn tag(mut self, key: impl Into, value: impl Into) -> Self { + self.tags.insert(key.into(), value.into()); + self + } + + /// Set all metadata tags at once, replacing any previously set tags. + pub fn tags(mut self, tags: HashMap) -> Self { + self.tags = tags; + self + } + /// Make this a recurring task with full schedule control. pub fn recurring_schedule(mut self, schedule: RecurringSchedule) -> Self { if let Some(delay) = schedule.initial_delay { @@ -641,6 +684,10 @@ impl TaskSubmission { if let Some(sched) = task.recurring() { sub = sub.recurring_schedule(sched); } + let task_tags = task.tags(); + if !task_tags.is_empty() { + sub = sub.tags(task_tags); + } sub } } diff --git a/src/task/tests.rs b/src/task/tests.rs index 166c158..4a8f96b 100644 --- a/src/task/tests.rs +++ b/src/task/tests.rs @@ -182,3 +182,71 @@ fn batch_submission_builder_no_defaults() { assert!(subs[0].group_key.is_none()); assert_eq!(subs[0].priority, Priority::NORMAL); } + +#[test] +fn typed_task_with_tags() { + use std::collections::HashMap; + + #[derive(Serialize, Deserialize)] + struct TaggedTask { + profile: String, + } + + impl TypedTask for TaggedTask { + const TASK_TYPE: &'static str = "tagged"; + + fn tags(&self) -> HashMap { + HashMap::from([("profile".into(), self.profile.clone())]) + } + } + + let task = TaggedTask { + profile: "default".into(), + }; + let sub = TaskSubmission::from_typed(&task); + assert_eq!(sub.tags.get("profile").unwrap(), "default"); +} + +#[test] +fn event_header_includes_tags() { + use std::collections::HashMap; + + let mut record = super::TaskRecord { + id: 42, + task_type: "test".into(), + key: "abc".into(), + label: "Test task".into(), + priority: Priority::NORMAL, + status: super::TaskStatus::Running, + payload: None, + expected_io: IoBudget::default(), + retry_count: 0, + last_error: None, + created_at: chrono::Utc::now(), + started_at: Some(chrono::Utc::now()), + parent_id: None, + fail_fast: true, + requeue: false, + requeue_priority: None, + group_key: None, + ttl_seconds: None, + ttl_from: super::TtlFrom::Submission, + expires_at: None, + run_after: None, + recurring_interval_secs: None, + recurring_max_executions: None, + recurring_execution_count: 0, + recurring_paused: false, + tags: HashMap::new(), + dependencies: Vec::new(), + on_dependency_failure: super::submission::DependencyFailurePolicy::Cancel, + }; + record.tags.insert("env".into(), "prod".into()); + record.tags.insert("owner".into(), "alice".into()); + + let header = record.event_header(); + assert_eq!(header.task_id, 42); + assert_eq!(header.tags.get("env").unwrap(), "prod"); + assert_eq!(header.tags.get("owner").unwrap(), "alice"); + assert_eq!(header.tags.len(), 2); +} diff --git a/src/task/typed.rs b/src/task/typed.rs index cbf10ab..386a2a9 100644 --- a/src/task/typed.rs +++ b/src/task/typed.rs @@ -1,5 +1,6 @@ //! The [`TypedTask`] trait for strongly-typed task payloads. +use std::collections::HashMap; use std::time::Duration; use serde::de::DeserializeOwned; @@ -23,15 +24,19 @@ use super::{IoBudget, TtlFrom}; /// # Example /// /// ```ignore +/// use std::collections::HashMap; /// use serde::{Serialize, Deserialize}; /// use taskmill::{TypedTask, IoBudget, Priority}; /// /// #[derive(Serialize, Deserialize)] -/// struct Thumbnail { path: String, size: u32 } +/// struct Thumbnail { path: String, size: u32, profile: String } /// /// impl TypedTask for Thumbnail { /// const TASK_TYPE: &'static str = "thumbnail"; /// fn expected_io(&self) -> IoBudget { IoBudget::disk(4096, 1024) } +/// fn tags(&self) -> HashMap { +/// HashMap::from([("profile".into(), self.profile.clone())]) +/// } /// } /// ``` pub trait TypedTask: Serialize + DeserializeOwned + Send + 'static { @@ -87,4 +92,9 @@ pub trait TypedTask: Serialize + DeserializeOwned + Send + 'static { fn recurring(&self) -> Option { None } + + /// Metadata tags for filtering and grouping. Default: empty. + fn tags(&self) -> HashMap { + HashMap::new() + } }