In [None]:
#| default_exp embedding.base

True

## Embeddings

In [None]:
#| export

import typing as t
from abc import ABC, abstractmethod

#TODO: Add support for other providers like HuggingFace, Cohere, etc.
#TODO: handle async calls properly and ensure that the client supports async if needed.

class BaseEmbedding(ABC):
    @abstractmethod
    def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
        pass
    
    @abstractmethod
    async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
        pass
    
    @abstractmethod
    def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
        pass
    
    @abstractmethod
    async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
        pass


class OpenAIEmbeddings(BaseEmbedding):
    def __init__(self, client: t.Any, model: str):
        self.client = client
        self.model = model
    
    def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
        return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding
    
    async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
        response = await self.client.embeddings.create(input=text, model=self.model, **kwargs)
        return response.data[0].embedding
    
    def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
        embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs)
        return [embedding.embedding for embedding in embeddings.data]
    
    async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
        embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs)
        return [embedding.embedding for embedding in embeddings.data]
    
    
def ragas_embedding(provider: str, model: str, client: t.Any) -> BaseEmbedding:
    """
    Factory function to create an embedding instance based on the provider.
    
    Args:
        provider (str): The name of the embedding provider (e.g., "openai").
        model (str): The model name to use for embeddings.
        **kwargs: Additional arguments for the provider's client.
    
    Returns:
        BaseEmbedding: An instance of the specified embedding provider.
    """
    if provider.lower() == "openai":
        return OpenAIEmbeddings(client=client, model=model)
    
    raise ValueError(f"Unsupported provider: {provider}")

### Example Usage

In [None]:
#| eval: false

## change to this design
from openai import OpenAI
embedding_model = ragas_embedding(provider="openai", model="text-embedding-3-small", client=OpenAI())
embedding_model.embed_text("Hello, world!")


[-0.019184619188308716,
 -0.025279032066464424,
 -0.0017195191467180848,
 0.01884828321635723,
 -0.033795066177845,
 -0.01969585195183754,
 -0.02094702236354351,
 0.051580529659986496,
 -0.03212684020400047,
 -0.030377890914678574,
 -0.002145825419574976,
 -0.028978731483221054,
 -0.0024737531784921885,
 -0.031481072306632996,
 0.010332250036299229,
 0.018606122583150864,
 -0.04614533483982086,
 0.04146353527903557,
 0.0004418617463670671,
 0.04122137278318405,
 0.05367926508188248,
 0.0018733929609879851,
 0.0045674461871385574,
 0.010022819973528385,
 0.04786737635731697,
 0.0022013208363205194,
 -0.009834472090005875,
 0.03847686946392059,
 0.00089213193859905,
 -0.05211866647005081,
 0.051150016486644745,
 -0.032557349652051926,
 -0.014031948521733284,
 -0.012632790021598339,
 0.013271828182041645,
 0.018565760925412178,
 0.0016068464610725641,
 -0.0008185583865270019,
 -0.012753871269524097,
 -0.029705218970775604,
 -0.004443001933395863,
 -0.015323479659855366,
 0.025655729696154