Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions examples/docs_to_kg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from dotenv import load_dotenv
import cocoindex

@dataclasses.dataclass
class DocumentSummary:
"""Describe a summary of a document."""
title: str
summary: str

@dataclasses.dataclass
class Relationship:
Expand All @@ -31,13 +36,25 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
cocoindex.sources.LocalFile(path="../../docs/docs/core",
included_patterns=["*.md", "*.mdx"]))

relationships = data_scope.add_collector()
document_node = data_scope.add_collector()
entity_relationship = data_scope.add_collector()
entity_mention = data_scope.add_collector()

with data_scope["documents"].row() as doc:
doc["chunks"] = doc["content"].transform(
cocoindex.functions.SplitRecursively(),
language="markdown", chunk_size=10000)

doc["summary"] = doc["content"].transform(
cocoindex.functions.ExtractByLlm(
llm_spec=cocoindex.LlmSpec(
api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
output_type=DocumentSummary,
instruction="Please summarize the content of the document."))
document_node.collect(
filename=doc["filename"], title=doc["summary"]["title"],
summary=doc["summary"]["summary"])

with doc["chunks"].row() as chunk:
chunk["relationships"] = chunk["text"].transform(
cocoindex.functions.ExtractByLlm(
Expand All @@ -59,17 +76,31 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
relationship["object_embedding"] = relationship["object"].transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"))
relationships.collect(
entity_relationship.collect(
id=cocoindex.GeneratedField.UUID,
subject=relationship["subject"],
subject_embedding=relationship["subject_embedding"],
object=relationship["object"],
object_embedding=relationship["object_embedding"],
predicate=relationship["predicate"],
)

relationships.export(
"relationships",
entity_mention.collect(
id=cocoindex.GeneratedField.UUID, entity=relationship["subject"],
filename=doc["filename"], location=chunk["location"],
)
entity_mention.collect(
id=cocoindex.GeneratedField.UUID, entity=relationship["object"],
filename=doc["filename"], location=chunk["location"],
)
document_node.export(
"document_node",
cocoindex.storages.Neo4j(
connection=conn_spec,
mapping=cocoindex.storages.Neo4jNode(label="Document")),
primary_key_fields=["filename"],
)
entity_relationship.export(
"entity_relationship",
cocoindex.storages.Neo4j(
connection=conn_spec,
mapping=cocoindex.storages.Neo4jRelationship(
Expand Down Expand Up @@ -107,6 +138,25 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
),
primary_key_fields=["id"],
)
entity_mention.export(
"entity_mention",
cocoindex.storages.Neo4j(
connection=conn_spec,
mapping=cocoindex.storages.Neo4jRelationship(
rel_type="MENTION",
source=cocoindex.storages.Neo4jRelationshipEnd(
label="Document",
fields=[cocoindex.storages.Neo4jFieldMapping("filename")],
),
target=cocoindex.storages.Neo4jRelationshipEnd(
label="Entity",
fields=[cocoindex.storages.Neo4jFieldMapping(
field_name="entity", node_field_name="value")],
),
),
),
primary_key_fields=["id"],
)

@cocoindex.main_fn()
def _run():
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Neo4jRelationship:
rel_type: str
source: Neo4jRelationshipEnd
target: Neo4jRelationshipEnd
nodes: dict[str, Neo4jRelationshipNode]
nodes: dict[str, Neo4jRelationshipNode] | None = None

class Neo4j(op.StorageSpec):
"""Graph storage powered by Neo4j."""
Expand Down
15 changes: 11 additions & 4 deletions src/ops/storages/neo4j.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub struct RelationshipSpec {
rel_type: String,
source: RelationshipEndSpec,
target: RelationshipEndSpec,
nodes: BTreeMap<String, RelationshipNodeSpec>,
nodes: Option<BTreeMap<String, RelationshipNodeSpec>>,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -693,7 +693,7 @@ impl RelationshipSetupState {
rel_spec.rel_type
)
})?;
for (label, node) in rel_spec.nodes.iter() {
for (label, node) in rel_spec.nodes.iter().flatten() {
sub_components.push(ComponentState {
object_label: ElementType::Node(label.clone()),
index_def: IndexDef::KeyConstraint {
Expand All @@ -720,7 +720,13 @@ impl RelationshipSetupState {
});
}
}
dependent_node_labels.extend(rel_spec.nodes.keys().cloned());
dependent_node_labels.extend(
rel_spec
.nodes
.iter()
.flat_map(|nodes| nodes.keys())
.cloned(),
);
}
};
Ok(Self {
Expand Down Expand Up @@ -1069,7 +1075,8 @@ impl<'a> DependentNodeLabelAnalyzer<'a> {
.collect(),
index_options: rel_spec
.nodes
.get(&rel_end_spec.label)
.as_ref()
.and_then(|nodes| nodes.get(&rel_end_spec.label))
.and_then(|node_spec| Some(&node_spec.index_options)),
})
}
Expand Down