Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 53 additions & 29 deletions src/execution/live_updater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{execution::stats::UpdateStats, prelude::*};
use super::stats;
use futures::future::try_join_all;
use sqlx::PgPool;
use tokio::{sync::watch, time::MissedTickBehavior};
use tokio::{sync::watch, task::JoinSet, time::MissedTickBehavior};

pub struct FlowLiveUpdaterUpdates {
pub active_sources: Vec<String>,
Expand All @@ -22,7 +22,8 @@ struct UpdateReceiveState {

pub struct FlowLiveUpdater {
flow_ctx: Arc<FlowContext>,
tasks: Vec<(tokio::task::JoinHandle<Result<()>>, Arc<stats::UpdateStats>)>,
join_set: Mutex<Option<JoinSet<Result<()>>>>,
stats_per_task: Vec<Arc<stats::UpdateStats>>,
recv_state: tokio::sync::Mutex<UpdateReceiveState>,
num_remaining_tasks_rx: watch::Receiver<usize>,

Expand Down Expand Up @@ -267,7 +268,11 @@ impl SourceUpdateTask {
.boxed()
});

try_join_all(futs).await?;
let join_result = try_join_all(futs).await;
if let Err(err) = join_result {
error!("Error in source `{}`: {:?}", import_op.name, err);
return Err(err);
}
Ok(())
}
}
Expand All @@ -288,27 +293,30 @@ impl FlowLiveUpdater {

let (num_remaining_tasks_tx, num_remaining_tasks_rx) =
watch::channel(plan.import_ops.len());
let tasks = (0..plan.import_ops.len())
.map(|source_idx| {
let source_update_stats = Arc::new(stats::UpdateStats::default());
let source_update_task = SourceUpdateTask {
source_idx,
flow: flow_ctx.flow.clone(),
plan: plan.clone(),
execution_ctx: execution_ctx.clone(),
source_update_stats: source_update_stats.clone(),
pool: pool.clone(),
options: options.clone(),
status_tx: status_tx.clone(),
num_remaining_tasks_tx: num_remaining_tasks_tx.clone(),
};
let task = tokio::spawn(source_update_task.run());
(task, source_update_stats)
})
.collect();

let mut join_set = JoinSet::new();
let mut stats_per_task = Vec::new();

for source_idx in 0..plan.import_ops.len() {
let source_update_stats = Arc::new(stats::UpdateStats::default());
let source_update_task = SourceUpdateTask {
source_idx,
flow: flow_ctx.flow.clone(),
plan: plan.clone(),
execution_ctx: execution_ctx.clone(),
source_update_stats: source_update_stats.clone(),
pool: pool.clone(),
options: options.clone(),
status_tx: status_tx.clone(),
num_remaining_tasks_tx: num_remaining_tasks_tx.clone(),
};
join_set.spawn(source_update_task.run());
stats_per_task.push(source_update_stats);
}
Ok(Self {
flow_ctx,
tasks,
join_set: Mutex::new(Some(join_set)),
stats_per_task,
recv_state: tokio::sync::Mutex::new(UpdateReceiveState {
status_rx,
last_num_source_updates: vec![0; plan.import_ops.len()],
Expand All @@ -322,27 +330,43 @@ impl FlowLiveUpdater {
}

pub async fn wait(&self) -> Result<()> {
let mut rx = self.num_remaining_tasks_rx.clone();
if *rx.borrow() == 0 {
{
let mut rx = self.num_remaining_tasks_rx.clone();
rx.wait_for(|v| *v == 0).await?;
}

let Some(mut join_set) = self.join_set.lock().unwrap().take() else {
return Ok(());
};
while let Some(task_result) = join_set.join_next().await {
match task_result {
Ok(Ok(_)) => {}
Ok(Err(err)) => {
return Err(err);
}
Err(err) if err.is_cancelled() => {}
Err(err) => {
return Err(err.into());
}
}
}
rx.wait_for(|v| *v == 0).await?;
Ok(())
}

pub fn abort(&self) {
for (task, _) in &self.tasks {
task.abort();
let mut join_set = self.join_set.lock().unwrap();
if let Some(join_set) = &mut *join_set {
join_set.abort_all();
}
}

pub fn index_update_info(&self) -> stats::IndexUpdateInfo {
stats::IndexUpdateInfo {
sources: std::iter::zip(
self.flow_ctx.flow.flow_instance.import_ops.iter(),
self.tasks.iter(),
self.stats_per_task.iter(),
)
.map(|(import_op, (_, stats))| stats::SourceUpdateInfo {
.map(|(import_op, stats)| stats::SourceUpdateInfo {
source_name: import_op.name.clone(),
stats: stats.as_ref().clone(),
})
Expand Down
Loading