# Custom Embeddings

We'll explore how to create a custom embedding model using LangChain's Embeddings interface. Embeddings are critical in natural language processing applications as they convert text into a numerical form that algorithms can understand, thereby enabling a wide range of applications such as similarity search, text classification, and clustering.

Implementing embeddings using the standard `Embeddings` interface will allow your embeddings to be utilized in existing `LangChain` abstractions (e.g., as the embeddings for a particular `Vectorstore` or cached using `CacheBackedEmbeddings`).

## Interface

The current `Embeddings` abstraction in LangChain is designed to operate on text data. In this implementation, the inputs are either single strings or lists of strings, and the outputs are lists of numerical arrays (vectors), where each vector represents
an embedding of the input text in some n-dimensional space.

Your custom embedding must implement the following methods:

| Method/Property                 | Description                                                                | Required/Optional |
|---------------------------------|----------------------------------------------------------------------------|-------------------|
| `embed_documents(texts)`        | Generates embeddings for a list of documents.                              | Required          |
| `embed_query(text)`             | Generates an embedding for a single text query.                            | Required          |
| `aembed_documents(texts)`       | Asynchronously generates embeddings for a list of documents.               | Optional          |
| `aembed_query(text)`            | Asynchronously generates an embedding for a single text query.             | Optional          |

These methods ensure that your embedding model can be integrated seamlessly into the LangChain framework, providing both synchronous and asynchronous capabilities for scalability and performance optimization.

:::{.callout-note}
`embed_documents` takes in a list of plain text, not a list of LangChain `Document` objects. The name of this method
may change in future versions of LangChain.
:::


:::{.callout-important}
`Embeddings` do not currently implement the `Runnable` interface and are also **not** instances of pydantic `BaseModel`.
:::

## Implementation

As an example, we'll implement a simple embeddings model that will count the characters in the text and generate a fixed size vector containing the character counts. The model will be case insensitive, and either count the characters from a-z or only the vowels (a, e, i, o, u). This model is for illustrative purposes only.

In [1]:
from collections import Counter
from typing import List

from langchain_core.embeddings import Embeddings


class CharCountEmbeddings(Embeddings):
    """Embedding model that counts occurrences of characters in text.

    When contributing an implementation to LangChain, carefully document
    the embedding model including the initialization parameters, include
    an example of how to initialize the model and include any relevant
    links to the underlying models documentation or API.

    Example:

        .. code-block:: python

            from langchain_community.embeddings import CharCountEmbeddings

            embeddings = ChatCountEmbeddings(only_vowels=True)
            print(embeddings.embed_documents(["Hello world", "Test"]))
            print(embeddings.embed_query("Quick Brown Fox"))
    """

    def __init__(self, *, only_vowels: bool = False) -> None:
        """Initialize the embedding model.

        Args:
            only_vowels: If True, the embedding will count only the
                vowels (a, e, i, o, u) and produce a 5-dimensional vector.
                If False, counts all lowercase alphabetic characters,
                producing a 26-dimensional vector.
        """

        self.only_vowels = only_vowels

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed multiple documents by counting specific character sets."""
        return [self._embed_text(text) for text in texts]

    def embed_query(self, text: str) -> List[float]:
        """Embed a single query by counting specific character sets."""
        return self._embed_text(text)

    def _embed_text(self, text: str) -> List[float]:
        """Helper function to create a character count vector from text."""
        text = text.lower()  # Normalize text to lowercase for case insensitivity.
        count = Counter(text)
        if self.only_vowels:
            # Embed only vowels
            vowels = "aeiou"
            return [count.get(vowel, 0) for vowel in vowels]
        else:
            # Embed all letters from 'a' to 'z'
            return [count.get(chr(i), 0) for i in range(ord("a"), ord("z") + 1)]

    # The async methods are optional.
    # Delete them if you do not have an actual async imlementation.
    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """Asynchronous embed search docs."""
        # This implementation is only for illustrative purposes.
        # If you're connecting to an API, you should provide
        # an actual async implementation (e.g., using httpx AsyncClient
        # https://www.python-httpx.org/async/).
        # If you do not have an actual async implementation, please
        # DELETE this method as LangChain already provides a first pass
        # optimization which involves delegating to the sync method.
        # If you do not have a native async implementation, just delete this
        # method. LangChain basically does this
        return [self._embed_text(text) for text in texts]

    async def aembed_query(self, text: str) -> List[float]:
        """Asynchronous embed query text."""
        # See comment above for the aembed_documents regarding
        # native async implementation
        return self._embed_text(text)

### Let's test it ðŸ§ª

In [2]:
embeddings = CharCountEmbeddings(only_vowels=True)
print(embeddings.embed_documents(["abce", "eee", "hello", "fox"]))
print(embeddings.embed_query("eeee"))

[[1, 1, 0, 0, 0], [0, 3, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 1, 0]]
[0, 4, 0, 0, 0]


## Contributing

We welcome contributions of Embedding models to the LangChain code base!

Here's a checklist to help make sure your contribution gets added to LangChain:

Documentation:

* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [API Reference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).
* The class doc-string for the model contains a link to the model API if the model is powered by a service.

Tests:

* [ ] Add an integration tests to test the integration with the API or model.

Optimizations:

If your implementation is an integration with an `API` consider providing async native support (e.g., via httpx AsyncClient).
 
* [ ] Provided a native async of `aembed_documents`
* [ ] Provided a native async of `aembed_query`