From 1825ff41aab8bb4297196f48e30637641c48b371 Mon Sep 17 00:00:00 2001 From: LJ Date: Sat, 12 Apr 2025 15:58:11 -0700 Subject: [PATCH] Correctly set field names for nodes. --- src/ops/storages/neo4j.rs | 70 +++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 6d2053d58..353b418ba 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -25,6 +25,12 @@ pub struct FieldMapping { node_field_name: Option, } +impl FieldMapping { + fn get_node_field_name(&self) -> &FieldName { + self.node_field_name.as_ref().unwrap_or(&self.field_name) + } +} + #[derive(Debug, Deserialize)] pub struct RelationshipEndSpec { label: String, @@ -133,7 +139,8 @@ impl GraphPool { #[derive(Debug, Clone)] struct AnalyzedGraphFieldMapping { field_idx: usize, - field_schema: FieldSchema, + field_name: String, + value_type: schema::ValueType, } struct AnalyzedGraphFields { @@ -203,6 +210,23 @@ fn field_values_to_bolt<'a>( Ok(bolt_value) } +fn mapped_field_values_to_bolt<'a>( + field_values: impl IntoIterator, + field_mappings: impl IntoIterator, +) -> Result { + let bolt_value = BoltType::Map(neo4rs::BoltMap { + value: std::iter::zip(field_mappings, field_values) + .map(|(mapping, value)| { + Ok(( + neo4rs::BoltString::new(&mapping.field_name), + value_to_bolt(value, &mapping.value_type)?, + )) + }) + .collect::>()?, + }); + Ok(bolt_value) +} + fn basic_value_to_bolt(value: &BasicValue, schema: &BasicValueType) -> Result { let bolt_value = match value { BasicValue::Bytes(v) => { @@ -390,46 +414,46 @@ FINISH SRC_ID_PARAM, value_to_bolt( &value.fields[self.src_fields.key_field.field_idx], - &self.src_fields.key_field.field_schema.value_type.typ, + &self.src_fields.key_field.value_type, )?, ) .param( TGT_ID_PARAM, value_to_bolt( &value.fields[self.tgt_fields.key_field.field_idx], - &self.tgt_fields.key_field.field_schema.value_type.typ, + &self.tgt_fields.key_field.value_type, )?, ); if !self.src_fields.value_fields.is_empty() { insert_cypher = insert_cypher.param( SRC_PROPS_PARAM, - field_values_to_bolt( + mapped_field_values_to_bolt( self.src_fields .value_fields .iter() .map(|f| &value.fields[f.field_idx]), - self.src_fields.value_fields.iter().map(|f| &f.field_schema), + self.src_fields.value_fields.iter(), )?, ); } if !self.tgt_fields.value_fields.is_empty() { insert_cypher = insert_cypher.param( TGT_PROPS_PARAM, - field_values_to_bolt( + mapped_field_values_to_bolt( self.tgt_fields .value_fields .iter() .map(|f| &value.fields[f.field_idx]), - self.tgt_fields.value_fields.iter().map(|f| &f.field_schema), + self.tgt_fields.value_fields.iter(), )?, ); } if !self.value_fields.is_empty() { insert_cypher = insert_cypher.param( REL_PROPS_PARAM, - field_values_to_bolt( + mapped_field_values_to_bolt( self.value_fields.iter().map(|f| &value.fields[f.field_idx]), - self.value_fields.iter().map(|f| &f.field_schema), + self.value_fields.iter(), )?, ); } @@ -819,34 +843,44 @@ impl StorageFactoryBase for RelationshipFactory { for (field_idx, field_schema) in value_fields_schema.into_iter().enumerate() { let src_field_info = field_name_to_src_field_info.remove(field_schema.name.as_str()); let tgt_field_info = field_name_to_tgt_field_info.remove(field_schema.name.as_str()); - let field_mapping = AnalyzedGraphFieldMapping { - field_idx, - field_schema, - }; if let Some(src_field_info) = src_field_info { + let field_mapping = AnalyzedGraphFieldMapping { + field_idx, + field_name: src_field_info.get_node_field_name().clone(), + value_type: field_schema.value_type.typ.clone(), + }; let node_field_name = src_field_info .node_field_name .as_ref() .unwrap_or(&src_field_info.field_name); if &src_label_info.key_field_name == node_field_name { - src_key_field_info = Some(field_mapping.clone()); + src_key_field_info = Some(field_mapping); } else { - src_value_fields_info.push(field_mapping.clone()); + src_value_fields_info.push(field_mapping); } } if let Some(tgt_field_info) = tgt_field_info { + let field_mapping = AnalyzedGraphFieldMapping { + field_idx, + field_name: tgt_field_info.get_node_field_name().clone(), + value_type: field_schema.value_type.typ.clone(), + }; let node_field_name = tgt_field_info .node_field_name .as_ref() .unwrap_or(&tgt_field_info.field_name); if &tgt_label_info.key_field_name == node_field_name { - tgt_key_field_info = Some(field_mapping.clone()); + tgt_key_field_info = Some(field_mapping); } else { - tgt_value_fields_info.push(field_mapping.clone()); + tgt_value_fields_info.push(field_mapping); } } if src_field_info.is_none() && tgt_field_info.is_none() { - rel_value_fields_info.push(field_mapping); + rel_value_fields_info.push(AnalyzedGraphFieldMapping { + field_idx, + field_name: field_schema.name.clone(), + value_type: field_schema.value_type.typ.clone(), + }); } } if !field_name_to_src_field_info.is_empty() {