From 352ea686a7d0789ed1adf9e77e2b401859115e96 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 23 Apr 2025 09:11:50 -0700 Subject: [PATCH] feat(neo4j)!: add `ReferencedNode` as declaration --- examples/docs_to_knowledge_graph/main.py | 29 ++- python/cocoindex/flow.py | 12 +- python/cocoindex/op.py | 3 - python/cocoindex/storages.py | 11 +- src/ops/storages/neo4j.rs | 304 +++++++++++++---------- src/ops/storages/spec.rs | 20 +- 6 files changed, 213 insertions(+), 166 deletions(-) diff --git a/examples/docs_to_knowledge_graph/main.py b/examples/docs_to_knowledge_graph/main.py index c20acbd0..b2c4b4a8 100644 --- a/examples/docs_to_knowledge_graph/main.py +++ b/examples/docs_to_knowledge_graph/main.py @@ -92,6 +92,7 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D id=cocoindex.GeneratedField.UUID, entity=relationship["object"], filename=doc["filename"], location=chunk["location"], ) + document_node.export( "document_node", cocoindex.storages.Neo4j( @@ -99,6 +100,23 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D mapping=cocoindex.storages.NodeMapping(label="Document")), primary_key_fields=["filename"], ) + flow_builder.declare( + cocoindex.storages.Neo4jDeclarations( + connection=conn_spec, + referenced_nodes=[ + cocoindex.storages.ReferencedNode( + label="Entity", + primary_key_fields=["value"], + vector_indexes=[ + cocoindex.VectorIndexDef( + field_name="embedding", + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ), + ], + ) + ] + ) + ) entity_relationship.export( "entity_relationship", cocoindex.storages.Neo4j( @@ -123,17 +141,6 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D source="object_embedding", target="embedding"), ] ), - nodes_storage_spec={ - "Entity": cocoindex.storages.NodeStorageSpec( - primary_key_fields=["value"], - vector_indexes=[ - cocoindex.VectorIndexDef( - field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, - ), - ], - ), - }, ), ), primary_key_fields=["id"], diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index f6502623..064a1c52 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -290,12 +290,6 @@ def export(self, name: str, target_spec: op.StorageSpec, /, *, name, _spec_kind(target_spec), dump_engine_object(target_spec), dump_engine_object(index_options), self._engine_data_collector, setup_by_user) - def declare(self, spec: op.DeclarationSpec): - """ - Add a declaration to the flow. - """ - self._flow_builder_state.engine_flow_builder.declare(dump_engine_object(spec)) - _flow_name_builder = _NameBuilder() @@ -361,6 +355,12 @@ def add_source(self, spec: op.SourceSpec, /, *, name ) + def declare(self, spec: op.DeclarationSpec): + """ + Add a declaration to the flow. + """ + self._state.engine_flow_builder.declare(dump_engine_object(spec)) + @dataclass class FlowLiveUpdaterOptions: """ diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 90259034..e524595b 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -7,7 +7,6 @@ from typing import get_type_hints, Protocol, Any, Callable, Awaitable, dataclass_transform from enum import Enum -from functools import partial from .typing import encode_enriched_type from .convert import to_engine_value, make_engine_value_converter @@ -43,8 +42,6 @@ class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: di class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods """A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)""" - kind: str - class Executor(Protocol): """An executor for an operation.""" op_category: OpCategory diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index b143f3a4..8a038cd5 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -43,8 +43,9 @@ class NodeReferenceMapping: fields: list[TargetFieldMapping] @dataclass -class NodeStorageSpec: +class ReferencedNode: """Storage spec for a graph node.""" + label: str primary_key_fields: Sequence[str] vector_indexes: Sequence[index.VectorIndexDef] = () @@ -63,10 +64,16 @@ class RelationshipMapping: rel_type: str source: NodeReferenceMapping target: NodeReferenceMapping - nodes_storage_spec: dict[str, NodeStorageSpec] | None = None class Neo4j(op.StorageSpec): """Graph storage powered by Neo4j.""" connection: AuthEntryReference mapping: NodeMapping | RelationshipMapping + +class Neo4jDeclarations(op.DeclarationSpec): + """Declarations for Neo4j.""" + + kind = "Neo4j" + connection: AuthEntryReference + referenced_nodes: Sequence[ReferencedNode] = () diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 098b5f15..cd5b09de 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -1,7 +1,7 @@ use crate::prelude::*; use super::spec::{ - GraphElementMapping, NodeReferenceMapping, RelationshipMapping, TargetFieldMapping, + GraphDeclarations, GraphElementMapping, NodeReferenceMapping, TargetFieldMapping, }; use crate::setup::components::{self, State}; use crate::setup::{ResourceSetupStatusCheck, SetupChangeType}; @@ -28,6 +28,13 @@ pub struct Spec { mapping: GraphElementMapping, } +#[derive(Debug, Deserialize)] +pub struct Declaration { + connection: spec::AuthEntryReference, + #[serde(flatten)] + decl: GraphDeclarations, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] struct GraphKey { uri: String, @@ -602,7 +609,7 @@ impl ExportContext { } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct RelationshipSetupState { +pub struct SetupState { key_field_names: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] dependent_node_labels: Vec, @@ -610,90 +617,38 @@ pub struct RelationshipSetupState { sub_components: Vec, } -impl RelationshipSetupState { +impl SetupState { fn new( - spec: &Spec, + object_label: &ElementType, key_field_names: Vec, index_options: &IndexOptions, - value_fields_info: &[AnalyzedGraphFieldMapping], - end_nodes_label_info: Option<&(AnalyzedNodeLabelInfo, AnalyzedNodeLabelInfo)>, + field_types: &HashMap, + dependent_node_labels: Option>, ) -> Result { let mut sub_components = vec![]; sub_components.push(ComponentState { - object_label: ElementType::from_mapping_spec(&spec.mapping), + object_label: object_label.clone(), index_def: IndexDef::KeyConstraint { field_names: key_field_names.clone(), }, }); for index_def in index_options.vector_indexes.iter() { sub_components.push(ComponentState { - object_label: ElementType::from_mapping_spec(&spec.mapping), + object_label: object_label.clone(), index_def: IndexDef::from_vector_index_def( index_def, - &value_fields_info - .iter() - .find(|f| f.field_name == index_def.field_name) - .ok_or_else(|| { - api_error!( - "Unknown field name for vector index: {}", - index_def.field_name - ) - })? - .value_type, + field_types.get(&index_def.field_name).ok_or_else(|| { + api_error!( + "Unknown field name for vector index: {}", + index_def.field_name + ) + })?, )?, }); } - let mut dependent_node_labels = vec![]; - match &spec.mapping { - GraphElementMapping::Node(_) => {} - GraphElementMapping::Relationship(rel_spec) => { - let (src_label_info, tgt_label_info) = end_nodes_label_info.ok_or_else(|| { - anyhow!( - "Expect `end_nodes_label_info` existing for relationship `{}`", - rel_spec.rel_type - ) - })?; - for (label, node) in rel_spec.nodes_storage_spec.iter().flatten() { - if let Some(primary_key_fields) = &node.index_options.primary_key_fields { - sub_components.push(ComponentState { - object_label: ElementType::Node(label.clone()), - index_def: IndexDef::KeyConstraint { - field_names: primary_key_fields.clone(), - }, - }); - } - for index_def in &node.index_options.vector_indexes { - sub_components.push(ComponentState { - object_label: ElementType::Node(label.clone()), - index_def: IndexDef::from_vector_index_def( - index_def, - [src_label_info, tgt_label_info] - .into_iter() - .flat_map(|v| v.key_fields.iter().chain(v.value_fields.iter())) - .find(|f| f.field_name == index_def.field_name) - .map(|f| &f.value_type) - .ok_or_else(|| { - api_error!( - "Unknown field name for vector index: {}", - index_def.field_name - ) - })?, - )?, - }); - } - } - dependent_node_labels.extend( - rel_spec - .nodes_storage_spec - .iter() - .flat_map(|nodes| nodes.keys()) - .cloned(), - ); - } - }; Ok(Self { key_field_names, - dependent_node_labels, + dependent_node_labels: dependent_node_labels.unwrap_or_default(), sub_components, }) } @@ -707,7 +662,7 @@ impl RelationshipSetupState { } } -impl IntoIterator for RelationshipSetupState { +impl IntoIterator for SetupState { type Item = ComponentState; type IntoIter = std::vec::IntoIter; @@ -715,9 +670,8 @@ impl IntoIterator for RelationshipSetupState { self.sub_components.into_iter() } } -#[derive(Debug)] +#[derive(Debug, Default)] struct DataClearAction { - core_elem_type: ElementType, dependent_node_labels: Vec, } @@ -811,7 +765,7 @@ struct SetupComponentOperator { impl components::Operator for SetupComponentOperator { type Key = ComponentKey; type State = ComponentState; - type SetupState = RelationshipSetupState; + type SetupState = SetupState; fn describe_key(&self, key: &Self::Key) -> String { format!("{} {}", key.kind.describe(), key.name) @@ -904,6 +858,7 @@ fn build_composite_field_names(qualifier: &str, field_names: &[String]) -> Strin #[derive(Derivative)] #[derivative(Debug)] struct SetupStatusCheck { + key: GraphElement, #[derivative(Debug = "ignore")] graph_pool: Arc, conn_spec: ConnectionSpec, @@ -916,25 +871,20 @@ impl SetupStatusCheck { key: GraphElement, graph_pool: Arc, conn_spec: ConnectionSpec, - desired_state: Option<&RelationshipSetupState>, - existing: &CombinedState, + desired_state: Option<&SetupState>, + existing: &CombinedState, ) -> Self { - let mut core_elem_type_to_clear = None; - let mut dependent_node_labels_to_clear = IndexSet::new(); + let mut data_clear: Option = None; for v in existing.possible_versions() { if desired_state.as_ref().is_none_or(|desired| { desired.check_compatible(v) == SetupStateCompatibility::NotCompatible }) { - if core_elem_type_to_clear.is_none() { - core_elem_type_to_clear = Some(key.typ.clone()); - } - dependent_node_labels_to_clear.extend(v.dependent_node_labels.iter().cloned()); + data_clear + .get_or_insert_default() + .dependent_node_labels + .extend(v.dependent_node_labels.iter().cloned()); } } - let data_clear = core_elem_type_to_clear.map(|core_elem_type| DataClearAction { - core_elem_type, - dependent_node_labels: dependent_node_labels_to_clear.into_iter().collect(), - }); let change_type = match (desired_state, existing.possible_versions().next()) { (Some(_), Some(_)) => { @@ -950,6 +900,7 @@ impl SetupStatusCheck { }; Self { + key, graph_pool, conn_spec, data_clear, @@ -963,7 +914,7 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { fn describe_changes(&self) -> Vec { let mut result = vec![]; if let Some(data_clear) = &self.data_clear { - let mut desc = format!("Clear data for {}", data_clear.core_elem_type); + let mut desc = "Clear data".to_string(); if !data_clear.dependent_node_labels.is_empty() { write!( &mut desc, @@ -996,9 +947,9 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { DELETE {var_name} }} IN TRANSACTIONS ", - matcher = data_clear.core_elem_type.matcher(CORE_ELEMENT_MATCHER_VAR), + matcher = self.key.typ.matcher(CORE_ELEMENT_MATCHER_VAR), var_name = CORE_ELEMENT_MATCHER_VAR, - optional_orphan_condition = match data_clear.core_elem_type { + optional_orphan_condition = match self.key.typ { ElementType::Node(_) => format!("WHERE NOT ({CORE_ELEMENT_MATCHER_VAR})--()"), _ => "".to_string(), }, @@ -1038,13 +989,13 @@ struct DependentNodeLabelAnalyzer<'a> { label_name: &'a str, fields: IndexMap<&'a str, AnalyzedGraphFieldMapping>, remaining_fields: HashMap<&'a str, &'a TargetFieldMapping>, - index_options: Option<&'a IndexOptions>, + primary_key_fields: &'a Vec, } impl<'a> DependentNodeLabelAnalyzer<'a> { fn new( - rel_spec: &'a RelationshipMapping, rel_end_spec: &'a NodeReferenceMapping, + index_options_map: &'a HashMap, ) -> Result { Ok(Self { label_name: rel_end_spec.label.as_str(), @@ -1054,31 +1005,38 @@ impl<'a> DependentNodeLabelAnalyzer<'a> { .iter() .map(|f| (f.source.as_str(), f)) .collect(), - index_options: rel_spec - .nodes_storage_spec - .as_ref() - .and_then(|nodes| nodes.get(&rel_end_spec.label)) - .and_then(|node_spec| Some(&node_spec.index_options)), + primary_key_fields: index_options_map + .get(&rel_end_spec.label) + .and_then(|o| o.primary_key_fields.as_ref()) + .ok_or_else(|| { + anyhow::anyhow!( + "No key fields specified for Node label `{}`", + rel_end_spec.label + ) + })?, }) } fn process_field(&mut self, field_idx: usize, field_schema: &FieldSchema) -> bool { - let field_info = match self.remaining_fields.remove(field_schema.name.as_str()) { - Some(field_info) => field_info, + let field_mapping = match self.remaining_fields.remove(field_schema.name.as_str()) { + Some(field_mapping) => field_mapping, None => return false, }; self.fields.insert( - field_info.get_target().as_str(), + field_mapping.get_target().as_str(), AnalyzedGraphFieldMapping { field_idx, - field_name: field_info.get_target().clone(), + field_name: field_mapping.get_target().clone(), value_type: field_schema.value_type.typ.clone(), }, ); true } - fn build(self) -> Result { + fn build( + self, + label_value_field_types: &mut HashMap>, + ) -> Result { if !self.remaining_fields.is_empty() { anyhow::bail!( "Fields not mapped for Node label `{}`: {}", @@ -1087,13 +1045,10 @@ impl<'a> DependentNodeLabelAnalyzer<'a> { ); } let mut fields = self.fields; - let mut key_fields = vec![]; - if let Some(index_options) = self.index_options { - for key_field in index_options - .primary_key_fields - .iter() - .flat_map(|f| f.iter()) - { + let key_fields = self + .primary_key_fields + .iter() + .map(|key_field| { let e = fields.shift_remove(key_field.as_str()).ok_or_else(|| { anyhow!( "Key field `{}` not mapped in Node label `{}`", @@ -1101,17 +1056,17 @@ impl<'a> DependentNodeLabelAnalyzer<'a> { self.label_name ) })?; - key_fields.push(e); - } - } else { - key_fields = std::mem::take(&mut fields).into_values().collect(); - } - if key_fields.is_empty() { - anyhow::bail!( - "No key fields specified for Node label `{}`", - self.label_name + Ok(e) + }) + .collect::>>()?; + label_value_field_types + .entry(self.label_name.to_string()) + .or_insert_with(HashMap::new) + .extend( + fields + .values() + .map(|f| (f.field_name.clone(), f.value_type.clone())), ); - } Ok(AnalyzedNodeLabelInfo { key_fields, value_fields: fields.into_values().collect(), @@ -1122,8 +1077,8 @@ impl<'a> DependentNodeLabelAnalyzer<'a> { #[async_trait] impl StorageFactoryBase for Factory { type Spec = Spec; - type DeclarationSpec = (); - type SetupState = RelationshipSetupState; + type DeclarationSpec = Declaration; + type SetupState = SetupState; type Key = GraphElement; type ExportContext = ExportContext; @@ -1134,21 +1089,39 @@ impl StorageFactoryBase for Factory { fn build( self: Arc, data_collections: Vec>, - _declarations: Vec<()>, + declarations: Vec, context: Arc, ) -> Result<( Vec>, - Vec<(GraphElement, RelationshipSetupState)>, + Vec<(GraphElement, SetupState)>, )> { + let node_labels_index_options = data_collections + .iter() + .filter_map(|d| match &d.spec.mapping { + GraphElementMapping::Node(n) => Some((n.label.clone(), d.index_options.clone())), + _ => None, + }) + .chain( + declarations + .iter() + .flat_map(|d| d.decl.referenced_nodes.iter()) + .map(|n| (n.label.clone(), n.index_options.clone())), + ) + .collect::>(); + let mut label_value_field_types = + HashMap::>::new(); let data_coll_output = data_collections .into_iter() .map(|d| { let setup_key = GraphElement::from_spec(&d.spec); - let (value_fields_info, rel_end_label_info) = match &d.spec.mapping { + let (value_fields_info, rel_end_label_info, dependent_node_labels) = match &d + .spec + .mapping + { GraphElementMapping::Node(_) => ( d.value_fields_schema - .into_iter() + .iter() .enumerate() .map(|(field_idx, field_schema)| AnalyzedGraphFieldMapping { field_idx, @@ -1157,12 +1130,17 @@ impl StorageFactoryBase for Factory { }) .collect(), None, + None, ), GraphElementMapping::Relationship(rel_spec) => { - let mut src_label_analyzer = - DependentNodeLabelAnalyzer::new(&rel_spec, &rel_spec.source)?; - let mut tgt_label_analyzer = - DependentNodeLabelAnalyzer::new(&rel_spec, &rel_spec.target)?; + let mut src_label_analyzer = DependentNodeLabelAnalyzer::new( + &rel_spec.source, + &node_labels_index_options, + )?; + let mut tgt_label_analyzer = DependentNodeLabelAnalyzer::new( + &rel_spec.target, + &node_labels_index_options, + )?; let mut value_fields_info = vec![]; for (field_idx, field_schema) in d.value_fields_schema.iter().enumerate() { if !src_label_analyzer.process_field(field_idx, field_schema) @@ -1175,18 +1153,35 @@ impl StorageFactoryBase for Factory { }); } } - let src_label_info = src_label_analyzer.build()?; - let tgt_label_info = tgt_label_analyzer.build()?; - (value_fields_info, Some((src_label_info, tgt_label_info))) + let src_label_info = + src_label_analyzer.build(&mut label_value_field_types)?; + let tgt_label_info = + tgt_label_analyzer.build(&mut label_value_field_types)?; + let dependent_node_labels: Vec = IndexSet::<&String>::from_iter([ + &rel_spec.source.label, + &rel_spec.target.label, + ]) + .into_iter() + .cloned() + .collect(); + ( + value_fields_info, + Some((src_label_info, tgt_label_info)), + Some(dependent_node_labels), + ) } }; - let desired_setup_state = RelationshipSetupState::new( - &d.spec, + let value_field_types = value_fields_info + .iter() + .map(|f| (f.field_name.clone(), f.value_type.clone())) + .collect::>(); + let desired_setup_state = SetupState::new( + &setup_key.typ, d.key_fields_schema.iter().map(|f| f.name.clone()).collect(), &d.index_options, - &value_fields_info, - rel_end_label_info.as_ref(), + &value_field_types, + dependent_node_labels, )?; let conn_spec = context @@ -1215,14 +1210,49 @@ impl StorageFactoryBase for Factory { }) }) .collect::>>()?; - Ok((data_coll_output, vec![])) + let decl_output = declarations + .into_iter() + .flat_map(|d| { + let label_value_field_types = &label_value_field_types; + d.decl.referenced_nodes.into_iter().map(move |n| { + let setup_key = GraphElement { + connection: d.connection.clone(), + typ: ElementType::Node(n.label.clone()), + }; + let primary_key_fields = n + .index_options + .primary_key_fields + .as_ref() + .ok_or_else(|| { + api_error!( + "No primary key fields specified for node label `{}`", + &n.label + ) + })? + .iter() + .map(|f| f.clone()) + .collect(); + let setup_state = SetupState::new( + &setup_key.typ, + primary_key_fields, + &n.index_options, + label_value_field_types.get(&n.label).ok_or_else(|| { + api_error!("Data for nodes with label `{}` not provided", n.label) + })?, + None, + )?; + Ok((setup_key, setup_state)) + }) + }) + .collect::>>()?; + Ok((data_coll_output, decl_output)) } fn check_setup_status( &self, key: GraphElement, - desired: Option, - existing: CombinedState, + desired: Option, + existing: CombinedState, auth_registry: &Arc, ) -> Result { let conn_spec = auth_registry.get::(&key.connection)?; @@ -1246,8 +1276,8 @@ impl StorageFactoryBase for Factory { fn check_state_compatibility( &self, - desired: &RelationshipSetupState, - existing: &RelationshipSetupState, + desired: &SetupState, + existing: &SetupState, ) -> Result { Ok(desired.check_compatible(existing)) } diff --git a/src/ops/storages/spec.rs b/src/ops/storages/spec.rs index c1d8772f..26ce8c39 100644 --- a/src/ops/storages/spec.rs +++ b/src/ops/storages/spec.rs @@ -22,12 +22,6 @@ pub struct NodeReferenceMapping { pub fields: Vec, } -#[derive(Debug, Deserialize)] -pub struct NodeStorageSpec { - #[serde(flatten)] - pub index_options: spec::IndexOptions, -} - #[derive(Debug, Deserialize)] pub struct NodeMapping { pub label: String, @@ -38,7 +32,6 @@ pub struct RelationshipMapping { pub rel_type: String, pub source: NodeReferenceMapping, pub target: NodeReferenceMapping, - pub nodes_storage_spec: Option>, } #[derive(Debug, Deserialize)] @@ -47,3 +40,16 @@ pub enum GraphElementMapping { Relationship(RelationshipMapping), Node(NodeMapping), } + +#[derive(Debug, Deserialize)] +pub struct ReferencedNodeSpec { + pub label: String, + + #[serde(flatten)] + pub index_options: spec::IndexOptions, +} + +#[derive(Debug, Deserialize)] +pub struct GraphDeclarations { + pub referenced_nodes: Vec, +}