diff --git a/docs/docs/ops/functions.md b/docs/docs/ops/functions.md index 4da0b55c..1c93f2a4 100644 --- a/docs/docs/ops/functions.md +++ b/docs/docs/ops/functions.md @@ -33,6 +33,7 @@ Return type: `Table`, each row represents a chunk, with the following sub fields The spec takes the following fields: * `model` (type: `str`, required): The name of the SentenceTransformer model to use. +* `args` (type: `dict[str, Any]`, optional): Additional arguments to pass to the SentenceTransformer constructor. e.g. `{"trust_remote_code": True}` Input data: diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 066a1968..33898fc7 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -1,5 +1,5 @@ """All builtin functions.""" -from typing import Annotated +from typing import Annotated, Any import json import sentence_transformers @@ -13,8 +13,16 @@ class SplitRecursively(op.FunctionSpec): language: str | None = None class SentenceTransformerEmbed(op.FunctionSpec): - """Run the sentence transformer""" + """ + `SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library. + + Args: + + model: The name of the SentenceTransformer model to use. + args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True} + """ model: str + args: dict[str, Any] | None = None @op.executor_class(gpu=True, cache=True, behavior_version=1) class SentenceTransformerEmbedExecutor: @@ -24,7 +32,8 @@ class SentenceTransformerEmbedExecutor: _model: sentence_transformers.SentenceTransformer def analyze(self, text = None): - self._model = sentence_transformers.SentenceTransformer(self.spec.model, 3) + args = self.spec.args or {} + self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args) dim = self._model.get_sentence_embedding_dimension() return Annotated[list[Float32], Vector(dim=dim), TypeAttr("cocoindex.io/vector_origin_text", json.loads(text.analyzed_value))]