diff --git a/examples/docs_to_kg/main.py b/examples/docs_to_kg/main.py index 2fa49b36..fb4c23e4 100644 --- a/examples/docs_to_kg/main.py +++ b/examples/docs_to_kg/main.py @@ -5,6 +5,11 @@ from dotenv import load_dotenv import cocoindex +@dataclasses.dataclass +class DocumentSummary: + """Describe a summary of a document.""" + title: str + summary: str @dataclasses.dataclass class Relationship: @@ -31,13 +36,25 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D cocoindex.sources.LocalFile(path="../../docs/docs/core", included_patterns=["*.md", "*.mdx"])) - relationships = data_scope.add_collector() + document_node = data_scope.add_collector() + entity_relationship = data_scope.add_collector() + entity_mention = data_scope.add_collector() with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( cocoindex.functions.SplitRecursively(), language="markdown", chunk_size=10000) + doc["summary"] = doc["content"].transform( + cocoindex.functions.ExtractByLlm( + llm_spec=cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), + output_type=DocumentSummary, + instruction="Please summarize the content of the document.")) + document_node.collect( + filename=doc["filename"], title=doc["summary"]["title"], + summary=doc["summary"]["summary"]) + with doc["chunks"].row() as chunk: chunk["relationships"] = chunk["text"].transform( cocoindex.functions.ExtractByLlm( @@ -59,7 +76,7 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D relationship["object_embedding"] = relationship["object"].transform( cocoindex.functions.SentenceTransformerEmbed( model="sentence-transformers/all-MiniLM-L6-v2")) - relationships.collect( + entity_relationship.collect( id=cocoindex.GeneratedField.UUID, subject=relationship["subject"], subject_embedding=relationship["subject_embedding"], @@ -67,9 +84,23 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D object_embedding=relationship["object_embedding"], predicate=relationship["predicate"], ) - - relationships.export( - "relationships", + entity_mention.collect( + id=cocoindex.GeneratedField.UUID, entity=relationship["subject"], + filename=doc["filename"], location=chunk["location"], + ) + entity_mention.collect( + id=cocoindex.GeneratedField.UUID, entity=relationship["object"], + filename=doc["filename"], location=chunk["location"], + ) + document_node.export( + "document_node", + cocoindex.storages.Neo4j( + connection=conn_spec, + mapping=cocoindex.storages.Neo4jNode(label="Document")), + primary_key_fields=["filename"], + ) + entity_relationship.export( + "entity_relationship", cocoindex.storages.Neo4j( connection=conn_spec, mapping=cocoindex.storages.Neo4jRelationship( @@ -107,6 +138,25 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D ), primary_key_fields=["id"], ) + entity_mention.export( + "entity_mention", + cocoindex.storages.Neo4j( + connection=conn_spec, + mapping=cocoindex.storages.Neo4jRelationship( + rel_type="MENTION", + source=cocoindex.storages.Neo4jRelationshipEnd( + label="Document", + fields=[cocoindex.storages.Neo4jFieldMapping("filename")], + ), + target=cocoindex.storages.Neo4jRelationshipEnd( + label="Entity", + fields=[cocoindex.storages.Neo4jFieldMapping( + field_name="entity", node_field_name="value")], + ), + ), + ), + primary_key_fields=["id"], + ) @cocoindex.main_fn() def _run(): diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index e2e58342..c1568e2e 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -63,7 +63,7 @@ class Neo4jRelationship: rel_type: str source: Neo4jRelationshipEnd target: Neo4jRelationshipEnd - nodes: dict[str, Neo4jRelationshipNode] + nodes: dict[str, Neo4jRelationshipNode] | None = None class Neo4j(op.StorageSpec): """Graph storage powered by Neo4j.""" diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 03da3f20..bb582d59 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -56,7 +56,7 @@ pub struct RelationshipSpec { rel_type: String, source: RelationshipEndSpec, target: RelationshipEndSpec, - nodes: BTreeMap, + nodes: Option>, } #[derive(Debug, Deserialize)] @@ -693,7 +693,7 @@ impl RelationshipSetupState { rel_spec.rel_type ) })?; - for (label, node) in rel_spec.nodes.iter() { + for (label, node) in rel_spec.nodes.iter().flatten() { sub_components.push(ComponentState { object_label: ElementType::Node(label.clone()), index_def: IndexDef::KeyConstraint { @@ -720,7 +720,13 @@ impl RelationshipSetupState { }); } } - dependent_node_labels.extend(rel_spec.nodes.keys().cloned()); + dependent_node_labels.extend( + rel_spec + .nodes + .iter() + .flat_map(|nodes| nodes.keys()) + .cloned(), + ); } }; Ok(Self { @@ -1069,7 +1075,8 @@ impl<'a> DependentNodeLabelAnalyzer<'a> { .collect(), index_options: rel_spec .nodes - .get(&rel_end_spec.label) + .as_ref() + .and_then(|nodes| nodes.get(&rel_end_spec.label)) .and_then(|node_spec| Some(&node_spec.index_options)), }) }