Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/docs/ops/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ The spec takes the following fields:
* `api_type` ([`cocoindex.LlmApiType`](/docs/ai/llm#llm-api-types)): The type of LLM API to use for embedding.
* `model` (`str`): The name of the embedding model to use.
* `address` (`str`, optional): The address of the LLM API. If not specified, uses the default address for the API type.
* `output_dimension` (`int`, optional): The expected dimension of the output embedding vector. If not specified, use the default dimension of the model.
* `output_dimension` (`int`, optional): The dimension to request from the embedding API. Some APIs support specifying the output dimension (e.g., OpenAI's models support dimension reduction). If not specified, the API will use its default dimension.
* `expected_output_dimension` (`int`, optional): The expected dimension of the output embedding vector for validation and type schema. If not specified, falls back to `output_dimension`, then to the default dimension of the model.

For most API types, the function internally keeps a registry for the default output dimension of known model.
You need to explicitly specify the `output_dimension` if you want to use a new model that is not in the registry yet.
For most API types, the function internally keeps a registry for the default output dimension of known models.
You need to explicitly specify `expected_output_dimension` (or `output_dimension`) if you want to use a new model that is not in the registry yet.

* `task_type` (`str`, optional): The task type for embedding, used by some embedding models to optimize the embedding for specific use cases.

Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/functions/_engine_builtin_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class EmbedText(op.FunctionSpec):
model: str
address: str | None = None
output_dimension: int | None = None
expected_output_dimension: int | None = None
task_type: str | None = None
api_config: llm.VertexAiConfig | None = None
api_key: TransientAuthEntryReference[str] | None = None
Expand Down
28 changes: 20 additions & 8 deletions rust/cocoindex/src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct Spec {
address: Option<String>,
api_config: Option<LlmApiConfig>,
output_dimension: Option<u32>,
expected_output_dimension: Option<u32>,
task_type: Option<String>,
api_key: Option<AuthEntryReference<String>>,
}
Expand Down Expand Up @@ -129,23 +130,33 @@ impl SimpleFunctionFactoryBase for Factory {
spec.api_config.clone(),
)
.await?;
let output_dimension = match spec.output_dimension {
Some(output_dimension) => output_dimension,
None => {
client.get_default_embedding_dimension(spec.model.as_str())
.ok_or_else(|| api_error!("model \"{}\" is unknown for {:?}, needs to specify `output_dimension` explicitly", spec.model, spec.api_type))?

// Warn if both parameters are specified but have different values
if let (Some(expected), Some(output)) =
(spec.expected_output_dimension, spec.output_dimension)
{
if expected != output {
warn!(
"Both `expected_output_dimension` ({expected}) and `output_dimension` ({output}) are specified but have different values. \
`expected_output_dimension` will be used for output schema and validation, while `output_dimension` will be sent to the embedding API."
);
}
};
}

let expected_output_dimension = spec.expected_output_dimension
.or(spec.output_dimension)
.or_else(|| client.get_default_embedding_dimension(spec.model.as_str()))
.ok_or_else(|| api_error!("model \"{}\" is unknown for {:?}, needs to specify `expected_output_dimension` (or `output_dimension`) explicitly", spec.model, spec.api_type))? as usize;
let output_schema = make_output_type(BasicValueType::Vector(VectorTypeSchema {
dimension: Some(output_dimension as usize),
dimension: Some(expected_output_dimension),
element_type: Box::new(BasicValueType::Float32),
}));
Ok(SimpleFunctionAnalysisOutput {
behavior_version: client.behavior_version(),
resolved_args: Args {
client,
text,
expected_output_dimension: output_dimension as usize,
expected_output_dimension,
},
output_schema,
})
Expand Down Expand Up @@ -179,6 +190,7 @@ mod tests {
address: None,
api_config: None,
output_dimension: None,
expected_output_dimension: None,
task_type: None,
api_key: None,
};
Expand Down
Loading