Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions examples/docs_to_kg/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
"Each relationship should be a tuple of (subject, predicate, object).")))

with chunk["relationships"]["relationships"].row() as relationship:
relationship["subject_embedding"] = relationship["subject"].transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"))
relationship["object_embedding"] = relationship["object"].transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"))
relationships.collect(
id=cocoindex.GeneratedField.UUID,
subject=relationship["subject"],
predicate=relationship["predicate"],
subject_embedding=relationship["subject_embedding"],
object=relationship["object"],
object_embedding=relationship["object_embedding"],
predicate=relationship["predicate"],
)

relationships.export(
Expand All @@ -69,14 +77,34 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D
rel_type="RELATIONSHIP",
source=cocoindex.storages.Neo4jRelationshipEndSpec(
label="Entity",
fields=[cocoindex.storages.Neo4jFieldMapping(field_name="subject", node_field_name="value")]
fields=[
cocoindex.storages.Neo4jFieldMapping(
field_name="subject", node_field_name="value"),
cocoindex.storages.Neo4jFieldMapping(
field_name="subject_embedding", node_field_name="embedding"),
]
),
target=cocoindex.storages.Neo4jRelationshipEndSpec(
label="Entity",
fields=[cocoindex.storages.Neo4jFieldMapping(field_name="object", node_field_name="value")]
fields=[
cocoindex.storages.Neo4jFieldMapping(
field_name="object", node_field_name="value"),
cocoindex.storages.Neo4jFieldMapping(
field_name="object_embedding", node_field_name="embedding"),
]
),
nodes={
"Entity": cocoindex.storages.Neo4jRelationshipNodeSpec(key_field_name="value"),
"Entity": cocoindex.storages.Neo4jRelationshipNodeSpec(
index_options=cocoindex.IndexOptions(
primary_key_fields=["value"],
vector_index_defs=[
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
),
],
),
),
},
),
primary_key_fields=["id"],
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .flow import EvaluateAndDumpOptions, GeneratedField
from .flow import update_all_flows, FlowLiveUpdater, FlowLiveUpdaterOptions
from .llm import LlmSpec, LlmApiType
from .vector import VectorSimilarityMetric
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry
from .lib import *
from ._engine import OpArgSchema
4 changes: 2 additions & 2 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass

from . import _engine
from . import vector
from . import index
from . import op
from .convert import dump_engine_object
from .typing import encode_enriched_type
Expand Down Expand Up @@ -268,7 +268,7 @@ def collect(self, **kwargs):

def export(self, name: str, target_spec: op.StorageSpec, /, *,
primary_key_fields: Sequence[str] | None = None,
vector_index: Sequence[tuple[str, vector.VectorSimilarityMetric]] = (),
vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (),
setup_by_user: bool = False):
"""
Export the collected data to the specified target.
Expand Down
23 changes: 23 additions & 0 deletions python/cocoindex/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from enum import Enum
from dataclasses import dataclass

class VectorSimilarityMetric(Enum):
COSINE_SIMILARITY = "CosineSimilarity"
L2_DISTANCE = "L2Distance"
INNER_PRODUCT = "InnerProduct"

@dataclass
class VectorIndexDef:
"""
Define a vector index on a field.
"""
field_name: str
metric: VectorSimilarityMetric

@dataclass
class IndexOptions:
"""
Options for an index.
"""
primary_key_fields: list[str] | None = None
vector_index_defs: list[VectorIndexDef] | None = None
10 changes: 5 additions & 5 deletions python/cocoindex/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from threading import Lock

from . import flow as fl
from . import vector
from . import index
from . import _engine

_handlers_lock = Lock()
Expand All @@ -14,7 +14,7 @@ class SimpleSemanticsQueryInfo:
"""
Additional information about the query.
"""
similarity_metric: vector.VectorSimilarityMetric
similarity_metric: index.VectorSimilarityMetric
query_vector: list[float]
vector_field_name: str

Expand All @@ -39,7 +39,7 @@ def __init__(
flow: fl.Flow,
target_name: str,
query_transform_flow: Callable[..., fl.DataSlice],
default_similarity_metric: vector.VectorSimilarityMetric = vector.VectorSimilarityMetric.COSINE_SIMILARITY) -> None:
default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY) -> None:

engine_handler = None
lock = Lock()
Expand All @@ -66,7 +66,7 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler:
return self._lazy_query_handler()

def search(self, query: str, limit: int, vector_field_name: str | None = None,
similarity_matric: vector.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
similarity_matric: index.VectorSimilarityMetric | None = None) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]:
"""
Search the index with the given query, limit, vector field name, and similarity metric.
"""
Expand All @@ -76,7 +76,7 @@ def search(self, query: str, limit: int, vector_field_name: str | None = None,
fields = [field['name'] for field in internal_results['fields']]
results = [QueryResult(data=dict(zip(fields, result['data'])), score=result['score']) for result in internal_results['results']]
info = SimpleSemanticsQueryInfo(
similarity_metric=vector.VectorSimilarityMetric(internal_info['similarity_metric']),
similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']),
query_vector=internal_info['query_vector'],
vector_field_name=internal_info['vector_field_name']
)
Expand Down
3 changes: 2 additions & 1 deletion python/cocoindex/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass

from . import op
from . import index
from .auth_registry import AuthEntryReference
class Postgres(op.StorageSpec):
"""Storage powered by Postgres and pgvector."""
Expand Down Expand Up @@ -35,7 +36,7 @@ class Neo4jRelationshipEndSpec:
class Neo4jRelationshipNodeSpec:
"""Spec for a Neo4j node type."""
key_field_name: str | None = None

index_options: index.IndexOptions | None = None
class Neo4jRelationship(op.StorageSpec):
"""Graph storage powered by Neo4j."""

Expand Down
6 changes: 0 additions & 6 deletions python/cocoindex/vector.py

This file was deleted.

10 changes: 10 additions & 0 deletions src/base/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ pub enum VectorSimilarityMetric {
InnerProduct,
}

impl std::fmt::Display for VectorSimilarityMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VectorSimilarityMetric::CosineSimilarity => write!(f, "Cosine"),
VectorSimilarityMetric::L2Distance => write!(f, "L2"),
VectorSimilarityMetric::InnerProduct => write!(f, "InnerProduct"),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VectorIndexDef {
pub field_name: FieldName,
Expand Down
15 changes: 15 additions & 0 deletions src/base/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ impl std::fmt::Display for KeyValue {
}

impl KeyValue {
pub fn fields_iter<'a>(
&'a self,
num_fields: usize,
) -> Result<impl Iterator<Item = &'a KeyValue>> {
let slice = if num_fields == 1 {
std::slice::from_ref(self)
} else {
match self {
KeyValue::Struct(v) => v,
_ => api_bail!("Invalid key value type"),
}
};
Ok(slice.iter())
}

fn parts_from_str(
values_iter: &mut impl Iterator<Item = String>,
schema: &ValueType,
Expand Down
Loading