From 0f01239f6b4c026f5c6c81394173c8730ed7c5b0 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 5 Mar 2025 09:37:40 -0800 Subject: [PATCH 1/2] `SentenceTransformerEmbed` support additional args for the library. --- python/cocoindex/functions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 066a1968..33898fc7 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -1,5 +1,5 @@ """All builtin functions.""" -from typing import Annotated +from typing import Annotated, Any import json import sentence_transformers @@ -13,8 +13,16 @@ class SplitRecursively(op.FunctionSpec): language: str | None = None class SentenceTransformerEmbed(op.FunctionSpec): - """Run the sentence transformer""" + """ + `SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library. + + Args: + + model: The name of the SentenceTransformer model to use. + args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True} + """ model: str + args: dict[str, Any] | None = None @op.executor_class(gpu=True, cache=True, behavior_version=1) class SentenceTransformerEmbedExecutor: @@ -24,7 +32,8 @@ class SentenceTransformerEmbedExecutor: _model: sentence_transformers.SentenceTransformer def analyze(self, text = None): - self._model = sentence_transformers.SentenceTransformer(self.spec.model, 3) + args = self.spec.args or {} + self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args) dim = self._model.get_sentence_embedding_dimension() return Annotated[list[Float32], Vector(dim=dim), TypeAttr("cocoindex.io/vector_origin_text", json.loads(text.analyzed_value))] From a7c4a2210a6e4ab7e92d6f060828dcddf16c2357 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 5 Mar 2025 09:39:17 -0800 Subject: [PATCH 2/2] Update documentation for `SentenceTransformerEmbed`. --- docs/docs/ops/functions.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/docs/ops/functions.md b/docs/docs/ops/functions.md index 4da0b55c..1c93f2a4 100644 --- a/docs/docs/ops/functions.md +++ b/docs/docs/ops/functions.md @@ -33,6 +33,7 @@ Return type: `Table`, each row represents a chunk, with the following sub fields The spec takes the following fields: * `model` (type: `str`, required): The name of the SentenceTransformer model to use. +* `args` (type: `dict[str, Any]`, optional): Additional arguments to pass to the SentenceTransformer constructor. e.g. `{"trust_remote_code": True}` Input data: