diff --git a/src/execution/live_updater.rs b/src/execution/live_updater.rs index 4c3c391b..4cf4aad6 100644 --- a/src/execution/live_updater.rs +++ b/src/execution/live_updater.rs @@ -3,7 +3,7 @@ use crate::{execution::stats::UpdateStats, prelude::*}; use super::stats; use futures::future::try_join_all; use sqlx::PgPool; -use tokio::{sync::watch, time::MissedTickBehavior}; +use tokio::{sync::watch, task::JoinSet, time::MissedTickBehavior}; pub struct FlowLiveUpdaterUpdates { pub active_sources: Vec, @@ -22,7 +22,8 @@ struct UpdateReceiveState { pub struct FlowLiveUpdater { flow_ctx: Arc, - tasks: Vec<(tokio::task::JoinHandle>, Arc)>, + join_set: Mutex>>>, + stats_per_task: Vec>, recv_state: tokio::sync::Mutex, num_remaining_tasks_rx: watch::Receiver, @@ -267,7 +268,11 @@ impl SourceUpdateTask { .boxed() }); - try_join_all(futs).await?; + let join_result = try_join_all(futs).await; + if let Err(err) = join_result { + error!("Error in source `{}`: {:?}", import_op.name, err); + return Err(err); + } Ok(()) } } @@ -288,27 +293,30 @@ impl FlowLiveUpdater { let (num_remaining_tasks_tx, num_remaining_tasks_rx) = watch::channel(plan.import_ops.len()); - let tasks = (0..plan.import_ops.len()) - .map(|source_idx| { - let source_update_stats = Arc::new(stats::UpdateStats::default()); - let source_update_task = SourceUpdateTask { - source_idx, - flow: flow_ctx.flow.clone(), - plan: plan.clone(), - execution_ctx: execution_ctx.clone(), - source_update_stats: source_update_stats.clone(), - pool: pool.clone(), - options: options.clone(), - status_tx: status_tx.clone(), - num_remaining_tasks_tx: num_remaining_tasks_tx.clone(), - }; - let task = tokio::spawn(source_update_task.run()); - (task, source_update_stats) - }) - .collect(); + + let mut join_set = JoinSet::new(); + let mut stats_per_task = Vec::new(); + + for source_idx in 0..plan.import_ops.len() { + let source_update_stats = Arc::new(stats::UpdateStats::default()); + let source_update_task = SourceUpdateTask { + source_idx, + flow: flow_ctx.flow.clone(), + plan: plan.clone(), + execution_ctx: execution_ctx.clone(), + source_update_stats: source_update_stats.clone(), + pool: pool.clone(), + options: options.clone(), + status_tx: status_tx.clone(), + num_remaining_tasks_tx: num_remaining_tasks_tx.clone(), + }; + join_set.spawn(source_update_task.run()); + stats_per_task.push(source_update_stats); + } Ok(Self { flow_ctx, - tasks, + join_set: Mutex::new(Some(join_set)), + stats_per_task, recv_state: tokio::sync::Mutex::new(UpdateReceiveState { status_rx, last_num_source_updates: vec![0; plan.import_ops.len()], @@ -322,17 +330,33 @@ impl FlowLiveUpdater { } pub async fn wait(&self) -> Result<()> { - let mut rx = self.num_remaining_tasks_rx.clone(); - if *rx.borrow() == 0 { + { + let mut rx = self.num_remaining_tasks_rx.clone(); + rx.wait_for(|v| *v == 0).await?; + } + + let Some(mut join_set) = self.join_set.lock().unwrap().take() else { return Ok(()); + }; + while let Some(task_result) = join_set.join_next().await { + match task_result { + Ok(Ok(_)) => {} + Ok(Err(err)) => { + return Err(err); + } + Err(err) if err.is_cancelled() => {} + Err(err) => { + return Err(err.into()); + } + } } - rx.wait_for(|v| *v == 0).await?; Ok(()) } pub fn abort(&self) { - for (task, _) in &self.tasks { - task.abort(); + let mut join_set = self.join_set.lock().unwrap(); + if let Some(join_set) = &mut *join_set { + join_set.abort_all(); } } @@ -340,9 +364,9 @@ impl FlowLiveUpdater { stats::IndexUpdateInfo { sources: std::iter::zip( self.flow_ctx.flow.flow_instance.import_ops.iter(), - self.tasks.iter(), + self.stats_per_task.iter(), ) - .map(|(import_op, (_, stats))| stats::SourceUpdateInfo { + .map(|(import_op, stats)| stats::SourceUpdateInfo { source_name: import_op.name.clone(), stats: stats.as_ref().clone(), })