diff --git a/examples/docs_to_kg/main.py b/examples/docs_to_kg/main.py index acb5a0ce6..50aa33134 100644 --- a/examples/docs_to_kg/main.py +++ b/examples/docs_to_kg/main.py @@ -55,11 +55,19 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D "Each relationship should be a tuple of (subject, predicate, object)."))) with chunk["relationships"]["relationships"].row() as relationship: + relationship["subject_embedding"] = relationship["subject"].transform( + cocoindex.functions.SentenceTransformerEmbed( + model="sentence-transformers/all-MiniLM-L6-v2")) + relationship["object_embedding"] = relationship["object"].transform( + cocoindex.functions.SentenceTransformerEmbed( + model="sentence-transformers/all-MiniLM-L6-v2")) relationships.collect( id=cocoindex.GeneratedField.UUID, subject=relationship["subject"], - predicate=relationship["predicate"], + subject_embedding=relationship["subject_embedding"], object=relationship["object"], + object_embedding=relationship["object_embedding"], + predicate=relationship["predicate"], ) relationships.export( @@ -69,14 +77,34 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D rel_type="RELATIONSHIP", source=cocoindex.storages.Neo4jRelationshipEndSpec( label="Entity", - fields=[cocoindex.storages.Neo4jFieldMapping(field_name="subject", node_field_name="value")] + fields=[ + cocoindex.storages.Neo4jFieldMapping( + field_name="subject", node_field_name="value"), + cocoindex.storages.Neo4jFieldMapping( + field_name="subject_embedding", node_field_name="embedding"), + ] ), target=cocoindex.storages.Neo4jRelationshipEndSpec( label="Entity", - fields=[cocoindex.storages.Neo4jFieldMapping(field_name="object", node_field_name="value")] + fields=[ + cocoindex.storages.Neo4jFieldMapping( + field_name="object", node_field_name="value"), + cocoindex.storages.Neo4jFieldMapping( + field_name="object_embedding", node_field_name="embedding"), + ] ), nodes={ - "Entity": cocoindex.storages.Neo4jRelationshipNodeSpec(key_field_name="value"), + "Entity": cocoindex.storages.Neo4jRelationshipNodeSpec( + index_options=cocoindex.IndexOptions( + primary_key_fields=["value"], + vector_index_defs=[ + cocoindex.VectorIndexDef( + field_name="embedding", + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ), + ], + ), + ), }, ), primary_key_fields=["id"], diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index ae8248353..b3c5b0100 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -6,7 +6,7 @@ from .flow import EvaluateAndDumpOptions, GeneratedField from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions from .llm import LlmSpec, LlmApiType -from .vector import VectorSimilarityMetric +from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry from .lib import * from ._engine import OpArgSchema \ No newline at end of file diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 77bec8406..eecb6a805 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from . import _engine -from . import vector +from . import index from . import op from .convert import dump_engine_object from .typing import encode_enriched_type @@ -268,7 +268,7 @@ def collect(self, **kwargs): def export(self, name: str, target_spec: op.StorageSpec, /, *, primary_key_fields: Sequence[str] | None = None, - vector_index: Sequence[tuple[str, vector.VectorSimilarityMetric]] = (), + vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (), setup_by_user: bool = False): """ Export the collected data to the specified target. diff --git a/python/cocoindex/index.py b/python/cocoindex/index.py new file mode 100644 index 000000000..90e12aa05 --- /dev/null +++ b/python/cocoindex/index.py @@ -0,0 +1,23 @@ +from enum import Enum +from dataclasses import dataclass + +class VectorSimilarityMetric(Enum): + COSINE_SIMILARITY = "CosineSimilarity" + L2_DISTANCE = "L2Distance" + INNER_PRODUCT = "InnerProduct" + +@dataclass +class VectorIndexDef: + """ + Define a vector index on a field. + """ + field_name: str + metric: VectorSimilarityMetric + +@dataclass +class IndexOptions: + """ + Options for an index. + """ + primary_key_fields: list[str] | None = None + vector_index_defs: list[VectorIndexDef] | None = None diff --git a/python/cocoindex/query.py b/python/cocoindex/query.py index 59e8d822f..cf9f57fbc 100644 --- a/python/cocoindex/query.py +++ b/python/cocoindex/query.py @@ -3,7 +3,7 @@ from threading import Lock from . import flow as fl -from . import vector +from . import index from . import _engine _handlers_lock = Lock() @@ -14,7 +14,7 @@ class SimpleSemanticsQueryInfo: """ Additional information about the query. """ - similarity_metric: vector.VectorSimilarityMetric + similarity_metric: index.VectorSimilarityMetric query_vector: list[float] vector_field_name: str @@ -39,7 +39,7 @@ def __init__( flow: fl.Flow, target_name: str, query_transform_flow: Callable[..., fl.DataSlice], - default_similarity_metric: vector.VectorSimilarityMetric = vector.VectorSimilarityMetric.COSINE_SIMILARITY) -> None: + default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY) -> None: engine_handler = None lock = Lock() @@ -66,7 +66,7 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler: return self._lazy_query_handler() def search(self, query: str, limit: int, vector_field_name: str | None = None, - similarity_matric: vector.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: + similarity_matric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: """ Search the index with the given query, limit, vector field name, and similarity metric. """ @@ -76,7 +76,7 @@ def search(self, query: str, limit: int, vector_field_name: str | None = None, fields = [field['name'] for field in internal_results['fields']] results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']] info = SimpleSemanticsQueryInfo( - similarity_metric=vector.VectorSimilarityMetric(internal_info['similarity_metric']), + similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']), query_vector=internal_info['query_vector'], vector_field_name=internal_info['vector_field_name'] ) diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index 77a5ea194..ec5d7326b 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from . import op +from . import index from .auth_registry import AuthEntryReference class Postgres(op.StorageSpec): """Storage powered by Postgres and pgvector.""" @@ -35,7 +36,7 @@ class Neo4jRelationshipEndSpec: class Neo4jRelationshipNodeSpec: """Spec for a Neo4j node type.""" key_field_name: str | None = None - + index_options: index.IndexOptions | None = None class Neo4jRelationship(op.StorageSpec): """Graph storage powered by Neo4j.""" diff --git a/python/cocoindex/vector.py b/python/cocoindex/vector.py deleted file mode 100644 index 301426989..000000000 --- a/python/cocoindex/vector.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - -class VectorSimilarityMetric(Enum): - COSINE_SIMILARITY = "CosineSimilarity" - L2_DISTANCE = "L2Distance" - INNER_PRODUCT = "InnerProduct" diff --git a/src/base/spec.rs b/src/base/spec.rs index 2321610ff..0989cacf8 100644 --- a/src/base/spec.rs +++ b/src/base/spec.rs @@ -211,6 +211,16 @@ pub enum VectorSimilarityMetric { InnerProduct, } +impl std::fmt::Display for VectorSimilarityMetric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorSimilarityMetric::CosineSimilarity => write!(f, "Cosine"), + VectorSimilarityMetric::L2Distance => write!(f, "L2"), + VectorSimilarityMetric::InnerProduct => write!(f, "InnerProduct"), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct VectorIndexDef { pub field_name: FieldName, diff --git a/src/base/value.rs b/src/base/value.rs index a95a70ca0..5a754f8ca 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -174,6 +174,21 @@ impl std::fmt::Display for KeyValue { } impl KeyValue { + pub fn fields_iter<'a>( + &'a self, + num_fields: usize, + ) -> Result> { + let slice = if num_fields == 1 { + std::slice::from_ref(self) + } else { + match self { + KeyValue::Struct(v) => v, + _ => api_bail!("Invalid key value type"), + } + }; + Ok(slice.iter()) + } + fn parts_from_str( values_iter: &mut impl Iterator, schema: &ValueType, diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 353b418ba..ea291da7d 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -39,8 +39,7 @@ pub struct RelationshipEndSpec { #[derive(Debug, Deserialize)] pub struct RelationshipNodeSpec { - #[serde(default)] - key_field_name: String, + index_options: spec::IndexOptions, } #[derive(Debug, Deserialize)] @@ -52,22 +51,6 @@ pub struct RelationshipSpec { nodes: BTreeMap, } -impl RelationshipSpec { - fn get_src_label_info(&self) -> Result<&RelationshipNodeSpec> { - Ok(self - .nodes - .get(self.source.label.as_str()) - .ok_or_else(|| api_error!("Source label `{}` not found", self.source.label))?) - } - - fn get_tgt_label_info(&self) -> Result<&RelationshipNodeSpec> { - Ok(self - .nodes - .get(self.target.label.as_str()) - .ok_or_else(|| api_error!("Target label `{}` not found", self.target.label))?) - } -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] struct GraphKey { uri: String, @@ -143,8 +126,8 @@ struct AnalyzedGraphFieldMapping { value_type: schema::ValueType, } -struct AnalyzedGraphFields { - key_field: AnalyzedGraphFieldMapping, +struct AnalyzedNodeLabelInfo { + key_fields: Vec, value_fields: Vec, } struct RelationshipStorageExecutor { @@ -152,11 +135,15 @@ struct RelationshipStorageExecutor { delete_cypher: String, insert_cypher: String, - key_field: FieldSchema, + key_field_params: Vec, + key_fields: Vec, value_fields: Vec, - src_fields: AnalyzedGraphFields, - tgt_fields: AnalyzedGraphFields, + src_key_field_params: Vec, + src_fields: AnalyzedNodeLabelInfo, + + tgt_key_field_params: Vec, + tgt_fields: AnalyzedNodeLabelInfo, } fn json_value_to_bolt_value(value: &serde_json::Value) -> Result { @@ -304,25 +291,54 @@ fn value_to_bolt(value: &Value, schema: &schema::ValueType) -> Result Ok(bolt_value) } -const REL_ID_PARAM: &str = "rel_id"; +const REL_KEY_PARAM_PREFIX: &str = "rel_key"; const REL_PROPS_PARAM: &str = "rel_props"; -const SRC_ID_PARAM: &str = "source_id"; +const SRC_KEY_PARAM_PREFIX: &str = "source_key"; const SRC_PROPS_PARAM: &str = "source_props"; -const TGT_ID_PARAM: &str = "target_id"; +const TGT_KEY_PARAM_PREFIX: &str = "target_key"; const TGT_PROPS_PARAM: &str = "target_props"; impl RelationshipStorageExecutor { + fn build_key_field_params_n_literal<'a>( + param_prefix: &str, + key_fields: impl Iterator, + ) -> (Vec, String) { + let (params, items): (Vec, Vec) = key_fields + .into_iter() + .enumerate() + .map(|(i, name)| { + let param = format!("{}_{}", param_prefix, i); + let item = format!("{}: ${}", name, param); + (param, item) + }) + .unzip(); + (params, format!("{{{}}}", items.into_iter().join(", "))) + } + fn new( graph: Arc, spec: RelationshipSpec, - key_field: FieldSchema, + key_fields: Vec, value_fields: Vec, - src_fields: AnalyzedGraphFields, - tgt_fields: AnalyzedGraphFields, + src_fields: AnalyzedNodeLabelInfo, + tgt_fields: AnalyzedNodeLabelInfo, ) -> Result { + let (key_field_params, key_fields_literal) = Self::build_key_field_params_n_literal( + REL_KEY_PARAM_PREFIX, + key_fields.iter().map(|f| &f.name), + ); + let (src_key_field_params, src_key_fields_literal) = Self::build_key_field_params_n_literal( + SRC_KEY_PARAM_PREFIX, + src_fields.key_fields.iter().map(|f| &f.field_name), + ); + let (tgt_key_field_params, tgt_key_fields_literal) = Self::build_key_field_params_n_literal( + TGT_KEY_PARAM_PREFIX, + tgt_fields.key_fields.iter().map(|f| &f.field_name), + ); + let delete_cypher = format!( r#" -OPTIONAL MATCH (old_src)-[old_rel:{rel_type} {{{rel_key_field_name}: ${REL_ID_PARAM}}}]->(old_tgt) +OPTIONAL MATCH (old_src)-[old_rel:{rel_type} {key_fields_literal}]->(old_tgt) DELETE old_rel @@ -348,38 +364,34 @@ CALL {{ FINISH "#, rel_type = spec.rel_type, - rel_key_field_name = key_field.name, ); let insert_cypher = format!( r#" -MERGE (new_src:{src_node_label} {{{src_node_key_field_name}: ${SRC_ID_PARAM}}}) +MERGE (new_src:{src_node_label} {src_key_fields_literal}) {optional_set_src_props} -MERGE (new_tgt:{tgt_node_label} {{{tgt_node_key_field_name}: ${TGT_ID_PARAM}}}) +MERGE (new_tgt:{tgt_node_label} {tgt_key_fields_literal}) {optional_set_tgt_props} -MERGE (new_src)-[new_rel:{rel_type} {{{rel_key_field_name}: ${REL_ID_PARAM}}}]->(new_tgt) +MERGE (new_src)-[new_rel:{rel_type} {key_fields_literal}]->(new_tgt) {optional_set_rel_props} FINISH "#, src_node_label = spec.source.label, - src_node_key_field_name = spec.get_src_label_info()?.key_field_name, optional_set_src_props = if src_fields.value_fields.is_empty() { "".to_string() } else { format!("SET new_src += ${SRC_PROPS_PARAM}\n") }, tgt_node_label = spec.target.label, - tgt_node_key_field_name = spec.get_tgt_label_info()?.key_field_name, optional_set_tgt_props = if tgt_fields.value_fields.is_empty() { "".to_string() } else { format!("SET new_tgt += ${TGT_PROPS_PARAM}\n") }, rel_type = spec.rel_type, - rel_key_field_name = key_field.name, optional_set_rel_props = if value_fields.is_empty() { "".to_string() } else { @@ -390,40 +402,73 @@ FINISH graph, delete_cypher, insert_cypher, - key_field, + key_field_params, + key_fields, value_fields, + src_key_field_params, src_fields, + tgt_key_field_params, tgt_fields, }) } + fn bind_key_field_params<'a>( + query: neo4rs::Query, + params: &[String], + type_val: impl Iterator, + ) -> Result { + let mut query = query; + for (i, (typ, val)) in type_val.enumerate() { + query = query.param(¶ms[i], value_to_bolt(val, typ)?); + } + Ok(query) + } + + fn bind_rel_key_field_params( + &self, + query: neo4rs::Query, + val: &KeyValue, + ) -> Result { + let mut query = query; + for (i, val) in val.fields_iter(self.key_fields.len())?.enumerate() { + query = query.param( + &self.key_field_params[i], + key_to_bolt(val, &self.key_fields[i].value_type.typ)?, + ); + } + Ok(query) + } + fn build_queries_to_apply_mutation( &self, mutation: &ExportTargetMutation, ) -> Result> { let mut queries = vec![]; for upsert in mutation.upserts.iter() { - let rel_id_bolt = key_to_bolt(&upsert.key, &self.key_field.value_type.typ)?; - queries - .push(neo4rs::query(&self.delete_cypher).param(REL_ID_PARAM, rel_id_bolt.clone())); + queries.push( + self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), &upsert.key)?, + ); let value = &upsert.value; - let mut insert_cypher = neo4rs::query(&self.insert_cypher) - .param(REL_ID_PARAM, rel_id_bolt) - .param( - SRC_ID_PARAM, - value_to_bolt( - &value.fields[self.src_fields.key_field.field_idx], - &self.src_fields.key_field.value_type, - )?, - ) - .param( - TGT_ID_PARAM, - value_to_bolt( - &value.fields[self.tgt_fields.key_field.field_idx], - &self.tgt_fields.key_field.value_type, - )?, - ); + let mut insert_cypher = + self.bind_rel_key_field_params(neo4rs::query(&self.insert_cypher), &upsert.key)?; + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.src_key_field_params, + self.src_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; + insert_cypher = Self::bind_key_field_params( + insert_cypher, + &self.tgt_key_field_params, + self.tgt_fields + .key_fields + .iter() + .map(|f| (&f.value_type, &value.fields[f.field_idx])), + )?; + if !self.src_fields.value_fields.is_empty() { insert_cypher = insert_cypher.param( SRC_PROPS_PARAM, @@ -460,10 +505,9 @@ FINISH queries.push(insert_cypher); } for delete_key in mutation.delete_keys.iter() { - queries.push(neo4rs::query(&self.delete_cypher).param( - REL_ID_PARAM, - key_to_bolt(delete_key, &self.key_field.value_type.typ)?, - )); + queries.push( + self.bind_rel_key_field_params(neo4rs::query(&self.delete_cypher), delete_key)?, + ); } Ok(queries) } @@ -489,35 +533,51 @@ impl ExportTargetExecutor for RelationshipStorageExecutor { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeLabelSetupState { - key_field_name: String, + key_field_names: Vec, key_constraint_name: String, + vector_indexes: HashMap, } impl NodeLabelSetupState { fn from_spec(label: &str, spec: &RelationshipNodeSpec) -> Self { - let key_constraint_name = format!("n__{}__{}", label, spec.key_field_name); + let key_constraint_name = format!("n__{}__unique", label); Self { - key_field_name: spec.key_field_name.clone(), + key_field_names: spec + .index_options + .primary_key_fields + .clone() + .unwrap_or_default(), key_constraint_name, + vector_indexes: spec + .index_options + .vector_index_defs + .iter() + .map(|v| { + ( + format!("n__{}__{}__{}", label, v.field_name.clone(), v.metric), + v.clone(), + ) + }) + .collect(), } } fn is_compatible(&self, other: &Self) -> bool { - self.key_field_name == other.key_field_name + self.key_field_names == other.key_field_names } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RelationshipSetupState { - key_field_name: String, + key_field_names: Vec, key_constraint_name: String, #[serde(default)] nodes: BTreeMap, } impl RelationshipSetupState { - fn from_spec(spec: &RelationshipSpec, key_field_name: String) -> Self { + fn from_spec(spec: &RelationshipSpec, key_field_names: Vec) -> Self { Self { - key_field_name, + key_field_names, key_constraint_name: format!("r__{}__key", spec.rel_type), nodes: spec .nodes @@ -528,7 +588,7 @@ impl RelationshipSetupState { } fn check_compatible(&self, existing: &Self) -> SetupStateCompatibility { - if self.key_field_name != existing.key_field_name { + if self.key_field_names != existing.key_field_names { SetupStateCompatibility::NotCompatible } else if existing.nodes.iter().any(|(label, existing_node)| { !self @@ -553,14 +613,14 @@ struct DataClearAction { #[derive(Debug)] struct KeyConstraint { label: String, - field_name: String, + field_names: Vec, } impl KeyConstraint { fn new(label: String, state: &NodeLabelSetupState) -> Self { Self { label: label, - field_name: state.key_field_name.clone(), + field_names: state.key_field_names.clone(), } } } @@ -617,20 +677,20 @@ impl SetupStatusCheck { if let Some(desired_state) = desired_state { let rel_constraint = KeyConstraint { label: key.relationship.clone(), - field_name: desired_state.key_field_name.clone(), + field_names: desired_state.key_field_names.clone(), }; - old_rel_constraints.swap_remove(&desired_state.key_constraint_name); + old_rel_constraints.shift_remove(&desired_state.key_constraint_name); if !existing .current .as_ref() - .map(|c| rel_constraint.field_name == c.key_field_name) + .map(|c| rel_constraint.field_names == c.key_field_names) .unwrap_or(false) { rel_constraint_to_create.insert(desired_state.key_constraint_name, rel_constraint); } for (label, node) in desired_state.nodes.iter() { - old_node_constraints.swap_remove(&node.key_constraint_name); + old_node_constraints.shift_remove(&node.key_constraint_name); if !existing .current .as_ref() @@ -700,7 +760,9 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { for (name, rel_constraint) in self.rel_constraint_to_create.iter() { result.push(format!( "Create KEY CONSTRAINT {} ON RELATIONSHIP {} (key: {})", - name, rel_constraint.label, rel_constraint.field_name, + name, + rel_constraint.label, + rel_constraint.field_names.join(", "), )); } for name in &self.node_constraint_to_delete { @@ -709,7 +771,9 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { for (name, node_constraint) in self.node_constraint_to_create.iter() { result.push(format!( "Create KEY CONSTRAINT {} ON NODE {} (key: {})", - name, node_constraint.label, node_constraint.field_name, + name, + node_constraint.label, + node_constraint.field_names.join(", "), )); } result @@ -720,6 +784,18 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { } async fn apply_change(&self) -> Result<()> { + let build_composite_field_names = |qualifier: &str, field_names: &[String]| -> String { + let strs = field_names + .iter() + .map(|name| format!("{qualifier}.{name}")) + .join(", "); + if field_names.len() == 1 { + strs + } else { + format!("({})", strs) + } + }; + let graph = self.graph_pool.get_graph(&self.conn_spec).await?; if let Some(data_clear) = &self.data_clear { @@ -761,9 +837,9 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { for (name, constraint) in self.node_constraint_to_create.iter() { graph .run(neo4rs::query(&format!( - "CREATE CONSTRAINT {name} IF NOT EXISTS FOR (n:{label}) REQUIRE n.{field_name} IS UNIQUE", + "CREATE CONSTRAINT {name} IF NOT EXISTS FOR (n:{label}) REQUIRE {field_names} IS UNIQUE", label = constraint.label, - field_name = constraint.field_name + field_names = build_composite_field_names("n", &constraint.field_names) ))) .await?; } @@ -771,9 +847,9 @@ impl ResourceSetupStatusCheck for SetupStatusCheck { for (name, constraint) in self.rel_constraint_to_create.iter() { graph .run(neo4rs::query(&format!( - "CREATE CONSTRAINT {name} IF NOT EXISTS FOR ()-[e:{label}]-() REQUIRE e.{field_name} IS UNIQUE", + "CREATE CONSTRAINT {name} IF NOT EXISTS FOR ()-[e:{label}]-() REQUIRE {field_names} IS UNIQUE", label = constraint.label, - field_name = constraint.field_name + field_names = build_composite_field_names("e", &constraint.field_names) ))) .await?; } @@ -791,6 +867,87 @@ impl RelationshipFactory { } } +struct NodeLabelAnalyzer<'a> { + label_name: &'a str, + fields: IndexMap<&'a str, AnalyzedGraphFieldMapping>, + remaining_fields: HashMap<&'a str, &'a FieldMapping>, + index_options: &'a IndexOptions, +} + +impl<'a> NodeLabelAnalyzer<'a> { + fn new(rel_spec: &'a RelationshipSpec, rel_end_spec: &'a RelationshipEndSpec) -> Result { + let node_spec = rel_spec.nodes.get(&rel_end_spec.label).ok_or_else(|| { + anyhow!( + "Node label `{}` not found in relationship spec", + rel_end_spec.label + ) + })?; + Ok(Self { + label_name: rel_end_spec.label.as_str(), + fields: IndexMap::new(), + remaining_fields: rel_end_spec + .fields + .iter() + .map(|f| (f.field_name.as_str(), f)) + .collect(), + index_options: &node_spec.index_options, + }) + } + + fn process_field(&mut self, field_idx: usize, field_schema: &FieldSchema) -> bool { + let field_info = match self.remaining_fields.remove(field_schema.name.as_str()) { + Some(field_info) => field_info, + None => return false, + }; + self.fields.insert( + field_info.get_node_field_name().as_str(), + AnalyzedGraphFieldMapping { + field_idx, + field_name: field_info.get_node_field_name().clone(), + value_type: field_schema.value_type.typ.clone(), + }, + ); + true + } + + fn build(self) -> Result { + if !self.remaining_fields.is_empty() { + anyhow::bail!( + "Fields not mapped for Node label `{}`: {}", + self.label_name, + self.remaining_fields.keys().join(", ") + ); + } + let mut fields = self.fields; + let mut key_fields = vec![]; + for key_field in self + .index_options + .primary_key_fields + .iter() + .flat_map(|f| f.iter()) + { + let e = fields.shift_remove(key_field.as_str()).ok_or_else(|| { + anyhow!( + "Key field `{}` not mapped in Node label `{}`", + key_field, + self.label_name + ) + })?; + key_fields.push(e); + } + if key_fields.is_empty() { + anyhow::bail!( + "No key fields specified for Node label `{}`", + self.label_name + ); + } + Ok(AnalyzedNodeLabelInfo { + key_fields, + value_fields: fields.into_values().collect(), + }) + } +} + impl StorageFactoryBase for RelationshipFactory { type Spec = RelationshipSpec; type SetupState = RelationshipSetupState; @@ -810,72 +967,18 @@ impl StorageFactoryBase for RelationshipFactory { context: Arc, ) -> Result> { let setup_key = GraphRelationship::from_spec(&spec); - let key_field_schema = { - if key_fields_schema.len() != 1 { - anyhow::bail!("Neo4j only supports a single key field"); - } - key_fields_schema.into_iter().next().unwrap() - }; - let desired_setup_state = - RelationshipSetupState::from_spec(&spec, key_field_schema.name.clone()); + let desired_setup_state = RelationshipSetupState::from_spec( + &spec, + key_fields_schema.iter().map(|f| f.name.clone()).collect(), + ); + let mut src_label_analyzer = NodeLabelAnalyzer::new(&spec, &spec.source)?; + let mut tgt_label_analyzer = NodeLabelAnalyzer::new(&spec, &spec.target)?; let mut rel_value_fields_info = vec![]; - let mut src_key_field_info = None; - let mut src_value_fields_info = vec![]; - let mut tgt_key_field_info = None; - let mut tgt_value_fields_info = vec![]; - - let mut field_name_to_src_field_info = spec - .source - .fields - .iter() - .map(|field| (field.field_name.as_str(), field)) - .collect::>(); - let mut field_name_to_tgt_field_info = spec - .target - .fields - .iter() - .map(|field| (field.field_name.as_str(), field)) - .collect::>(); - - let src_label_info = spec.get_src_label_info()?; - let tgt_label_info = spec.get_tgt_label_info()?; for (field_idx, field_schema) in value_fields_schema.into_iter().enumerate() { - let src_field_info = field_name_to_src_field_info.remove(field_schema.name.as_str()); - let tgt_field_info = field_name_to_tgt_field_info.remove(field_schema.name.as_str()); - if let Some(src_field_info) = src_field_info { - let field_mapping = AnalyzedGraphFieldMapping { - field_idx, - field_name: src_field_info.get_node_field_name().clone(), - value_type: field_schema.value_type.typ.clone(), - }; - let node_field_name = src_field_info - .node_field_name - .as_ref() - .unwrap_or(&src_field_info.field_name); - if &src_label_info.key_field_name == node_field_name { - src_key_field_info = Some(field_mapping); - } else { - src_value_fields_info.push(field_mapping); - } - } - if let Some(tgt_field_info) = tgt_field_info { - let field_mapping = AnalyzedGraphFieldMapping { - field_idx, - field_name: tgt_field_info.get_node_field_name().clone(), - value_type: field_schema.value_type.typ.clone(), - }; - let node_field_name = tgt_field_info - .node_field_name - .as_ref() - .unwrap_or(&tgt_field_info.field_name); - if &tgt_label_info.key_field_name == node_field_name { - tgt_key_field_info = Some(field_mapping); - } else { - tgt_value_fields_info.push(field_mapping); - } - } - if src_field_info.is_none() && tgt_field_info.is_none() { + if !src_label_analyzer.process_field(field_idx, &field_schema) + && !tgt_label_analyzer.process_field(field_idx, &field_schema) + { rel_value_fields_info.push(AnalyzedGraphFieldMapping { field_idx, field_name: field_schema.name.clone(), @@ -883,30 +986,9 @@ impl StorageFactoryBase for RelationshipFactory { }); } } - if !field_name_to_src_field_info.is_empty() { - anyhow::bail!( - "Source field not found: {}", - field_name_to_src_field_info.keys().join(", ") - ); - } - if !field_name_to_tgt_field_info.is_empty() { - anyhow::bail!( - "Target field not found: {}", - field_name_to_tgt_field_info.keys().join(", ") - ); - } - let src_key_field_info = src_key_field_info.ok_or_else(|| { - anyhow::anyhow!( - "Source key field not found: {}", - src_label_info.key_field_name - ) - })?; - let tgt_key_field_info = tgt_key_field_info.ok_or_else(|| { - anyhow::anyhow!( - "Target key field not found: {}", - tgt_label_info.key_field_name - ) - })?; + let src_label_info = src_label_analyzer.build()?; + let tgt_label_info = tgt_label_analyzer.build()?; + let conn_spec = context .auth_registry .get::(&spec.connection)?; @@ -915,16 +997,10 @@ impl StorageFactoryBase for RelationshipFactory { let executor = Arc::new(RelationshipStorageExecutor::new( graph, spec, - key_field_schema, + key_fields_schema, rel_value_fields_info, - AnalyzedGraphFields { - key_field: src_key_field_info, - value_fields: src_value_fields_info, - }, - AnalyzedGraphFields { - key_field: tgt_key_field_info, - value_fields: tgt_value_fields_info, - }, + src_label_info, + tgt_label_info, )?); Ok((executor as Arc, None)) }