# embed

> Embedding models

In [None]:
# | default_exp embed

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
# | hide
from nbdev.showdoc import *

In [None]:
# | export
from functools import partial
from typing import Awaitable, Callable, List, Union

import numpy as np
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

In [None]:
# | export
MAX_BATCH_OPENAI = 2048
MAX_BATCH_SENTENCE_TRANSFORMER = 32

In [None]:
# | export
EmbedClient = Union[AsyncOpenAI, SentenceTransformer]
EmbedResult = List[Union[float, np.array]]
EmbedFn = Callable[[str], Union[Awaitable[EmbedResult], EmbedResult]]

In [None]:
# | export


class EmbedModel(BaseModel):
    async_: bool
    init_key: Union[str, None]
    max_input_size: int
    fn_init: Callable[[Union[str, None]], EmbedClient]
    fn_embed: Callable[[EmbedClient], EmbedFn]
    max_batch: int = Field(default=1)

In [None]:
# | export


class EmbedRequest(BaseModel):
    text: Union[str, List[str]]
    model_str: str = Field(default="openai_text_embedding_3_small")
    output_size: Union[int, None] = Field(default=None)

In [None]:
# | export


def init_openai() -> AsyncOpenAI:
    return AsyncOpenAI()

In [None]:
# | export


def init_sentence_transformer(model: str) -> SentenceTransformer:
    return SentenceTransformer(model)

In [None]:
# | export


def embed_openai(client: AsyncOpenAI, model: str) -> EmbedFn:
    async def _embed_openai(text: str) -> EmbedResult:
        response = await client.embeddings.create(input=[text], model=model)
        return response.data[0].embedding

    return _embed_openai

In [None]:
# | export


def embed_sentence_transformer(client: SentenceTransformer) -> EmbedFn:
    return lambda text: client.encode([text])[0]

In [None]:
# | export


embed_supported = {
    "openai_text_embedding_3_small": EmbedModel(
        init_key="openai",
        max_input_size=8191,
        fn_init=partial(AsyncOpenAI),
        fn_embed=partial(embed_openai, model="text-embedding-3-small"),
        async_=True,
        max_batch=MAX_BATCH_OPENAI,
    ),
    "openai_text_embedding_3_large": EmbedModel(
        init_key="openai",
        max_input_size=8191,
        fn_init=partial(AsyncOpenAI),
        fn_embed=partial(embed_openai, model="text-embedding-3-large"),
        async_=True,
        max_batch=MAX_BATCH_OPENAI,
    ),
    "jina-embeddings-v2-base-en": EmbedModel(
        init_key="jina-embeddings-v2-base-en",
        max_input_size=8191,
        fn_init=partial(SentenceTransformer, "jinaai/jina-embeddings-v2-base-en"),
        fn_embed=partial(embed_sentence_transformer),
        async_=False,
        max_batch=MAX_BATCH_SENTENCE_TRANSFORMER,
    ),
}

In [None]:
# | hide

import nbdev

nbdev.nbdev_export()