diff --git a/src/builder/analyzer.rs b/src/builder/analyzer.rs index 56343368..d41f2211 100644 --- a/src/builder/analyzer.rs +++ b/src/builder/analyzer.rs @@ -816,20 +816,11 @@ impl AnalyzerContext<'_> { &self, scope: &mut DataScopeBuilder, export_op: NamedSpec, + export_factory: Arc, setup_state: Option<&mut FlowSetupState>, existing_target_states: &HashMap<&ResourceIdentifier, Vec<&TargetSetupState>>, ) -> Result> + Send> { let export_target = export_op.spec.target; - let export_factory = match self.registry.get(&export_target.kind) { - Some(ExecutorFactory::ExportTarget(export_executor)) => export_executor, - _ => { - return Err(anyhow::anyhow!( - "Export target kind not found: {}", - export_target.kind - )) - } - }; - let spec = serde_json::Value::Object(export_target.spec.clone()); let (local_collector_ref, collector_schema) = scope.consume_collector(&export_op.spec.collector_name)?; @@ -986,8 +977,8 @@ impl AnalyzerContext<'_> { .unwrap_or(false); Ok(async move { trace!("Start building executor for export op `{}`", export_op.name); - let (executor, query_target) = setup_output - .executor + let executors = setup_output + .executors .await .with_context(|| format!("Analyzing export op: {}", export_op.name))?; trace!( @@ -999,8 +990,8 @@ impl AnalyzerContext<'_> { name, target_id: target_id.unwrap_or_default(), input: local_collector_ref, - executor, - query_target, + export_context: executors.export_context, + query_target: executors.query_target, primary_key_def, primary_key_type, value_fields: value_fields_idx, @@ -1127,18 +1118,36 @@ pub fn analyze_flow( &flow_inst.reactive_ops, RefList::Nil, )?; - let export_ops_futs = flow_inst - .export_ops - .iter() - .map(|export_op| { - analyzer_ctx.analyze_export_op( - root_exec_scope.data, - export_op.clone(), - Some(&mut setup_state), - &target_states_by_name_type, - ) - }) - .collect::>>()?; + + let mut target_groups = IndexMap::::new(); + let mut export_ops_futs = vec![]; + for (idx, export_op) in flow_inst.export_ops.iter().enumerate() { + let target_kind = export_op.spec.target.kind.clone(); + let export_factory = match registry.get(&target_kind) { + Some(ExecutorFactory::ExportTarget(export_executor)) => export_executor, + _ => { + return Err(anyhow::anyhow!( + "Export target kind not found: {}", + export_op.spec.target.kind + )) + } + }; + export_ops_futs.push(analyzer_ctx.analyze_export_op( + root_exec_scope.data, + export_op.clone(), + export_factory.clone(), + Some(&mut setup_state), + &target_states_by_name_type, + )?); + target_groups + .entry(target_kind) + .or_insert_with(|| AnalyzedExportTargetOpGroup { + target_factory: export_factory.clone(), + op_idx: vec![], + }) + .op_idx + .push(idx); + } let tracking_table_setup = setup_state.tracking_table.clone(); let data_schema = root_data_scope.into_data_schema()?; @@ -1160,6 +1169,7 @@ pub fn analyze_flow( import_ops, op_scope, export_ops, + export_op_groups: target_groups.into_values().collect(), }) }; diff --git a/src/builder/plan.rs b/src/builder/plan.rs index bfca2864..133bdd0b 100644 --- a/src/builder/plan.rs +++ b/src/builder/plan.rs @@ -100,7 +100,7 @@ pub struct AnalyzedExportOp { pub name: String, pub target_id: i32, pub input: AnalyzedLocalCollectorReference, - pub executor: Arc, + pub export_context: Arc, pub query_target: Option>, pub primary_key_def: AnalyzedPrimaryKeyDef, pub primary_key_type: schema::ValueType, @@ -111,6 +111,11 @@ pub struct AnalyzedExportOp { pub value_stable: bool, } +pub struct AnalyzedExportTargetOpGroup { + pub target_factory: Arc, + pub op_idx: Vec, +} + pub enum AnalyzedReactiveOp { Transform(AnalyzedTransformOp), ForEach(AnalyzedForEachOp), @@ -128,6 +133,7 @@ pub struct ExecutionPlan { pub import_ops: Vec, pub op_scope: AnalyzedOpScope, pub export_ops: Vec, + pub export_op_groups: Vec, } pub struct TransientExecutionPlan { diff --git a/src/execution/row_indexer.rs b/src/execution/row_indexer.rs index 96ced938..b5039d59 100644 --- a/src/execution/row_indexer.rs +++ b/src/execution/row_indexer.rs @@ -554,16 +554,26 @@ pub async fn update_source_row( // Phase 3: Apply changes to the target storage, including upserting new target records and removing existing ones. let mut target_mutations = precommit_output.target_mutations; - let apply_futs = plan.export_ops.iter().filter_map(|export_op| { - target_mutations - .remove(&export_op.target_id) - .and_then(|mutation| { - if !mutation.is_empty() { - Some(export_op.executor.apply_mutation(mutation)) - } else { - None - } + let apply_futs = plan.export_op_groups.iter().filter_map(|export_op_group| { + let mutations_w_ctx: Vec<_> = export_op_group + .op_idx + .iter() + .filter_map(|export_op_idx| { + let export_op = &plan.export_ops[*export_op_idx]; + target_mutations + .remove(&export_op.target_id) + .filter(|m| !m.is_empty()) + .map(|mutation| interface::ExportTargetMutationWithContext { + mutation, + export_context: export_op.export_context.as_ref(), + }) }) + .collect(); + (!mutations_w_ctx.is_empty()).then(|| { + export_op_group + .target_factory + .apply_mutation(mutations_w_ctx) + }) }); // TODO: Handle errors. diff --git a/src/ops/factory_bases.rs b/src/ops/factory_bases.rs index b9055c7c..2e509ab0 100644 --- a/src/ops/factory_bases.rs +++ b/src/ops/factory_bases.rs @@ -264,17 +264,23 @@ impl SimpleFunctionFactory for T { } } -pub struct ExportTargetBuildOutput { - pub executor: - BoxFuture<'static, Result<(Arc, Option>)>>, +pub struct TypedExportTargetExecutors { + pub export_context: Arc, + pub query_target: Option>, +} + +pub struct TypedExportTargetBuildOutput { + pub executors: BoxFuture<'static, Result>>, pub setup_key: F::Key, pub desired_setup_state: F::SetupState, } +#[async_trait] pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static { type Spec: DeserializeOwned + Send + Sync; type Key: Debug + Clone + Serialize + DeserializeOwned + Eq + Hash + Send + Sync; type SetupState: Debug + Clone + Serialize + DeserializeOwned + Send + Sync; + type ExportContext: Send + Sync + 'static; fn name(&self) -> &str; @@ -286,7 +292,7 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static { value_fields_schema: Vec, storage_options: IndexOptions, context: Arc, - ) -> Result>; + ) -> Result>; /// Will not be called if it's setup by user. /// It returns an error if the target only supports setup by user. @@ -315,8 +321,14 @@ pub trait StorageFactoryBase: ExportTargetFactory + Send + Sync + 'static { ExecutorFactory::ExportTarget(Arc::new(self)), ) } + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()>; } +#[async_trait] impl ExportTargetFactory for T { fn build( self: Arc, @@ -337,10 +349,17 @@ impl ExportTargetFactory for T { storage_options, context, )?; + let executors = async move { + let executors = build_output.executors.await?; + Ok(interface::ExportTargetExecutors { + export_context: executors.export_context, + query_target: executors.query_target, + }) + }; Ok(interface::ExportTargetBuildOutput { - executor: build_output.executor, setup_key: serde_json::to_value(build_output.setup_key)?, desired_setup_state: serde_json::to_value(build_output.desired_setup_state)?, + executors: executors.boxed(), }) } @@ -383,6 +402,25 @@ impl ExportTargetFactory for T { )?; Ok(result) } + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()> { + let mutations = mutations + .into_iter() + .map(|m| { + anyhow::Ok(ExportTargetMutationWithContext { + mutation: m.mutation, + export_context: m + .export_context + .downcast_ref::() + .ok_or_else(|| anyhow!("Unexpected export context type"))?, + }) + }) + .collect::>()?; + StorageFactoryBase::apply_mutation(self, mutations).await + } } fn from_json_combined_state( diff --git a/src/ops/interface.rs b/src/ops/interface.rs index e9ffc17f..927fd86c 100644 --- a/src/ops/interface.rs +++ b/src/ops/interface.rs @@ -138,9 +138,10 @@ impl ExportTargetMutation { } } -#[async_trait] -pub trait ExportTargetExecutor: Send + Sync { - async fn apply_mutation(&self, mutation: ExportTargetMutation) -> Result<()>; +#[derive(Debug)] +pub struct ExportTargetMutationWithContext<'ctx, T: ?Sized + Send + Sync> { + pub mutation: ExportTargetMutation, + pub export_context: &'ctx T, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -156,14 +157,18 @@ pub enum SetupStateCompatibility { NotCompatible, } +pub struct ExportTargetExecutors { + pub export_context: Arc, + pub query_target: Option>, +} pub struct ExportTargetBuildOutput { - pub executor: - BoxFuture<'static, Result<(Arc, Option>)>>, + pub executors: BoxFuture<'static, Result>, pub setup_key: serde_json::Value, pub desired_setup_state: serde_json::Value, } -pub trait ExportTargetFactory { +#[async_trait] +pub trait ExportTargetFactory: Send + Sync { fn build( self: Arc, name: String, @@ -191,6 +196,11 @@ pub trait ExportTargetFactory { ) -> Result; fn describe_resource(&self, key: &serde_json::Value) -> Result; + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()>; } #[derive(Clone)] diff --git a/src/ops/sdk.rs b/src/ops/sdk.rs index f135cd7a..803c67ba 100644 --- a/src/ops/sdk.rs +++ b/src/ops/sdk.rs @@ -11,7 +11,7 @@ pub use crate::base::spec::*; pub use crate::base::value::*; // Disambiguate the ExportTargetBuildOutput type. -pub use super::factory_bases::ExportTargetBuildOutput; +pub use super::factory_bases::TypedExportTargetBuildOutput; /// Defined for all types convertible to ValueType, to ease creation for ValueType in various operation factories. pub trait TypeCore { fn into_type(self) -> ValueType; diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 29726501..055eb181 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -132,8 +132,10 @@ struct AnalyzedNodeLabelInfo { key_fields: Vec, value_fields: Vec, } -struct RelationshipStorageExecutor { +pub struct RelationshipExportContext { + connection_ref: AuthEntryReference, graph: Arc, + delete_cypher: String, insert_cypher: String, @@ -300,7 +302,7 @@ const SRC_PROPS_PARAM: &str = "source_props"; const TGT_KEY_PARAM_PREFIX: &str = "target_key"; const TGT_PROPS_PARAM: &str = "target_props"; -impl RelationshipStorageExecutor { +impl RelationshipExportContext { fn build_key_field_params_n_literal<'a>( param_prefix: &str, key_fields: impl Iterator, @@ -401,6 +403,7 @@ FINISH }, ); Ok(Self { + connection_ref: spec.connection, graph, delete_cypher, insert_cypher, @@ -441,95 +444,79 @@ FINISH Ok(query) } - fn build_queries_to_apply_mutation( + fn add_upsert_queries( &self, - mutation: &ExportTargetMutation, - ) -> Result> { - let mut queries = vec![]; - for upsert in mutation.upserts.iter() { - queries.push( - self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), &upsert.key)?, - ); + upsert: &ExportTargetUpsertEntry, + queries: &mut Vec, + ) -> Result<()> { + queries + .push(self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), &upsert.key)?); + + let value = &upsert.value; + let mut insert_cypher = + self.bind_rel_key_field_params(neo4rs::query(&self.insert_cypher), &upsert.key)?; + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.src_key_field_params, + self.src_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.tgt_key_field_params, + self.tgt_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; - let value = &upsert.value; - let mut insert_cypher = - self.bind_rel_key_field_params(neo4rs::query(&self.insert_cypher), &upsert.key)?; - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.src_key_field_params, - self.src_fields - .key_fields - .iter() - .map(|f| (&f.value_type, &value.fields[f.field_idx])), - )?; - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.tgt_key_field_params, - self.tgt_fields - .key_fields - .iter() - .map(|f| (&f.value_type, &value.fields[f.field_idx])), - )?; - - if !self.src_fields.value_fields.is_empty() { - insert_cypher = insert_cypher.param( - SRC_PROPS_PARAM, - mapped_field_values_to_bolt( - self.src_fields - .value_fields - .iter() - .map(|f| &value.fields[f.field_idx]), - self.src_fields.value_fields.iter(), - )?, - ); - } - if !self.tgt_fields.value_fields.is_empty() { - insert_cypher = insert_cypher.param( - TGT_PROPS_PARAM, - mapped_field_values_to_bolt( - self.tgt_fields - .value_fields - .iter() - .map(|f| &value.fields[f.field_idx]), - self.tgt_fields.value_fields.iter(), - )?, - ); - } - if !self.value_fields.is_empty() { - insert_cypher = insert_cypher.param( - REL_PROPS_PARAM, - mapped_field_values_to_bolt( - self.value_fields.iter().map(|f| &value.fields[f.field_idx]), - self.value_fields.iter(), - )?, - ); - } - queries.push(insert_cypher); + if !self.src_fields.value_fields.is_empty() { + insert_cypher = insert_cypher.param( + SRC_PROPS_PARAM, + mapped_field_values_to_bolt( + self.src_fields + .value_fields + .iter() + .map(|f| &value.fields[f.field_idx]), + self.src_fields.value_fields.iter(), + )?, + ); + } + if !self.tgt_fields.value_fields.is_empty() { + insert_cypher = insert_cypher.param( + TGT_PROPS_PARAM, + mapped_field_values_to_bolt( + self.tgt_fields + .value_fields + .iter() + .map(|f| &value.fields[f.field_idx]), + self.tgt_fields.value_fields.iter(), + )?, + ); } - for delete_key in mutation.delete_keys.iter() { - queries.push( - self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), delete_key)?, + if !self.value_fields.is_empty() { + insert_cypher = insert_cypher.param( + REL_PROPS_PARAM, + mapped_field_values_to_bolt( + self.value_fields.iter().map(|f| &value.fields[f.field_idx]), + self.value_fields.iter(), + )?, ); } - Ok(queries) + queries.push(insert_cypher); + Ok(()) } -} -#[async_trait] -impl ExportTargetExecutor for RelationshipStorageExecutor { - async fn apply_mutation(&self, mutation: ExportTargetMutation) -> Result<()> { - retriable::run( - || async { - let queries = self.build_queries_to_apply_mutation(&mutation)?; - let mut txn = self.graph.start_txn().await?; - txn.run_queries(queries.clone()).await?; - txn.commit().await?; - retriable::Ok(()) - }, - retriable::RunOptions::default(), - ) - .await - .map_err(Into::::into) + fn add_delete_queries( + &self, + delete_key: &value::KeyValue, + queries: &mut Vec, + ) -> Result<()> { + queries + .push(self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), delete_key)?); + Ok(()) } } @@ -1016,10 +1003,12 @@ impl<'a> NodeLabelAnalyzer<'a> { } } +#[async_trait] impl StorageFactoryBase for RelationshipFactory { type Spec = RelationshipSpec; type SetupState = RelationshipSetupState; type Key = GraphRelationship; + type ExportContext = RelationshipExportContext; fn name(&self) -> &str { "Neo4jRelationship" @@ -1033,7 +1022,7 @@ impl StorageFactoryBase for RelationshipFactory { value_fields_schema: Vec, index_options: IndexOptions, context: Arc, - ) -> Result> { + ) -> Result> { let setup_key = GraphRelationship::from_spec(&spec); let mut src_label_analyzer = NodeLabelAnalyzer::new(&spec, &spec.source)?; @@ -1065,9 +1054,9 @@ impl StorageFactoryBase for RelationshipFactory { let conn_spec = context .auth_registry .get::(&spec.connection)?; - let executor = async move { + let executors = async move { let graph = self.graph_pool.get_graph(&conn_spec).await?; - let executor = Arc::new(RelationshipStorageExecutor::new( + let executor = Arc::new(RelationshipExportContext::new( graph, spec, key_fields_schema, @@ -1075,11 +1064,14 @@ impl StorageFactoryBase for RelationshipFactory { src_label_info, tgt_label_info, )?); - Ok((executor as Arc, None)) + Ok(TypedExportTargetExecutors { + export_context: executor, + query_target: None, + }) } .boxed(); - Ok(ExportTargetBuildOutput { - executor, + Ok(TypedExportTargetBuildOutput { + executors, setup_key, desired_setup_state, }) @@ -1122,4 +1114,45 @@ impl StorageFactoryBase for RelationshipFactory { fn describe_resource(&self, key: &GraphRelationship) -> Result { Ok(format!("Neo4j relationship {}", key.relationship)) } + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()> { + let mut muts_by_graph = HashMap::new(); + for mut_with_ctx in mutations.iter() { + muts_by_graph + .entry(&mut_with_ctx.export_context.connection_ref) + .or_insert_with(Vec::new) + .push(mut_with_ctx); + } + for muts in muts_by_graph.values() { + let graph = &muts[0].export_context.graph; + retriable::run( + || async { + let mut queries = vec![]; + for mut_with_ctx in muts.iter() { + let export_ctx = &mut_with_ctx.export_context; + for upsert in mut_with_ctx.mutation.upserts.iter() { + export_ctx.add_upsert_queries(upsert, &mut queries)?; + } + } + for mut_with_ctx in muts.iter() { + let export_ctx = &mut_with_ctx.export_context; + for delete_key in mut_with_ctx.mutation.delete_keys.iter() { + export_ctx.add_delete_queries(delete_key, &mut queries)?; + } + } + let mut txn = graph.start_txn().await?; + txn.run_queries(queries).await?; + txn.commit().await?; + retriable::Ok(()) + }, + retriable::RunOptions::default(), + ) + .await + .map_err(Into::::into)? + } + Ok(()) + } } diff --git a/src/ops/storages/postgres.rs b/src/ops/storages/postgres.rs index 431df401..05397365 100644 --- a/src/ops/storages/postgres.rs +++ b/src/ops/storages/postgres.rs @@ -262,8 +262,8 @@ fn from_pg_value(row: &PgRow, field_idx: usize, typ: &ValueType) -> Result, table_name: ValidIdentifier, key_fields_schema: Vec, value_fields_schema: Vec, @@ -274,9 +274,9 @@ pub struct Executor { delete_sql_prefix: String, } -impl Executor { +impl ExportContext { fn new( - db_pool: PgPool, + database_url: Option, table_name: String, key_fields_schema: Vec, value_fields_schema: Vec, @@ -304,7 +304,7 @@ impl Executor { .collect::>(); let table_name = ValidIdentifier::try_from(table_name)?; Ok(Self { - db_pool, + database_url, key_fields_schema, value_fields_schema, all_fields_comma_separated: all_fields @@ -325,13 +325,14 @@ impl Executor { } } -#[async_trait] -impl ExportTargetExecutor for Executor { - async fn apply_mutation(&self, mutation: ExportTargetMutation) -> Result<()> { +impl ExportContext { + async fn upsert( + &self, + upserts: &[interface::ExportTargetUpsertEntry], + txn: &mut sqlx::PgTransaction<'_>, + ) -> Result<()> { let num_parameters = self.key_fields_schema.len() + self.value_fields_schema.len(); - let mut txn = self.db_pool.begin().await?; - - for upsert_chunk in mutation.upserts.chunks(BIND_LIMIT / num_parameters) { + for upsert_chunk in upserts.chunks(BIND_LIMIT / num_parameters) { let mut query_builder = sqlx::QueryBuilder::new(&self.upsert_sql_prefix); for (i, upsert) in upsert_chunk.iter().enumerate() { if i > 0 { @@ -365,11 +366,18 @@ impl ExportTargetExecutor for Executor { query_builder.push(")"); } query_builder.push(&self.upsert_sql_suffix); - query_builder.build().execute(&mut *txn).await?; + query_builder.build().execute(&mut **txn).await?; } + Ok(()) + } + async fn delete( + &self, + delete_keys: &[KeyValue], + txn: &mut sqlx::PgTransaction<'_>, + ) -> Result<()> { // TODO: Find a way to batch delete. - for delete_key in mutation.delete_keys.iter() { + for delete_key in delete_keys.iter() { let mut query_builder = sqlx::QueryBuilder::new(""); query_builder.push(&self.delete_sql_prefix); for (i, (schema, value)) in self @@ -385,26 +393,28 @@ impl ExportTargetExecutor for Executor { query_builder.push("="); bind_key_field(&mut query_builder, value)?; } - query_builder.build().execute(&mut *txn).await?; + query_builder.build().execute(&mut **txn).await?; } - - txn.commit().await?; - Ok(()) } } static SCORE_FIELD_NAME: &str = "__score"; +struct PostgresQueryTarget { + db_pool: PgPool, + context: Arc, +} + #[async_trait] -impl QueryTarget for Executor { +impl QueryTarget for PostgresQueryTarget { async fn search(&self, query: VectorMatchQuery) -> Result { let query_str = format!( "SELECT {} {} $1 AS {SCORE_FIELD_NAME}, {} FROM {} ORDER BY {SCORE_FIELD_NAME} LIMIT $2", ValidIdentifier::try_from(query.vector_field_name)?, to_distance_operator(query.similarity_metric), - self.all_fields_comma_separated, - self.table_name, + self.context.all_fields_comma_separated, + self.context.table_name, ); let results = sqlx::query(&query_str) .bind(pgvector::Vector::from(query.vector)) @@ -415,9 +425,10 @@ impl QueryTarget for Executor { .map(|r| -> Result { let score: f64 = distance_to_similarity(query.similarity_metric, r.try_get(0)?); let data = self + .context .key_fields_schema .iter() - .chain(self.value_fields_schema.iter()) + .chain(self.context.value_fields_schema.iter()) .enumerate() .map(|(idx, schema)| from_pg_value(&r, idx + 1, &schema.value_type.typ)) .collect::>>()?; @@ -427,7 +438,7 @@ impl QueryTarget for Executor { .collect::>>()?; Ok(QueryResults { - fields: self.all_fields.clone(), + fields: self.context.all_fields.clone(), results, }) } @@ -641,7 +652,8 @@ impl SetupStatusCheck { .filter(|(name, def)| { !existing .current - .as_ref().is_some_and(|v| v.vector_indexes.get(*name) != Some(def)) + .as_ref() + .is_some_and(|v| v.vector_indexes.get(*name) != Some(def)) }) .map(|(k, v)| (k.clone(), v.clone())) .collect(), @@ -821,7 +833,7 @@ impl setup::ResourceSetupStatusCheck for SetupStatusCheck { async fn apply_change(&self) -> Result<()> { let db_pool = self .factory - .get_db_pool(self.table_id.database_url.clone()) + .get_db_pool(&self.table_id.database_url) .await?; let table_name = &self.table_id.table_name; if self.drop_existing { @@ -889,10 +901,12 @@ impl setup::ResourceSetupStatusCheck for SetupStatusCheck { } } +#[async_trait] impl StorageFactoryBase for Arc { type Spec = Spec; type SetupState = SetupState; type Key = TableId; + type ExportContext = ExportContext; fn name(&self) -> &str { "Postgres" @@ -906,7 +920,7 @@ impl StorageFactoryBase for Arc { value_fields_schema: Vec, storage_options: IndexOptions, context: Arc, - ) -> Result> { + ) -> Result> { let table_id = TableId { database_url: spec.database_url.clone(), table_name: spec @@ -920,23 +934,26 @@ impl StorageFactoryBase for Arc { &storage_options, ); let table_name = table_id.table_name.clone(); + let export_context = Arc::new(ExportContext::new( + spec.database_url.clone(), + table_name, + key_fields_schema, + value_fields_schema, + )?); let executors = async move { - let executor = Arc::new(Executor::new( - self.get_db_pool(spec.database_url).await?, - table_name, - key_fields_schema, - value_fields_schema, - )?); - let query_target = executor.clone(); - Ok(( - executor as Arc, - Some(query_target as Arc), - )) + let query_target = Arc::new(PostgresQueryTarget { + db_pool: self.get_db_pool(&spec.database_url).await?, + context: export_context.clone(), + }); + Ok(TypedExportTargetExecutors { + export_context: export_context.clone(), + query_target: Some(query_target as Arc), + }) }; - Ok(ExportTargetBuildOutput { - executor: executors.boxed(), + Ok(TypedExportTargetBuildOutput { setup_key: table_id, desired_setup_state: setup_state, + executors: executors.boxed(), }) } @@ -979,27 +996,59 @@ impl StorageFactoryBase for Arc { fn describe_resource(&self, key: &TableId) -> Result { Ok(format!("Postgres table {}", key.table_name)) } + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()> { + let mut mut_groups_by_db_url = HashMap::new(); + for mutation in mutations.iter() { + mut_groups_by_db_url + .entry(mutation.export_context.database_url.clone()) + .or_insert_with(Vec::new) + .push(mutation); + } + for (db_url, mut_groups) in mut_groups_by_db_url.iter() { + let db_pool = self.get_db_pool(db_url).await?; + let mut txn = db_pool.begin().await?; + for mut_group in mut_groups.iter() { + mut_group + .export_context + .upsert(&mut_group.mutation.upserts, &mut txn) + .await?; + } + for mut_group in mut_groups.iter() { + mut_group + .export_context + .delete(&mut_group.mutation.delete_keys, &mut txn) + .await?; + } + txn.commit().await?; + } + Ok(()) + } } impl Factory { - async fn get_db_pool(&self, database_url: Option) -> Result { + async fn get_db_pool(&self, database_url: &Option) -> Result { let pool_fut = { let mut db_pools = self.db_pools.lock().unwrap(); - match db_pools.entry(database_url) { - std::collections::hash_map::Entry::Vacant(entry) => { - let database_url = entry.key().clone(); - let pool_fut = async { + if let Some(shared_fut) = db_pools.get(database_url) { + shared_fut.clone() + } else { + let pool_fut = { + let database_url = database_url.clone(); + async move { shared_ok(if let Some(database_url) = database_url { PgPool::connect(&database_url).await? } else { get_lib_context().map_err(SharedError::new)?.pool.clone() }) - }; - let shared_fut = pool_fut.boxed().shared(); - entry.insert(shared_fut.clone()); - shared_fut - } - std::collections::hash_map::Entry::Occupied(entry) => entry.get().clone(), + } + }; + let shared_fut = pool_fut.boxed().shared(); + db_pools.insert(database_url.clone(), shared_fut.clone()); + shared_fut } }; Ok(pool_fut.await.std_result()?) diff --git a/src/ops/storages/qdrant.rs b/src/ops/storages/qdrant.rs index f3d047b2..1a48f477 100644 --- a/src/ops/storages/qdrant.rs +++ b/src/ops/storages/qdrant.rs @@ -24,14 +24,14 @@ pub struct Spec { api_key: Option, } -pub struct Executor { +pub struct ExportContext { client: Qdrant, collection_name: String, value_fields_schema: Vec, all_fields: Vec, } -impl Executor { +impl ExportContext { fn new( url: String, collection_name: String, @@ -60,10 +60,7 @@ impl Executor { collection_name, }) } -} -#[async_trait] -impl ExportTargetExecutor for Executor { async fn apply_mutation(&self, mutation: ExportTargetMutation) -> Result<()> { let mut points: Vec = Vec::with_capacity(mutation.upserts.len()); for upsert in mutation.upserts.iter() { @@ -280,7 +277,7 @@ fn into_value(point: &ScoredPoint, schema: &FieldSchema) -> Result { } #[async_trait] -impl QueryTarget for Executor { +impl QueryTarget for ExportContext { async fn search(&self, query: VectorMatchQuery) -> Result { let points = self .client @@ -329,10 +326,12 @@ impl Display for CollectionId { } } +#[async_trait] impl StorageFactoryBase for Arc { type Spec = Spec; type SetupState = (); type Key = String; + type ExportContext = ExportContext; fn name(&self) -> &str { "Qdrant" @@ -346,7 +345,7 @@ impl StorageFactoryBase for Arc { value_fields_schema: Vec, _storage_options: IndexOptions, _context: Arc, - ) -> Result> { + ) -> Result> { if key_fields_schema.len() != 1 { api_bail!( "Expected one primary key field for the point ID. Got {}.", @@ -356,22 +355,22 @@ impl StorageFactoryBase for Arc { let collection_name = spec.collection_name.clone(); + let export_context = Arc::new(ExportContext::new( + spec.grpc_url, + spec.collection_name.clone(), + spec.api_key, + key_fields_schema, + value_fields_schema, + )?); + let query_target = export_context.clone(); let executors = async move { - let executor = Arc::new(Executor::new( - spec.grpc_url, - spec.collection_name.clone(), - spec.api_key, - key_fields_schema, - value_fields_schema, - )?); - let query_target = executor.clone(); - Ok(( - executor as Arc, - Some(query_target as Arc), - )) + Ok(TypedExportTargetExecutors { + export_context, + query_target: Some(query_target as Arc), + }) }; - Ok(ExportTargetBuildOutput { - executor: executors.boxed(), + Ok(TypedExportTargetBuildOutput { + executors: executors.boxed(), setup_key: collection_name, desired_setup_state: (), }) @@ -398,4 +397,17 @@ impl StorageFactoryBase for Arc { fn describe_resource(&self, key: &String) -> Result { Ok(format!("Qdrant collection {}", key)) } + + async fn apply_mutation( + &self, + mutations: Vec>, + ) -> Result<()> { + for mutation_w_ctx in mutations.into_iter() { + mutation_w_ctx + .export_context + .apply_mutation(mutation_w_ctx.mutation) + .await?; + } + Ok(()) + } } diff --git a/src/prelude.rs b/src/prelude.rs index 55ed4192..976bd3f2 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,6 +12,7 @@ pub(crate) use futures::{FutureExt, StreamExt}; pub(crate) use indexmap::{IndexMap, IndexSet}; pub(crate) use itertools::Itertools; pub(crate) use serde::{de::DeserializeOwned, Deserialize, Serialize}; +pub(crate) use std::any::Any; pub(crate) use std::borrow::Cow; pub(crate) use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; pub(crate) use std::hash::Hash;