diff --git a/Cargo.toml b/Cargo.toml index c0f2fa66..62d3d451 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,3 +104,4 @@ async-stream = "0.3.6" neo4rs = "0.8.0" bytes = "1.10.1" rand = "0.9.0" +indoc = "2.0.6" diff --git a/examples/docs_to_kg/main.py b/examples/docs_to_kg/main.py index 4952a8f1..2fa49b36 100644 --- a/examples/docs_to_kg/main.py +++ b/examples/docs_to_kg/main.py @@ -21,7 +21,7 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D conn_spec = cocoindex.add_auth_entry( "Neo4jConnection", - cocoindex.storages.Neo4jConnectionSpec( + cocoindex.storages.Neo4jConnection( uri="bolt://localhost:7687", user="neo4j", password="cocoindex", @@ -70,38 +70,40 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D relationships.export( "relationships", - cocoindex.storages.Neo4jRelationship( + cocoindex.storages.Neo4j( connection=conn_spec, - rel_type="RELATIONSHIP", - source=cocoindex.storages.Neo4jRelationshipEndSpec( - label="Entity", - fields=[ - cocoindex.storages.Neo4jFieldMapping( - field_name="subject", node_field_name="value"), - cocoindex.storages.Neo4jFieldMapping( - field_name="subject_embedding", node_field_name="embedding"), - ] - ), - target=cocoindex.storages.Neo4jRelationshipEndSpec( - label="Entity", - fields=[ - cocoindex.storages.Neo4jFieldMapping( - field_name="object", node_field_name="value"), - cocoindex.storages.Neo4jFieldMapping( - field_name="object_embedding", node_field_name="embedding"), - ] - ), - nodes={ - "Entity": cocoindex.storages.Neo4jRelationshipNodeSpec( - primary_key_fields=["value"], - vector_indexes=[ - cocoindex.VectorIndexDef( - field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, - ), - ], + mapping=cocoindex.storages.Neo4jRelationship( + rel_type="RELATIONSHIP", + source=cocoindex.storages.Neo4jRelationshipEnd( + label="Entity", + fields=[ + cocoindex.storages.Neo4jFieldMapping( + field_name="subject", node_field_name="value"), + cocoindex.storages.Neo4jFieldMapping( + field_name="subject_embedding", node_field_name="embedding"), + ] ), - }, + target=cocoindex.storages.Neo4jRelationshipEnd( + label="Entity", + fields=[ + cocoindex.storages.Neo4jFieldMapping( + field_name="object", node_field_name="value"), + cocoindex.storages.Neo4jFieldMapping( + field_name="object_embedding", node_field_name="embedding"), + ] + ), + nodes={ + "Entity": cocoindex.storages.Neo4jRelationshipNode( + primary_key_fields=["value"], + vector_indexes=[ + cocoindex.VectorIndexDef( + field_name="embedding", + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ), + ], + ), + }, + ), ), primary_key_fields=["id"], ) diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index a543d515..056e6cd4 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -117,7 +117,10 @@ def dump_engine_object(v: Any) -> Any: nanos = int((total_secs - secs) * 1e9) return {'secs': secs, 'nanos': nanos} elif hasattr(v, '__dict__'): - return {k: dump_engine_object(v) for k, v in v.__dict__.items()} + s = {k: dump_engine_object(v) for k, v in v.__dict__.items()} + if hasattr(v, 'kind') and 'kind' not in s: + s['kind'] = v.kind + return s elif isinstance(v, (list, tuple)): return [dump_engine_object(item) for item in v] elif isinstance(v, dict): diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index bdb2e7d1..e2e58342 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -21,7 +21,7 @@ class Qdrant(op.StorageSpec): api_key: str | None = None @dataclass -class Neo4jConnectionSpec: +class Neo4jConnection: """Connection spec for Neo4j.""" uri: str user: str @@ -37,22 +37,36 @@ class Neo4jFieldMapping: node_field_name: str | None = None @dataclass -class Neo4jRelationshipEndSpec: +class Neo4jRelationshipEnd: """Spec for a Neo4j node type.""" label: str fields: list[Neo4jFieldMapping] @dataclass -class Neo4jRelationshipNodeSpec: +class Neo4jRelationshipNode: """Spec for a Neo4j node type.""" primary_key_fields: Sequence[str] vector_indexes: Sequence[index.VectorIndexDef] = () -class Neo4jRelationship(op.StorageSpec): +@dataclass +class Neo4jNode: + """Spec for a Neo4j node type.""" + kind = "Node" + + label: str + +@dataclass +class Neo4jRelationship: + """Spec for a Neo4j relationship.""" + kind = "Relationship" + + rel_type: str + source: Neo4jRelationshipEnd + target: Neo4jRelationshipEnd + nodes: dict[str, Neo4jRelationshipNode] + +class Neo4j(op.StorageSpec): """Graph storage powered by Neo4j.""" connection: AuthEntryReference - rel_type: str - source: Neo4jRelationshipEndSpec - target: Neo4jRelationshipEndSpec - nodes: dict[str, Neo4jRelationshipNodeSpec] + mapping: Neo4jNode | Neo4jRelationship diff --git a/src/ops/registration.rs b/src/ops/registration.rs index 97ab6e32..a6195ea0 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -15,8 +15,7 @@ fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result Arc::new(storages::postgres::Factory::default()).register(registry)?; Arc::new(storages::qdrant::Factory::default()).register(registry)?; - let neo4j_pool = Arc::new(storages::neo4j::GraphPool::default()); - storages::neo4j::RelationshipFactory::new(neo4j_pool).register(registry)?; + storages::neo4j::Factory::new().register(registry)?; Ok(()) } diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 055eb181..03da3f20 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -3,7 +3,9 @@ use crate::setup::components::{self, State}; use crate::setup::{ResourceSetupStatusCheck, SetupChangeType}; use crate::{ops::sdk::*, setup::CombinedState}; +use indoc::formatdoc; use neo4rs::{BoltType, ConfigBuilder, Graph}; +use std::fmt::Write; use tokio::sync::OnceCell; const DEFAULT_DB: &str = "neo4j"; @@ -44,15 +46,32 @@ pub struct RelationshipNodeSpec { index_options: spec::IndexOptions, } +#[derive(Debug, Deserialize)] +pub struct NodeSpec { + label: String, +} + #[derive(Debug, Deserialize)] pub struct RelationshipSpec { - connection: AuthEntryReference, rel_type: String, source: RelationshipEndSpec, target: RelationshipEndSpec, nodes: BTreeMap, } +#[derive(Debug, Deserialize)] +#[serde(tag = "kind")] +pub enum RowMappingSpec { + Relationship(RelationshipSpec), + Node(NodeSpec), +} + +#[derive(Debug, Deserialize)] +pub struct Spec { + connection: AuthEntryReference, + mapping: RowMappingSpec, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] struct GraphKey { uri: String, @@ -68,17 +87,55 @@ impl GraphKey { } } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)] +enum ElementType { + Node(String), + Relationship(String), +} + +impl ElementType { + fn label(&self) -> &str { + match self { + ElementType::Node(label) => label, + ElementType::Relationship(label) => label, + } + } + + fn from_mapping_spec(spec: &RowMappingSpec) -> Self { + match spec { + RowMappingSpec::Relationship(spec) => ElementType::Relationship(spec.rel_type.clone()), + RowMappingSpec::Node(spec) => ElementType::Node(spec.label.clone()), + } + } + + fn matcher(&self, var_name: &str) -> String { + match self { + ElementType::Relationship(label) => format!("()-[{var_name}:{label}]->()"), + ElementType::Node(label) => format!("({var_name}:{label})"), + } + } +} + +impl std::fmt::Display for ElementType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ElementType::Node(label) => write!(f, "Node(label:{label})"), + ElementType::Relationship(rel_type) => write!(f, "Relationship(type:{rel_type})"), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub struct GraphRelationship { +pub struct GraphElement { connection: AuthEntryReference, - relationship: String, + typ: ElementType, } -impl GraphRelationship { - fn from_spec(spec: &RelationshipSpec) -> Self { +impl GraphElement { + fn from_spec(spec: &Spec) -> Self { Self { connection: spec.connection.clone(), - relationship: spec.rel_type.clone(), + typ: ElementType::from_mapping_spec(&spec.mapping), } } } @@ -132,22 +189,26 @@ struct AnalyzedNodeLabelInfo { key_fields: Vec, value_fields: Vec, } -pub struct RelationshipExportContext { + +pub struct ExportContext { connection_ref: AuthEntryReference, graph: Arc, + create_order: u8, + delete_cypher: String, insert_cypher: String, + delete_before_upsert: bool, key_field_params: Vec, key_fields: Vec, value_fields: Vec, + src_fields: Option, src_key_field_params: Vec, - src_fields: AnalyzedNodeLabelInfo, + tgt_fields: Option, tgt_key_field_params: Vec, - tgt_fields: AnalyzedNodeLabelInfo, } fn json_value_to_bolt_value(value: &serde_json::Value) -> Result { @@ -295,14 +356,15 @@ fn value_to_bolt(value: &Value, schema: &schema::ValueType) -> Result Ok(bolt_value) } -const REL_KEY_PARAM_PREFIX: &str = "rel_key"; -const REL_PROPS_PARAM: &str = "rel_props"; +const CORE_KEY_PARAM_PREFIX: &str = "key"; +const CORE_PROPS_PARAM: &str = "props"; const SRC_KEY_PARAM_PREFIX: &str = "source_key"; const SRC_PROPS_PARAM: &str = "source_props"; const TGT_KEY_PARAM_PREFIX: &str = "target_key"; const TGT_PROPS_PARAM: &str = "target_props"; +const CORE_ELEMENT_MATCHER_VAR: &str = "e"; -impl RelationshipExportContext { +impl ExportContext { fn build_key_field_params_n_literal<'a>( param_prefix: &str, key_fields: impl Iterator, @@ -321,100 +383,148 @@ impl RelationshipExportContext { fn new( graph: Arc, - spec: RelationshipSpec, + spec: Spec, key_fields: Vec, value_fields: Vec, - src_fields: AnalyzedNodeLabelInfo, - tgt_fields: AnalyzedNodeLabelInfo, + end_node_fields: Option<(AnalyzedNodeLabelInfo, AnalyzedNodeLabelInfo)>, ) -> Result { let (key_field_params, key_fields_literal) = Self::build_key_field_params_n_literal( - REL_KEY_PARAM_PREFIX, + CORE_KEY_PARAM_PREFIX, key_fields.iter().map(|f| &f.name), ); - let (src_key_field_params, src_key_fields_literal) = Self::build_key_field_params_n_literal( - SRC_KEY_PARAM_PREFIX, - src_fields.key_fields.iter().map(|f| &f.field_name), - ); - let (tgt_key_field_params, tgt_key_fields_literal) = Self::build_key_field_params_n_literal( - TGT_KEY_PARAM_PREFIX, - tgt_fields.key_fields.iter().map(|f| &f.field_name), - ); - - let delete_cypher = format!( - r#" -OPTIONAL MATCH (old_src)-[old_rel:{rel_type} {key_fields_literal}]->(old_tgt) - -DELETE old_rel - -WITH old_src, old_tgt -CALL {{ - WITH old_src - OPTIONAL MATCH (old_src)-[r]-() - WITH old_src, count(r) AS rels - WHERE rels = 0 - DELETE old_src - RETURN 0 AS _1 -}} - -CALL {{ - WITH old_tgt - OPTIONAL MATCH (old_tgt)-[r]-() - WITH old_tgt, count(r) AS rels - WHERE rels = 0 - DELETE old_tgt - RETURN 0 AS _2 -}} - -FINISH - "#, - rel_type = spec.rel_type, - ); - - let insert_cypher = format!( - r#" -MERGE (new_src:{src_node_label} {src_key_fields_literal}) -{optional_set_src_props} + let result = match spec.mapping { + RowMappingSpec::Node(node_spec) => { + let delete_cypher = formatdoc! {" + OPTIONAL MATCH (old_node:{label} {key_fields_literal}) + DELETE old_node + FINISH + ", + label = node_spec.label, + }; + + let insert_cypher = formatdoc! {" + MERGE (new_node:{label} {key_fields_literal}) + {optional_set_props} + FINISH + ", + label = node_spec.label, + optional_set_props = if value_fields.is_empty() { + "".to_string() + } else { + format!("SET new_node += ${CORE_PROPS_PARAM}\n") + }, + }; + + Self { + connection_ref: spec.connection, + graph, + create_order: 0, + delete_cypher, + insert_cypher, + delete_before_upsert: false, + key_field_params, + key_fields, + value_fields, + src_key_field_params: vec![], + src_fields: None, + tgt_key_field_params: vec![], + tgt_fields: None, + } + } + RowMappingSpec::Relationship(rel_spec) => { + let delete_cypher = formatdoc! {" + OPTIONAL MATCH (old_src)-[old_rel:{rel_type} {key_fields_literal}]->(old_tgt) -MERGE (new_tgt:{tgt_node_label} {tgt_key_fields_literal}) -{optional_set_tgt_props} + DELETE old_rel -MERGE (new_src)-[new_rel:{rel_type} {key_fields_literal}]->(new_tgt) -{optional_set_rel_props} + WITH old_src, old_tgt + CALL {{ + WITH old_src + OPTIONAL MATCH (old_src)-[r]-() + WITH old_src, count(r) AS rels + WHERE rels = 0 + DELETE old_src + RETURN 0 AS _1 + }} -FINISH - "#, - src_node_label = spec.source.label, - optional_set_src_props = if src_fields.value_fields.is_empty() { - "".to_string() - } else { - format!("SET new_src += ${SRC_PROPS_PARAM}\n") - }, - tgt_node_label = spec.target.label, - optional_set_tgt_props = if tgt_fields.value_fields.is_empty() { - "".to_string() - } else { - format!("SET new_tgt += ${TGT_PROPS_PARAM}\n") - }, - rel_type = spec.rel_type, - optional_set_rel_props = if value_fields.is_empty() { - "".to_string() - } else { - format!("SET new_rel += ${REL_PROPS_PARAM}\n") - }, - ); - Ok(Self { - connection_ref: spec.connection, - graph, - delete_cypher, - insert_cypher, - key_field_params, - key_fields, - value_fields, - src_key_field_params, - src_fields, - tgt_key_field_params, - tgt_fields, - }) + CALL {{ + WITH old_tgt + OPTIONAL MATCH (old_tgt)-[r]-() + WITH old_tgt, count(r) AS rels + WHERE rels = 0 + DELETE old_tgt + RETURN 0 AS _2 + }} + + FINISH + ", + rel_type = rel_spec.rel_type, + }; + + let (src_fields, tgt_fields) = match end_node_fields { + Some((src_fields, tgt_fields)) => (src_fields, tgt_fields), + None => anyhow::bail!("Relationship spec requires source / target fields"), + }; + let (src_key_field_params, src_key_fields_literal) = + Self::build_key_field_params_n_literal( + SRC_KEY_PARAM_PREFIX, + src_fields.key_fields.iter().map(|f| &f.field_name), + ); + let (tgt_key_field_params, tgt_key_fields_literal) = + Self::build_key_field_params_n_literal( + TGT_KEY_PARAM_PREFIX, + tgt_fields.key_fields.iter().map(|f| &f.field_name), + ); + + let insert_cypher = formatdoc! {" + MERGE (new_src:{src_node_label} {src_key_fields_literal}) + {optional_set_src_props} + + MERGE (new_tgt:{tgt_node_label} {tgt_key_fields_literal}) + {optional_set_tgt_props} + + MERGE (new_src)-[new_rel:{rel_type} {key_fields_literal}]->(new_tgt) + {optional_set_rel_props} + + FINISH + ", + src_node_label = rel_spec.source.label, + optional_set_src_props = if src_fields.value_fields.is_empty() { + "".to_string() + } else { + format!("SET new_src += ${SRC_PROPS_PARAM}\n") + }, + tgt_node_label = rel_spec.target.label, + optional_set_tgt_props = if tgt_fields.value_fields.is_empty() { + "".to_string() + } else { + format!("SET new_tgt += ${TGT_PROPS_PARAM}\n") + }, + rel_type = rel_spec.rel_type, + optional_set_rel_props = if value_fields.is_empty() { + "".to_string() + } else { + format!("SET new_rel += ${CORE_PROPS_PARAM}\n") + }, + }; + Self { + connection_ref: spec.connection, + graph, + create_order: 1, + delete_cypher, + insert_cypher, + delete_before_upsert: true, + key_field_params, + key_fields, + value_fields, + src_key_field_params, + src_fields: Some(src_fields), + tgt_key_field_params, + tgt_fields: Some(tgt_fields), + } + } + }; + Ok(result) } fn bind_key_field_params<'a>( @@ -449,56 +559,67 @@ FINISH upsert: &ExportTargetUpsertEntry, queries: &mut Vec, ) -> Result<()> { - queries - .push(self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), &upsert.key)?); + if self.delete_before_upsert { + queries.push( + self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), &upsert.key)?, + ); + } let value = &upsert.value; let mut insert_cypher = self.bind_rel_key_field_params(neo4rs::query(&self.insert_cypher), &upsert.key)?; - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.src_key_field_params, - self.src_fields - .key_fields - .iter() - .map(|f| (&f.value_type, &value.fields[f.field_idx])), - )?; - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.tgt_key_field_params, - self.tgt_fields - .key_fields - .iter() - .map(|f| (&f.value_type, &value.fields[f.field_idx])), - )?; - if !self.src_fields.value_fields.is_empty() { - insert_cypher = insert_cypher.param( - SRC_PROPS_PARAM, - mapped_field_values_to_bolt( - self.src_fields - .value_fields - .iter() - .map(|f| &value.fields[f.field_idx]), - self.src_fields.value_fields.iter(), - )?, - ); + if let Some(src_fields) = &self.src_fields { + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.src_key_field_params, + src_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; + + if !src_fields.value_fields.is_empty() { + insert_cypher = insert_cypher.param( + SRC_PROPS_PARAM, + mapped_field_values_to_bolt( + src_fields + .value_fields + .iter() + .map(|f| &value.fields[f.field_idx]), + src_fields.value_fields.iter(), + )?, + ); + } } - if !self.tgt_fields.value_fields.is_empty() { - insert_cypher = insert_cypher.param( - TGT_PROPS_PARAM, - mapped_field_values_to_bolt( - self.tgt_fields - .value_fields - .iter() - .map(|f| &value.fields[f.field_idx]), - self.tgt_fields.value_fields.iter(), - )?, - ); + + if let Some(tgt_fields) = &self.tgt_fields { + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.tgt_key_field_params, + tgt_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; + + if !tgt_fields.value_fields.is_empty() { + insert_cypher = insert_cypher.param( + TGT_PROPS_PARAM, + mapped_field_values_to_bolt( + tgt_fields + .value_fields + .iter() + .map(|f| &value.fields[f.field_idx]), + tgt_fields.value_fields.iter(), + )?, + ); + } } + if !self.value_fields.is_empty() { insert_cypher = insert_cypher.param( - REL_PROPS_PARAM, + CORE_PROPS_PARAM, mapped_field_values_to_bolt( self.value_fields.iter().map(|f| &value.fields[f.field_idx]), self.value_fields.iter(), @@ -524,33 +645,32 @@ FINISH pub struct RelationshipSetupState { key_field_names: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] - node_labels: Vec, + dependent_node_labels: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] sub_components: Vec, } impl RelationshipSetupState { fn new( - spec: &RelationshipSpec, + spec: &Spec, key_field_names: Vec, index_options: &IndexOptions, - rel_value_fields_info: &[AnalyzedGraphFieldMapping], - src_label_info: &AnalyzedNodeLabelInfo, - tgt_label_info: &AnalyzedNodeLabelInfo, + value_fields_info: &[AnalyzedGraphFieldMapping], + end_nodes_label_info: Option<&(AnalyzedNodeLabelInfo, AnalyzedNodeLabelInfo)>, ) -> Result { let mut sub_components = vec![]; sub_components.push(ComponentState { - object_label: ObjectLabel::Relationship(spec.rel_type.clone()), + object_label: ElementType::from_mapping_spec(&spec.mapping), index_def: IndexDef::KeyConstraint { field_names: key_field_names.clone(), }, }); for index_def in index_options.vector_indexes.iter() { sub_components.push(ComponentState { - object_label: ObjectLabel::Relationship(spec.rel_type.clone()), + object_label: ElementType::from_mapping_spec(&spec.mapping), index_def: IndexDef::from_vector_index_def( index_def, - &rel_value_fields_info + &value_fields_info .iter() .find(|f| f.field_name == index_def.field_name) .ok_or_else(|| { @@ -563,36 +683,49 @@ impl RelationshipSetupState { )?, }); } - for (label, node) in spec.nodes.iter() { - sub_components.push(ComponentState { - object_label: ObjectLabel::Node(label.clone()), - index_def: IndexDef::KeyConstraint { - field_names: key_field_names.clone(), - }, - }); - for index_def in &node.index_options.vector_indexes { - sub_components.push(ComponentState { - object_label: ObjectLabel::Node(label.clone()), - index_def: IndexDef::from_vector_index_def( - index_def, - [src_label_info, tgt_label_info] - .into_iter() - .flat_map(|v| v.key_fields.iter().chain(v.value_fields.iter())) - .find(|f| f.field_name == index_def.field_name) - .map(|f| &f.value_type) - .ok_or_else(|| { - api_error!( - "Unknown field name for vector index: {}", - index_def.field_name - ) - })?, - )?, - }); + let mut dependent_node_labels = vec![]; + match &spec.mapping { + RowMappingSpec::Node(_) => {} + RowMappingSpec::Relationship(rel_spec) => { + let (src_label_info, tgt_label_info) = end_nodes_label_info.ok_or_else(|| { + anyhow!( + "Expect `end_nodes_label_info` existing for relationship `{}`", + rel_spec.rel_type + ) + })?; + for (label, node) in rel_spec.nodes.iter() { + sub_components.push(ComponentState { + object_label: ElementType::Node(label.clone()), + index_def: IndexDef::KeyConstraint { + field_names: key_field_names.clone(), + }, + }); + for index_def in &node.index_options.vector_indexes { + sub_components.push(ComponentState { + object_label: ElementType::Node(label.clone()), + index_def: IndexDef::from_vector_index_def( + index_def, + [src_label_info, tgt_label_info] + .into_iter() + .flat_map(|v| v.key_fields.iter().chain(v.value_fields.iter())) + .find(|f| f.field_name == index_def.field_name) + .map(|f| &f.value_type) + .ok_or_else(|| { + api_error!( + "Unknown field name for vector index: {}", + index_def.field_name + ) + })?, + )?, + }); + } + } + dependent_node_labels.extend(rel_spec.nodes.keys().cloned()); } - } + }; Ok(Self { key_field_names, - node_labels: spec.nodes.keys().cloned().collect(), + dependent_node_labels, sub_components, }) } @@ -616,8 +749,8 @@ impl IntoIterator for RelationshipSetupState { } #[derive(Debug)] struct DataClearAction { - rel_type: String, - node_labels: Vec, + core_elem_type: ElementType, + dependent_node_labels: Vec, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -640,21 +773,6 @@ struct ComponentKey { name: String, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -enum ObjectLabel { - Node(String), - Relationship(String), -} - -impl ObjectLabel { - fn label(&self) -> &str { - match self { - ObjectLabel::Node(label) => label, - ObjectLabel::Relationship(label) => label, - } - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] enum IndexDef { KeyConstraint { @@ -690,15 +808,15 @@ impl IndexDef { #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct ComponentState { - object_label: ObjectLabel, + object_label: ElementType, index_def: IndexDef, } impl components::State for ComponentState { fn key(&self) -> ComponentKey { let prefix = match &self.object_label { - ObjectLabel::Relationship(_) => "r", - ObjectLabel::Node(_) => "n", + ElementType::Relationship(_) => "r", + ElementType::Node(_) => "n", }; let label = self.object_label.label(); match &self.index_def { @@ -755,10 +873,8 @@ impl components::Operator for SetupComponentOperator { async fn create(&self, state: &ComponentState) -> Result<()> { let graph = self.graph_pool.get_graph(&self.conn_spec).await?; let key = state.key(); - let (matcher, qualifier) = match &state.object_label { - ObjectLabel::Relationship(label) => (format!("()-[r:{label}]->()"), "r"), - ObjectLabel::Node(label) => (format!("(n:{label})"), "n"), - }; + let qualifier = CORE_ELEMENT_MATCHER_VAR; + let matcher = state.object_label.matcher(qualifier); let query = neo4rs::query(&match &state.index_def { IndexDef::KeyConstraint { field_names } => { format!( @@ -819,7 +935,7 @@ struct SetupStatusCheck { impl SetupStatusCheck { fn new( - key: GraphRelationship, + key: GraphElement, graph_pool: Arc, conn_spec: ConnectionSpec, desired_state: Option<&RelationshipSetupState>, @@ -835,8 +951,8 @@ impl SetupStatusCheck { }) }) .map(|existing_current| DataClearAction { - rel_type: key.relationship.clone(), - node_labels: existing_current.node_labels.clone(), + core_elem_type: key.typ.clone(), + dependent_node_labels: existing_current.dependent_node_labels.clone(), }); let change_type = match (desired_state, existing.possible_versions().next()) { @@ -866,11 +982,20 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { fn describe_changes(&self) -> Vec { let mut result = vec![]; if let Some(data_clear) = &self.data_clear { - result.push(format!( - "Clear data for relationship {}; nodes {}", - data_clear.rel_type, - data_clear.node_labels.iter().join(", "), - )); + let mut desc = format!("Clear data for {}", data_clear.core_elem_type); + if !data_clear.dependent_node_labels.is_empty() { + write!( + &mut desc, + "; dependents {}", + data_clear + .dependent_node_labels + .iter() + .map(|l| format!("{}", ElementType::Node(l.clone()))) + .join(", ") + ) + .unwrap(); + } + result.push(desc); } result } @@ -885,16 +1010,17 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { let delete_rel_query = neo4rs::query(&format!( r#" CALL {{ - MATCH ()-[r:{rel_type}]->() - WITH r - DELETE r + MATCH {matcher} + WITH {var_name} + DELETE {var_name} }} IN TRANSACTIONS "#, - rel_type = data_clear.rel_type + matcher = data_clear.core_elem_type.matcher(CORE_ELEMENT_MATCHER_VAR), + var_name = CORE_ELEMENT_MATCHER_VAR, )); graph.run(delete_rel_query).await?; - for node_label in &data_clear.node_labels { + for node_label in &data_clear.dependent_node_labels { let delete_node_query = neo4rs::query(&format!( r#" CALL {{ @@ -912,31 +1038,27 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { } } /// Factory for Neo4j relationships -pub struct RelationshipFactory { +pub struct Factory { graph_pool: Arc, } -impl RelationshipFactory { - pub fn new(graph_pool: Arc) -> Self { - Self { graph_pool } +impl Factory { + pub fn new() -> Self { + Self { + graph_pool: Arc::default(), + } } } -struct NodeLabelAnalyzer<'a> { +struct DependentNodeLabelAnalyzer<'a> { label_name: &'a str, fields: IndexMap<&'a str, AnalyzedGraphFieldMapping>, remaining_fields: HashMap<&'a str, &'a FieldMapping>, - index_options: &'a IndexOptions, + index_options: Option<&'a IndexOptions>, } -impl<'a> NodeLabelAnalyzer<'a> { +impl<'a> DependentNodeLabelAnalyzer<'a> { fn new(rel_spec: &'a RelationshipSpec, rel_end_spec: &'a RelationshipEndSpec) -> Result { - let node_spec = rel_spec.nodes.get(&rel_end_spec.label).ok_or_else(|| { - anyhow!( - "Node label `{}` not found in relationship spec", - rel_end_spec.label - ) - })?; Ok(Self { label_name: rel_end_spec.label.as_str(), fields: IndexMap::new(), @@ -945,7 +1067,10 @@ impl<'a> NodeLabelAnalyzer<'a> { .iter() .map(|f| (f.field_name.as_str(), f)) .collect(), - index_options: &node_spec.index_options, + index_options: rel_spec + .nodes + .get(&rel_end_spec.label) + .and_then(|node_spec| Some(&node_spec.index_options)), }) } @@ -975,20 +1100,23 @@ impl<'a> NodeLabelAnalyzer<'a> { } let mut fields = self.fields; let mut key_fields = vec![]; - for key_field in self - .index_options - .primary_key_fields - .iter() - .flat_map(|f| f.iter()) - { - let e = fields.shift_remove(key_field.as_str()).ok_or_else(|| { - anyhow!( - "Key field `{}` not mapped in Node label `{}`", - key_field, - self.label_name - ) - })?; - key_fields.push(e); + if let Some(index_options) = self.index_options { + for key_field in index_options + .primary_key_fields + .iter() + .flat_map(|f| f.iter()) + { + let e = fields.shift_remove(key_field.as_str()).ok_or_else(|| { + anyhow!( + "Key field `{}` not mapped in Node label `{}`", + key_field, + self.label_name + ) + })?; + key_fields.push(e); + } + } else { + key_fields = std::mem::take(&mut fields).into_values().collect(); } if key_fields.is_empty() { anyhow::bail!( @@ -1004,51 +1132,69 @@ impl<'a> NodeLabelAnalyzer<'a> { } #[async_trait] -impl StorageFactoryBase for RelationshipFactory { - type Spec = RelationshipSpec; +impl StorageFactoryBase for Factory { + type Spec = Spec; type SetupState = RelationshipSetupState; - type Key = GraphRelationship; - type ExportContext = RelationshipExportContext; + type Key = GraphElement; + type ExportContext = ExportContext; fn name(&self) -> &str { - "Neo4jRelationship" + "Neo4j" } fn build( self: Arc, _name: String, - spec: RelationshipSpec, + spec: Spec, key_fields_schema: Vec, value_fields_schema: Vec, index_options: IndexOptions, context: Arc, ) -> Result> { - let setup_key = GraphRelationship::from_spec(&spec); - - let mut src_label_analyzer = NodeLabelAnalyzer::new(&spec, &spec.source)?; - let mut tgt_label_analyzer = NodeLabelAnalyzer::new(&spec, &spec.target)?; - let mut rel_value_fields_info = vec![]; - for (field_idx, field_schema) in value_fields_schema.iter().enumerate() { - if !src_label_analyzer.process_field(field_idx, field_schema) - && !tgt_label_analyzer.process_field(field_idx, field_schema) - { - rel_value_fields_info.push(AnalyzedGraphFieldMapping { - field_idx, - field_name: field_schema.name.clone(), - value_type: field_schema.value_type.typ.clone(), - }); + let setup_key = GraphElement::from_spec(&spec); + + let (value_fields_info, rel_end_label_info) = match &spec.mapping { + RowMappingSpec::Node(_) => ( + value_fields_schema + .into_iter() + .enumerate() + .map(|(field_idx, field_schema)| AnalyzedGraphFieldMapping { + field_idx, + field_name: field_schema.name.clone(), + value_type: field_schema.value_type.typ.clone(), + }) + .collect(), + None, + ), + RowMappingSpec::Relationship(rel_spec) => { + let mut src_label_analyzer = + DependentNodeLabelAnalyzer::new(&rel_spec, &rel_spec.source)?; + let mut tgt_label_analyzer = + DependentNodeLabelAnalyzer::new(&rel_spec, &rel_spec.target)?; + let mut value_fields_info = vec![]; + for (field_idx, field_schema) in value_fields_schema.iter().enumerate() { + if !src_label_analyzer.process_field(field_idx, field_schema) + && !tgt_label_analyzer.process_field(field_idx, field_schema) + { + value_fields_info.push(AnalyzedGraphFieldMapping { + field_idx, + field_name: field_schema.name.clone(), + value_type: field_schema.value_type.typ.clone(), + }); + } + } + let src_label_info = src_label_analyzer.build()?; + let tgt_label_info = tgt_label_analyzer.build()?; + (value_fields_info, Some((src_label_info, tgt_label_info))) } - } - let src_label_info = src_label_analyzer.build()?; - let tgt_label_info = tgt_label_analyzer.build()?; + }; let desired_setup_state = RelationshipSetupState::new( &spec, key_fields_schema.iter().map(|f| f.name.clone()).collect(), &index_options, - &rel_value_fields_info, - &src_label_info, - &tgt_label_info, + &value_fields_info, + rel_end_label_info.as_ref(), )?; let conn_spec = context @@ -1056,13 +1202,12 @@ impl StorageFactoryBase for RelationshipFactory { .get::(&spec.connection)?; let executors = async move { let graph = self.graph_pool.get_graph(&conn_spec).await?; - let executor = Arc::new(RelationshipExportContext::new( + let executor = Arc::new(ExportContext::new( graph, spec, key_fields_schema, - rel_value_fields_info, - src_label_info, - tgt_label_info, + value_fields_info, + rel_end_label_info, )?); Ok(TypedExportTargetExecutors { export_context: executor, @@ -1079,7 +1224,7 @@ impl StorageFactoryBase for RelationshipFactory { fn check_setup_status( &self, - key: GraphRelationship, + key: GraphElement, desired: Option, existing: CombinedState, auth_registry: &Arc, @@ -1111,13 +1256,13 @@ impl StorageFactoryBase for RelationshipFactory { Ok(desired.check_compatible(existing)) } - fn describe_resource(&self, key: &GraphRelationship) -> Result { - Ok(format!("Neo4j relationship {}", key.relationship)) + fn describe_resource(&self, key: &GraphElement) -> Result { + Ok(format!("Neo4j {}", key.typ)) } async fn apply_mutation( &self, - mutations: Vec>, + mutations: Vec>, ) -> Result<()> { let mut muts_by_graph = HashMap::new(); for mut_with_ctx in mutations.iter() { @@ -1126,7 +1271,8 @@ impl StorageFactoryBase for RelationshipFactory { .or_insert_with(Vec::new) .push(mut_with_ctx); } - for muts in muts_by_graph.values() { + for muts in muts_by_graph.values_mut() { + muts.sort_by_key(|m| m.export_context.create_order); let graph = &muts[0].export_context.graph; retriable::run( || async { @@ -1137,7 +1283,7 @@ impl StorageFactoryBase for RelationshipFactory { export_ctx.add_upsert_queries(upsert, &mut queries)?; } } - for mut_with_ctx in muts.iter() { + for mut_with_ctx in muts.iter().rev() { let export_ctx = &mut_with_ctx.export_context; for delete_key in mut_with_ctx.mutation.delete_keys.iter() { export_ctx.add_delete_queries(delete_key, &mut queries)?;