Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion magicblock-api/src/magic_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ impl MagicValidator {
.take()
.expect("task_scheduler should be initialized");
tokio::spawn(async move {
let join_handle = match task_scheduler.start() {
let join_handle = match task_scheduler.start().await {
Ok(join_handle) => join_handle,
Err(err) => {
error!("Failed to start task scheduler: {:?}", err);
Expand Down
62 changes: 39 additions & 23 deletions magicblock-task-scheduler/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use chrono::Utc;
use magicblock_program::args::ScheduleTaskRequest;
use rusqlite::{params, Connection};
use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
use tokio::sync::Mutex;

use crate::errors::TaskSchedulerError;

Expand Down Expand Up @@ -55,7 +56,7 @@ pub struct FailedTask {
}

pub struct SchedulerDatabase {
conn: Connection,
conn: Mutex<Connection>,
}

impl SchedulerDatabase {
Expand Down Expand Up @@ -101,15 +102,20 @@ impl SchedulerDatabase {
[],
)?;

Ok(Self { conn })
Ok(Self {
conn: Mutex::new(conn),
})
}

pub fn insert_task(&self, task: &DbTask) -> Result<(), TaskSchedulerError> {
pub async fn insert_task(
&self,
task: &DbTask,
) -> Result<(), TaskSchedulerError> {
let instructions_bin = bincode::serialize(&task.instructions)?;
let authority_str = task.authority.to_string();
let now = Utc::now().timestamp_millis();

self.conn.execute(
self.conn.lock().await.execute(
"INSERT OR REPLACE INTO tasks
(id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
Expand All @@ -128,14 +134,14 @@ impl SchedulerDatabase {
Ok(())
}

pub fn update_task_after_execution(
pub async fn update_task_after_execution(
&self,
task_id: i64,
last_execution: i64,
) -> Result<(), TaskSchedulerError> {
let now = Utc::now().timestamp_millis();

self.conn.execute(
self.conn.lock().await.execute(
"UPDATE tasks
SET executions_left = executions_left - 1,
last_execution_millis = ?,
Expand All @@ -147,56 +153,62 @@ impl SchedulerDatabase {
Ok(())
}

pub fn insert_failed_scheduling(
pub async fn insert_failed_scheduling(
&self,
task_id: i64,
error: String,
) -> Result<(), TaskSchedulerError> {
self.conn.execute(
self.conn.lock().await.execute(
"INSERT INTO failed_scheduling (timestamp, task_id, error) VALUES (?, ?, ?)",
params![Utc::now().timestamp_millis(), task_id, error],
)?;

Ok(())
}

pub fn insert_failed_task(
pub async fn insert_failed_task(
&self,
task_id: i64,
error: String,
) -> Result<(), TaskSchedulerError> {
self.conn.execute(
self.conn.lock().await.execute(
"INSERT INTO failed_tasks (timestamp, task_id, error) VALUES (?, ?, ?)",
params![Utc::now().timestamp_millis(), task_id, error],
)?;

Ok(())
}

pub fn unschedule_task(
pub async fn unschedule_task(
&self,
task_id: i64,
) -> Result<(), TaskSchedulerError> {
self.conn.execute(
self.conn.lock().await.execute(
"UPDATE tasks SET executions_left = 0 WHERE id = ?",
[task_id],
)?;

Ok(())
}

pub fn remove_task(&self, task_id: i64) -> Result<(), TaskSchedulerError> {
pub async fn remove_task(
&self,
task_id: i64,
) -> Result<(), TaskSchedulerError> {
self.conn
.lock()
.await
.execute("DELETE FROM tasks WHERE id = ?", [task_id])?;

Ok(())
}

pub fn get_task(
pub async fn get_task(
&self,
task_id: i64,
) -> Result<Option<DbTask>, TaskSchedulerError> {
let mut stmt = self.conn.prepare(
let db = self.conn.lock().await;
let mut stmt = db.prepare(
"SELECT id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis
FROM tasks WHERE id = ?"
)?;
Expand Down Expand Up @@ -231,8 +243,9 @@ impl SchedulerDatabase {
Ok(rows.next().transpose()?)
}

pub fn get_tasks(&self) -> Result<Vec<DbTask>, TaskSchedulerError> {
let mut stmt = self.conn.prepare(
pub async fn get_tasks(&self) -> Result<Vec<DbTask>, TaskSchedulerError> {
let db = self.conn.lock().await;
let mut stmt = db.prepare(
"SELECT id, instructions, authority, execution_interval_millis, executions_left, last_execution_millis
FROM tasks"
)?;
Expand Down Expand Up @@ -272,8 +285,9 @@ impl SchedulerDatabase {
Ok(tasks)
}

pub fn get_task_ids(&self) -> Result<Vec<i64>, TaskSchedulerError> {
let mut stmt = self.conn.prepare(
pub async fn get_task_ids(&self) -> Result<Vec<i64>, TaskSchedulerError> {
let db = self.conn.lock().await;
let mut stmt = db.prepare(
"SELECT id
FROM tasks",
)?;
Expand All @@ -283,10 +297,11 @@ impl SchedulerDatabase {
Ok(rows.collect::<Result<Vec<i64>, rusqlite::Error>>()?)
}

pub fn get_failed_schedulings(
pub async fn get_failed_schedulings(
&self,
) -> Result<Vec<FailedScheduling>, TaskSchedulerError> {
let mut stmt = self.conn.prepare(
let db = self.conn.lock().await;
let mut stmt = db.prepare(
"SELECT *
FROM failed_scheduling",
)?;
Expand All @@ -303,10 +318,11 @@ impl SchedulerDatabase {
Ok(rows.collect::<Result<Vec<FailedScheduling>, rusqlite::Error>>()?)
}

pub fn get_failed_tasks(
pub async fn get_failed_tasks(
&self,
) -> Result<Vec<FailedTask>, TaskSchedulerError> {
let mut stmt = self.conn.prepare(
let db = self.conn.lock().await;
let mut stmt = db.prepare(
"SELECT *
FROM failed_tasks",
)?;
Expand Down
61 changes: 36 additions & 25 deletions magicblock-task-scheduler/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ impl TaskSchedulerService {
})
}

pub fn start(
pub async fn start(
mut self,
) -> TaskSchedulerResult<JoinHandle<TaskSchedulerResult<()>>> {
let tasks = self.db.get_tasks()?;
let tasks = self.db.get_tasks().await?;
let now = chrono::Utc::now().timestamp_millis();
debug!(
"Task scheduler starting at {} with {} tasks",
Expand All @@ -123,17 +123,19 @@ impl TaskSchedulerService {
Ok(tokio::spawn(self.run()))
}

fn process_request(
async fn process_request(
&mut self,
request: &TaskRequest,
) -> TaskSchedulerResult<ProcessingOutcome> {
match request {
TaskRequest::Schedule(schedule_request) => {
if let Err(e) = self.register_task(schedule_request) {
self.db.insert_failed_scheduling(
schedule_request.id,
format!("{:?}", e),
)?;
if let Err(e) = self.register_task(schedule_request).await {
self.db
.insert_failed_scheduling(
schedule_request.id,
format!("{:?}", e),
)
.await?;
error!(
"Failed to process schedule request {}: {}",
schedule_request.id, e
Expand All @@ -143,11 +145,15 @@ impl TaskSchedulerService {
}
}
TaskRequest::Cancel(cancel_request) => {
if let Err(e) = self.process_cancel_request(cancel_request) {
self.db.insert_failed_scheduling(
cancel_request.task_id,
format!("{:?}", e),
)?;
if let Err(e) =
self.process_cancel_request(cancel_request).await
{
self.db
.insert_failed_scheduling(
cancel_request.task_id,
format!("{:?}", e),
)
.await?;
error!(
"Failed to process cancel request for task {}: {}",
cancel_request.task_id, e
Expand All @@ -161,11 +167,11 @@ impl TaskSchedulerService {
Ok(ProcessingOutcome::Success)
}

fn process_cancel_request(
async fn process_cancel_request(
&mut self,
cancel_request: &CancelTaskRequest,
) -> TaskSchedulerResult<()> {
let Some(task) = self.db.get_task(cancel_request.task_id)? else {
let Some(task) = self.db.get_task(cancel_request.task_id).await? else {
// Task not found in the database, cleanup the queue
self.remove_task_from_queue(cancel_request.task_id);
return Ok(());
Expand All @@ -183,7 +189,7 @@ impl TaskSchedulerService {
self.remove_task_from_queue(cancel_request.task_id);

// Remove task from database
self.unregister_task(cancel_request.task_id)?;
self.unregister_task(cancel_request.task_id).await?;

Ok(())
}
Expand Down Expand Up @@ -212,19 +218,21 @@ impl TaskSchedulerService {
}

let current_time = chrono::Utc::now().timestamp_millis();
self.db.update_task_after_execution(task.id, current_time)?;
self.db
.update_task_after_execution(task.id, current_time)
.await?;

Ok(())
}

pub fn register_task(
pub async fn register_task(
&mut self,
task: impl Into<DbTask>,
) -> TaskSchedulerResult<()> {
let task = task.into();

// Check if the task already exists in the database
if let Some(db_task) = self.db.get_task(task.id)? {
if let Some(db_task) = self.db.get_task(task.id).await? {
if db_task.authority != task.authority {
return Err(TaskSchedulerError::UnauthorizedReplacing(
task.id,
Expand All @@ -234,16 +242,19 @@ impl TaskSchedulerService {
}
}

self.db.insert_task(&task)?;
self.db.insert_task(&task).await?;
self.task_queue
.insert(task.clone(), Duration::from_millis(0));
debug!("Registered task {} from context", task.id);

Ok(())
}

pub fn unregister_task(&self, task_id: i64) -> TaskSchedulerResult<()> {
self.db.remove_task(task_id)?;
pub async fn unregister_task(
&self,
task_id: i64,
) -> TaskSchedulerResult<()> {
self.db.remove_task(task_id).await?;
debug!("Removed task {} from database", task_id);

Ok(())
Expand All @@ -259,12 +270,12 @@ impl TaskSchedulerService {
error!("Failed to execute task {}: {}", task.id, e);

// If any instruction fails, the task is cancelled
self.db.remove_task(task.id)?;
self.db.insert_failed_task(task.id, format!("{:?}", e))?;
self.db.remove_task(task.id).await?;
self.db.insert_failed_task(task.id, format!("{:?}", e)).await?;
}
}
Some(task) = self.scheduled_tasks.recv() => {
match self.process_request(&task) {
match self.process_request(&task).await {
Ok(ProcessingOutcome::Success) => {}
Ok(ProcessingOutcome::Recoverable(e)) => {
warn!("Failed to process request ID={}: {e:?}", task.id());
Expand Down
9 changes: 5 additions & 4 deletions magicblock-task-scheduler/tests/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type SetupResult = TaskSchedulerResult<(
JoinHandle<Result<(), TaskSchedulerError>>,
)>;

fn setup() -> SetupResult {
async fn setup() -> SetupResult {
let mut env = ExecutionTestEnv::new();

init_validator_authority_if_needed(env.payer.insecure_clone());
Expand All @@ -51,14 +51,15 @@ fn setup() -> SetupResult {
env.ledger.latest_block().clone(),
token.clone(),
)?
.start()?;
.start()
.await?;

Ok((env, token, handle))
}

#[tokio::test]
pub async fn test_schedule_task() -> TaskSchedulerResult<()> {
let (env, token, handle) = setup()?;
let (env, token, handle) = setup().await?;

let account =
env.create_account_with_config(LAMPORTS_PER_SOL, 1, guinea::ID);
Expand Down Expand Up @@ -110,7 +111,7 @@ pub async fn test_schedule_task() -> TaskSchedulerResult<()> {

#[tokio::test]
pub async fn test_cancel_task() -> TaskSchedulerResult<()> {
let (env, token, handle) = setup()?;
let (env, token, handle) = setup().await?;

let account =
env.create_account_with_config(LAMPORTS_PER_SOL, 1, guinea::ID);
Expand Down
Loading
Loading