diff --git a/magicblock-api/src/magic_validator.rs b/magicblock-api/src/magic_validator.rs index c4da8a53b..dc6c9c297 100644 --- a/magicblock-api/src/magic_validator.rs +++ b/magicblock-api/src/magic_validator.rs @@ -699,7 +699,7 @@ impl MagicValidator { .take() .expect("task_scheduler should be initialized"); tokio::spawn(async move { - let join_handle = match task_scheduler.start() { + let join_handle = match task_scheduler.start().await { Ok(join_handle) => join_handle, Err(err) => { error!("Failed to start task scheduler: {:?}", err); diff --git a/magicblock-task-scheduler/src/db.rs b/magicblock-task-scheduler/src/db.rs index a39fb23f2..d69aa63e4 100644 --- a/magicblock-task-scheduler/src/db.rs +++ b/magicblock-task-scheduler/src/db.rs @@ -4,6 +4,7 @@ use chrono::Utc; use magicblock_program::args::ScheduleTaskRequest; use rusqlite::{params, Connection}; use solana_sdk::{instruction::Instruction, pubkey::Pubkey}; +use tokio::sync::Mutex; use crate::errors::TaskSchedulerError; @@ -55,7 +56,7 @@ pub struct FailedTask { } pub struct SchedulerDatabase { - conn: Connection, + conn: Mutex, } impl SchedulerDatabase { @@ -101,15 +102,20 @@ impl SchedulerDatabase { [], )?; - Ok(Self { conn }) + Ok(Self { + conn: Mutex::new(conn), + }) } - pub fn insert_task(&self, task: &DbTask) -> Result<(), TaskSchedulerError> { + pub async fn insert_task( + &self, + task: &DbTask, + ) -> Result<(), TaskSchedulerError> { let instructions_bin = bincode::serialize(&task.instructions)?; let authority_str = task.authority.to_string(); let now = Utc::now().timestamp_millis(); - self.conn.execute( + self.conn.lock().await.execute( "INSERT OR REPLACE INTO tasks (id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", @@ -128,14 +134,14 @@ impl SchedulerDatabase { Ok(()) } - pub fn update_task_after_execution( + pub async fn update_task_after_execution( &self, task_id: i64, last_execution: i64, ) -> Result<(), TaskSchedulerError> { let now = Utc::now().timestamp_millis(); - self.conn.execute( + self.conn.lock().await.execute( "UPDATE tasks SET executions_left = executions_left - 1, last_execution_millis = ?, @@ -147,12 +153,12 @@ impl SchedulerDatabase { Ok(()) } - pub fn insert_failed_scheduling( + pub async fn insert_failed_scheduling( &self, task_id: i64, error: String, ) -> Result<(), TaskSchedulerError> { - self.conn.execute( + self.conn.lock().await.execute( "INSERT INTO failed_scheduling (timestamp, task_id, error) VALUES (?, ?, ?)", params![Utc::now().timestamp_millis(), task_id, error], )?; @@ -160,12 +166,12 @@ impl SchedulerDatabase { Ok(()) } - pub fn insert_failed_task( + pub async fn insert_failed_task( &self, task_id: i64, error: String, ) -> Result<(), TaskSchedulerError> { - self.conn.execute( + self.conn.lock().await.execute( "INSERT INTO failed_tasks (timestamp, task_id, error) VALUES (?, ?, ?)", params![Utc::now().timestamp_millis(), task_id, error], )?; @@ -173,11 +179,11 @@ impl SchedulerDatabase { Ok(()) } - pub fn unschedule_task( + pub async fn unschedule_task( &self, task_id: i64, ) -> Result<(), TaskSchedulerError> { - self.conn.execute( + self.conn.lock().await.execute( "UPDATE tasks SET executions_left = 0 WHERE id = ?", [task_id], )?; @@ -185,18 +191,24 @@ impl SchedulerDatabase { Ok(()) } - pub fn remove_task(&self, task_id: i64) -> Result<(), TaskSchedulerError> { + pub async fn remove_task( + &self, + task_id: i64, + ) -> Result<(), TaskSchedulerError> { self.conn + .lock() + .await .execute("DELETE FROM tasks WHERE id = ?", [task_id])?; Ok(()) } - pub fn get_task( + pub async fn get_task( &self, task_id: i64, ) -> Result, TaskSchedulerError> { - let mut stmt = self.conn.prepare( + let db = self.conn.lock().await; + let mut stmt = db.prepare( "SELECT id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis FROM tasks WHERE id = ?" )?; @@ -231,8 +243,9 @@ impl SchedulerDatabase { Ok(rows.next().transpose()?) } - pub fn get_tasks(&self) -> Result, TaskSchedulerError> { - let mut stmt = self.conn.prepare( + pub async fn get_tasks(&self) -> Result, TaskSchedulerError> { + let db = self.conn.lock().await; + let mut stmt = db.prepare( "SELECT id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis FROM tasks" )?; @@ -272,8 +285,9 @@ impl SchedulerDatabase { Ok(tasks) } - pub fn get_task_ids(&self) -> Result, TaskSchedulerError> { - let mut stmt = self.conn.prepare( + pub async fn get_task_ids(&self) -> Result, TaskSchedulerError> { + let db = self.conn.lock().await; + let mut stmt = db.prepare( "SELECT id FROM tasks", )?; @@ -283,10 +297,11 @@ impl SchedulerDatabase { Ok(rows.collect::, rusqlite::Error>>()?) } - pub fn get_failed_schedulings( + pub async fn get_failed_schedulings( &self, ) -> Result, TaskSchedulerError> { - let mut stmt = self.conn.prepare( + let db = self.conn.lock().await; + let mut stmt = db.prepare( "SELECT * FROM failed_scheduling", )?; @@ -303,10 +318,11 @@ impl SchedulerDatabase { Ok(rows.collect::, rusqlite::Error>>()?) } - pub fn get_failed_tasks( + pub async fn get_failed_tasks( &self, ) -> Result, TaskSchedulerError> { - let mut stmt = self.conn.prepare( + let db = self.conn.lock().await; + let mut stmt = db.prepare( "SELECT * FROM failed_tasks", )?; diff --git a/magicblock-task-scheduler/src/service.rs b/magicblock-task-scheduler/src/service.rs index 7fe432e97..bb8ddb0ac 100644 --- a/magicblock-task-scheduler/src/service.rs +++ b/magicblock-task-scheduler/src/service.rs @@ -99,10 +99,10 @@ impl TaskSchedulerService { }) } - pub fn start( + pub async fn start( mut self, ) -> TaskSchedulerResult>> { - let tasks = self.db.get_tasks()?; + let tasks = self.db.get_tasks().await?; let now = chrono::Utc::now().timestamp_millis(); debug!( "Task scheduler starting at {} with {} tasks", @@ -123,17 +123,19 @@ impl TaskSchedulerService { Ok(tokio::spawn(self.run())) } - fn process_request( + async fn process_request( &mut self, request: &TaskRequest, ) -> TaskSchedulerResult { match request { TaskRequest::Schedule(schedule_request) => { - if let Err(e) = self.register_task(schedule_request) { - self.db.insert_failed_scheduling( - schedule_request.id, - format!("{:?}", e), - )?; + if let Err(e) = self.register_task(schedule_request).await { + self.db + .insert_failed_scheduling( + schedule_request.id, + format!("{:?}", e), + ) + .await?; error!( "Failed to process schedule request {}: {}", schedule_request.id, e @@ -143,11 +145,15 @@ impl TaskSchedulerService { } } TaskRequest::Cancel(cancel_request) => { - if let Err(e) = self.process_cancel_request(cancel_request) { - self.db.insert_failed_scheduling( - cancel_request.task_id, - format!("{:?}", e), - )?; + if let Err(e) = + self.process_cancel_request(cancel_request).await + { + self.db + .insert_failed_scheduling( + cancel_request.task_id, + format!("{:?}", e), + ) + .await?; error!( "Failed to process cancel request for task {}: {}", cancel_request.task_id, e @@ -161,11 +167,11 @@ impl TaskSchedulerService { Ok(ProcessingOutcome::Success) } - fn process_cancel_request( + async fn process_cancel_request( &mut self, cancel_request: &CancelTaskRequest, ) -> TaskSchedulerResult<()> { - let Some(task) = self.db.get_task(cancel_request.task_id)? else { + let Some(task) = self.db.get_task(cancel_request.task_id).await? else { // Task not found in the database, cleanup the queue self.remove_task_from_queue(cancel_request.task_id); return Ok(()); @@ -183,7 +189,7 @@ impl TaskSchedulerService { self.remove_task_from_queue(cancel_request.task_id); // Remove task from database - self.unregister_task(cancel_request.task_id)?; + self.unregister_task(cancel_request.task_id).await?; Ok(()) } @@ -212,19 +218,21 @@ impl TaskSchedulerService { } let current_time = chrono::Utc::now().timestamp_millis(); - self.db.update_task_after_execution(task.id, current_time)?; + self.db + .update_task_after_execution(task.id, current_time) + .await?; Ok(()) } - pub fn register_task( + pub async fn register_task( &mut self, task: impl Into, ) -> TaskSchedulerResult<()> { let task = task.into(); // Check if the task already exists in the database - if let Some(db_task) = self.db.get_task(task.id)? { + if let Some(db_task) = self.db.get_task(task.id).await? { if db_task.authority != task.authority { return Err(TaskSchedulerError::UnauthorizedReplacing( task.id, @@ -234,7 +242,7 @@ impl TaskSchedulerService { } } - self.db.insert_task(&task)?; + self.db.insert_task(&task).await?; self.task_queue .insert(task.clone(), Duration::from_millis(0)); debug!("Registered task {} from context", task.id); @@ -242,8 +250,11 @@ impl TaskSchedulerService { Ok(()) } - pub fn unregister_task(&self, task_id: i64) -> TaskSchedulerResult<()> { - self.db.remove_task(task_id)?; + pub async fn unregister_task( + &self, + task_id: i64, + ) -> TaskSchedulerResult<()> { + self.db.remove_task(task_id).await?; debug!("Removed task {} from database", task_id); Ok(()) @@ -259,12 +270,12 @@ impl TaskSchedulerService { error!("Failed to execute task {}: {}", task.id, e); // If any instruction fails, the task is cancelled - self.db.remove_task(task.id)?; - self.db.insert_failed_task(task.id, format!("{:?}", e))?; + self.db.remove_task(task.id).await?; + self.db.insert_failed_task(task.id, format!("{:?}", e)).await?; } } Some(task) = self.scheduled_tasks.recv() => { - match self.process_request(&task) { + match self.process_request(&task).await { Ok(ProcessingOutcome::Success) => {} Ok(ProcessingOutcome::Recoverable(e)) => { warn!("Failed to process request ID={}: {e:?}", task.id()); diff --git a/magicblock-task-scheduler/tests/service.rs b/magicblock-task-scheduler/tests/service.rs index e85bafbdf..398ee77c2 100644 --- a/magicblock-task-scheduler/tests/service.rs +++ b/magicblock-task-scheduler/tests/service.rs @@ -25,7 +25,7 @@ type SetupResult = TaskSchedulerResult<( JoinHandle>, )>; -fn setup() -> SetupResult { +async fn setup() -> SetupResult { let mut env = ExecutionTestEnv::new(); init_validator_authority_if_needed(env.payer.insecure_clone()); @@ -51,14 +51,15 @@ fn setup() -> SetupResult { env.ledger.latest_block().clone(), token.clone(), )? - .start()?; + .start() + .await?; Ok((env, token, handle)) } #[tokio::test] pub async fn test_schedule_task() -> TaskSchedulerResult<()> { - let (env, token, handle) = setup()?; + let (env, token, handle) = setup().await?; let account = env.create_account_with_config(LAMPORTS_PER_SOL, 1, guinea::ID); @@ -110,7 +111,7 @@ pub async fn test_schedule_task() -> TaskSchedulerResult<()> { #[tokio::test] pub async fn test_cancel_task() -> TaskSchedulerResult<()> { - let (env, token, handle) = setup()?; + let (env, token, handle) = setup().await?; let account = env.create_account_with_config(LAMPORTS_PER_SOL, 1, guinea::ID); diff --git a/test-integration/test-task-scheduler/tests/test_cancel_ongoing_task.rs b/test-integration/test-task-scheduler/tests/test_cancel_ongoing_task.rs index 245dcd510..5b109426b 100644 --- a/test-integration/test-task-scheduler/tests/test_cancel_ongoing_task.rs +++ b/test-integration/test-task-scheduler/tests/test_cancel_ongoing_task.rs @@ -12,6 +12,7 @@ use solana_sdk::{ use test_task_scheduler::{ create_delegated_counter, send_noop_tx, setup_validator, }; +use tokio::runtime::Runtime; #[test] fn test_cancel_ongoing_task() { @@ -103,8 +104,10 @@ fn test_cancel_ongoing_task() { // Check that the task was cancelled let db = expect!(SchedulerDatabase::new(db_path), validator); + let runtime = expect!(Runtime::new(), validator); - let failed_scheduling = expect!(db.get_failed_schedulings(), validator); + let failed_scheduling = + expect!(runtime.block_on(db.get_failed_schedulings()), validator); assert_eq!( failed_scheduling.len(), 0, @@ -113,7 +116,8 @@ fn test_cancel_ongoing_task() { failed_scheduling, ); - let failed_tasks = expect!(db.get_failed_tasks(), validator); + let failed_tasks = + expect!(runtime.block_on(db.get_failed_tasks()), validator); assert_eq!( failed_tasks.len(), 0, @@ -122,7 +126,7 @@ fn test_cancel_ongoing_task() { failed_tasks ); - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!( tasks.len(), 0, @@ -131,7 +135,7 @@ fn test_cancel_ongoing_task() { tasks ); - let task = expect!(db.get_task(task_id), validator); + let task = expect!(runtime.block_on(db.get_task(task_id)), validator); assert!(task.is_none(), cleanup(&mut validator)); // Check that the counter was incremented but not as much as the number of executions diff --git a/test-integration/test-task-scheduler/tests/test_reschedule_task.rs b/test-integration/test-task-scheduler/tests/test_reschedule_task.rs index f2532bef7..3b21ee189 100644 --- a/test-integration/test-task-scheduler/tests/test_reschedule_task.rs +++ b/test-integration/test-task-scheduler/tests/test_reschedule_task.rs @@ -12,6 +12,7 @@ use solana_sdk::{ use test_task_scheduler::{ create_delegated_counter, send_noop_tx, setup_validator, }; +use tokio::runtime::Runtime; #[test] fn test_reschedule_task() { @@ -103,8 +104,10 @@ fn test_reschedule_task() { // Check that the task was scheduled in the database let db = expect!(SchedulerDatabase::new(db_path), validator); + let runtime = expect!(Runtime::new(), validator); - let failed_scheduling = expect!(db.get_failed_schedulings(), validator); + let failed_scheduling = + expect!(runtime.block_on(db.get_failed_schedulings()), validator); assert_eq!( failed_scheduling.len(), 0, @@ -113,7 +116,8 @@ fn test_reschedule_task() { failed_scheduling, ); - let failed_tasks = expect!(db.get_failed_tasks(), validator); + let failed_tasks = + expect!(runtime.block_on(db.get_failed_tasks()), validator); assert_eq!( failed_tasks.len(), 0, @@ -122,11 +126,12 @@ fn test_reschedule_task() { failed_tasks ); - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!(tasks.len(), 1, cleanup(&mut validator)); let task = expect!( - db.get_task(task_id) + runtime + .block_on(db.get_task(task_id)) .ok() .flatten() .ok_or(anyhow::anyhow!("Task not found")), @@ -184,7 +189,7 @@ fn test_reschedule_task() { expect!(ctx.wait_for_delta_slot_ephem(5), validator); // Check that the task was cancelled - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!(tasks.len(), 0, cleanup(&mut validator)); cleanup(&mut validator); diff --git a/test-integration/test-task-scheduler/tests/test_schedule_error.rs b/test-integration/test-task-scheduler/tests/test_schedule_error.rs index 970cf7acc..45e3d5b7f 100644 --- a/test-integration/test-task-scheduler/tests/test_schedule_error.rs +++ b/test-integration/test-task-scheduler/tests/test_schedule_error.rs @@ -12,6 +12,7 @@ use solana_sdk::{ use test_task_scheduler::{ create_delegated_counter, send_noop_tx, setup_validator, }; +use tokio::runtime::Runtime; // Test that a task with an error is unscheduled #[test] @@ -70,8 +71,10 @@ fn test_schedule_error() { // Check that the task was scheduled in the database let db = expect!(SchedulerDatabase::new(db_path), validator); + let runtime = expect!(Runtime::new(), validator); - let failed_scheduling = expect!(db.get_failed_schedulings(), validator); + let failed_scheduling = + expect!(runtime.block_on(db.get_failed_schedulings()), validator); assert_eq!( failed_scheduling.len(), 0, @@ -80,7 +83,8 @@ fn test_schedule_error() { failed_scheduling, ); - let failed_tasks = expect!(db.get_failed_tasks(), validator); + let failed_tasks = + expect!(runtime.block_on(db.get_failed_tasks()), validator); assert_eq!( failed_tasks.len(), 1, @@ -89,7 +93,7 @@ fn test_schedule_error() { failed_tasks, ); - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!( tasks.len(), 0, @@ -99,7 +103,7 @@ fn test_schedule_error() { ); assert!( - expect!(db.get_task(task_id), validator).is_none(), + expect!(runtime.block_on(db.get_task(task_id)), validator).is_none(), cleanup(&mut validator) ); @@ -145,7 +149,7 @@ fn test_schedule_error() { expect!(ctx.wait_for_delta_slot_ephem(2), validator); // Check that the task was cancelled - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!( tasks.len(), 0, diff --git a/test-integration/test-task-scheduler/tests/test_schedule_task.rs b/test-integration/test-task-scheduler/tests/test_schedule_task.rs index 4040d074d..72d72e830 100644 --- a/test-integration/test-task-scheduler/tests/test_schedule_task.rs +++ b/test-integration/test-task-scheduler/tests/test_schedule_task.rs @@ -12,6 +12,7 @@ use solana_sdk::{ use test_task_scheduler::{ create_delegated_counter, send_noop_tx, setup_validator, }; +use tokio::runtime::Runtime; #[test] fn test_schedule_task() { @@ -69,8 +70,10 @@ fn test_schedule_task() { // Check that the task was scheduled in the database let db = expect!(SchedulerDatabase::new(db_path), validator); + let runtime = expect!(Runtime::new(), validator); - let failed_scheduling = expect!(db.get_failed_schedulings(), validator); + let failed_scheduling = + expect!(runtime.block_on(db.get_failed_schedulings()), validator); assert_eq!( failed_scheduling.len(), 0, @@ -79,7 +82,8 @@ fn test_schedule_task() { failed_scheduling, ); - let failed_tasks = expect!(db.get_failed_tasks(), validator); + let failed_tasks = + expect!(runtime.block_on(db.get_failed_tasks()), validator); assert_eq!( failed_tasks.len(), 0, @@ -88,11 +92,12 @@ fn test_schedule_task() { failed_tasks ); - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!(tasks.len(), 1, cleanup(&mut validator)); let task = expect!( - db.get_task(task_id) + runtime + .block_on(db.get_task(task_id)) .ok() .flatten() .ok_or(anyhow::anyhow!("Task not found")), @@ -150,7 +155,7 @@ fn test_schedule_task() { expect!(ctx.wait_for_delta_slot_ephem(5), validator); // Check that the task was cancelled - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!(tasks.len(), 0, cleanup(&mut validator)); cleanup(&mut validator); diff --git a/test-integration/test-task-scheduler/tests/test_unauthorized_reschedule.rs b/test-integration/test-task-scheduler/tests/test_unauthorized_reschedule.rs index 37eb25125..d1cf96241 100644 --- a/test-integration/test-task-scheduler/tests/test_unauthorized_reschedule.rs +++ b/test-integration/test-task-scheduler/tests/test_unauthorized_reschedule.rs @@ -11,6 +11,7 @@ use solana_sdk::{ use test_task_scheduler::{ create_delegated_counter, send_noop_tx, setup_validator, }; +use tokio::runtime::Runtime; #[test] fn test_unauthorized_reschedule() { @@ -108,8 +109,10 @@ fn test_unauthorized_reschedule() { // Check that one task is scheduled but another one is failed to schedule let db = expect!(SchedulerDatabase::new(db_path), validator); + let runtime = expect!(Runtime::new(), validator); - let failed_scheduling = expect!(db.get_failed_schedulings(), validator); + let failed_scheduling = + expect!(runtime.block_on(db.get_failed_schedulings()), validator); assert_eq!( failed_scheduling.len(), 1, @@ -118,7 +121,8 @@ fn test_unauthorized_reschedule() { failed_scheduling, ); - let failed_tasks = expect!(db.get_failed_tasks(), validator); + let failed_tasks = + expect!(runtime.block_on(db.get_failed_tasks()), validator); assert_eq!( failed_tasks.len(), 0, @@ -127,11 +131,12 @@ fn test_unauthorized_reschedule() { failed_tasks ); - let tasks = expect!(db.get_task_ids(), validator); + let tasks = expect!(runtime.block_on(db.get_task_ids()), validator); assert_eq!(tasks.len(), 1, cleanup(&mut validator)); let task = expect!( - db.get_task(task_id) + runtime + .block_on(db.get_task(task_id)) .ok() .flatten() .ok_or(anyhow::anyhow!("Task not found")),