Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create llama-index InfinityEmbeddings as langchain #111

Open
semoal opened this issue Feb 21, 2024 · 11 comments
Open

Create llama-index InfinityEmbeddings as langchain #111

semoal opened this issue Feb 21, 2024 · 11 comments

Comments

@semoal
Copy link

semoal commented Feb 21, 2024

Hi! Kudos for this project Michael! It is amazing.

We're migrating from a single repo with a RAG and and T40, to one repo with a RAG with just cpu and and another service with our embeddings models, rerankers... and just start/stop this machine (which is expensive) when traffic arrives.

We have seen your tool and looks promising, and we're willing to contribute.

Did you consider support llama-index? I think we (my company) could work on the llama-index integration, we can ping you when it's done to review.

What do you think?

@michaelfeil
Copy link
Owner

@semoal Love the idea - and I also love llamaindex. Im low on time, but I recently opened an issue myself.

run-llama/llama_index#10628

Also - feel free to copy anything from my langchain PR / langchain integration.
Both, the „server“ or the pure python implementation would be welcome.
langchain-ai/langchain#17671
For guidance, I think the AsyncEmbeddingEngine integration would have slightly more impact and usage, as people are still stuck with sentence transformers.

Ill be happily review your PR and update this Readme to highlight the integration.

@michaelfeil
Copy link
Owner

@semoal Any help required here?

@semoal
Copy link
Author

semoal commented Mar 12, 2024

No, we have it done and operative in production already. Even the reranker embedding for LlamaIndex too, using Infinity as backend; we've been lately a bit busy to reach a deadline for 22th, but we'll push it forward soon, before the end of the month.

@michaelfeil
Copy link
Owner

Exciting, looking forward for the llama index community contribution.

@hvico
Copy link

hvico commented Jun 14, 2024

Hi everyone! Looking forward to test LlamaIndex integration with infinity, also willing to test on a ROCm based workstation. Hope it is realeased soon. Thanks!

@semoal
Copy link
Author

semoal commented Jun 14, 2024

This our code guys, sorry I couldn't find time yet to publish a package on LlamaIndex but I share the code that we've been using already for 4 months and working like a charm.

"""written under MIT Licence, Sergio Moreno inspired by Michael Feil 2023."""

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple

import aiohttp
import numpy as np
import requests

from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.embeddings.base import (
    BaseEmbedding,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.base.llms.generic_utils import get_from_param_or_env

__all__ = ["InfinityEmbeddings"]


class TinyAsyncOpenAIInfinityEmbeddingClient:  #: :meta private:
    """A helper tool to embed Infinity. Not part of LlamaIndex's stable API,
    direct use discouraged.
    """

    host: str = Field(default="http://localhost:7797")
    batch_size: Optional[int] = Field(
        default=128, description="Batch size for embedding."
    )
    # _aiosession: Optional[aiohttp.ClientSession] = PrivateAttr(default=None)

    def __init__(
        self,
        host: str = "http://localhost:7797",
        batch_size: Optional[int] = 1024,
        # aiosession: Optional[aiohttp.ClientSession] = None,
    ) -> None:
        self.host = host
        self.batch_size = batch_size
        self.embed_batch_size = batch_size
        # self._aiosession = aiosession

        if self.host is None or len(self.host) < 3:
            raise ValueError(" param `host` must be set to a valid url")

    @staticmethod
    def _permute(
        input: List[str], sorter: Callable = len
    ) -> Tuple[List[str], Callable]:
        """Sort texts in ascending order, and
        delivers a lambda expr, which can sort a same length list
        https://github.com/UKPLab/sentence-transformers/blob/
        c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156

        Args:
            input (List[str]): _description_
            sorter (Callable, optional): _description_. Defaults to len.

        Returns:
            Tuple[List[str], Callable]: _description_

        Example:
            ```
            texts = ["one","three","four"]
            perm_texts, undo = self._permute(texts)
            texts == undo(perm_texts)
            ```
        """

        if len(input) == 1:
            # special case query
            return input, lambda t: t
        length_sorted_idx = np.argsort([-sorter(sen) for sen in input])
        texts_sorted = [input[idx] for idx in length_sorted_idx]

        return texts_sorted, lambda unsorted_embeddings: [
            unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
        ]

    def _batch(self, input: List[str]) -> List[List[str]]:
        """
        splits Lists of text parts into batches of size max `self.batch_size`
        When encoding vector database,

        Args:
            input (List[str]): List of sentences
            self.batch_size (int, optional): max batch size of one request.

        Returns:
            List[List[str]]: Batches of List of sentences
        """

        if len(input) == 1 and len(input[0]) < self.batch_size:
            # special case query
            return [input]
        elif len(input) == 1 and len(input[0]) > self.batch_size:
            return [
                input[0][i : i + self.batch_size]
                for i in range(0, len(input[0]), self.batch_size)
            ]
        else:
            batches = []
            for start_index in range(0, len(input), self.batch_size):
                batches.append(input[start_index : start_index + self.batch_size])
            return batches

    @staticmethod
    def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
        if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
            # special case query
            return batch_of_texts[0]
        texts = []
        for sublist in batch_of_texts:
            texts.extend(sublist)
        return texts

    def _kwargs_post_request(self, model: str, input: List[str]) -> Dict[str, Any]:
        """Build the kwargs for the Post request, used by sync

        Args:
            model (str): _description_
            input (List[str]): _description_

        Returns:
            Dict[str, Collection[str]]: _description_
        """
        return dict(
            url=f"{self.host}/embeddings",
            headers={
                "content-type": "application/json",
            },
            json=dict(
                input=input,
                model=model,
            ),
        )

    def _sync_request_embed(
        self, model: str, batch_texts: List[str]
    ) -> List[List[float]]:
        response = requests.post(
            **self._kwargs_post_request(model=model, input=batch_texts)
        )
        if response.status_code != 200:
            raise Exception(
                f"Infinity returned an unexpected response with status "
                f"{response.status_code}: {response.text}"
            )
        return [e["embedding"] for e in response.json()["data"]]

    def embed(self, model: str, input: List[str]) -> List[List[float]]:
        """call the embedding of model

        Args:
            model (str): to embedding model
            input (List[str]): List of sentences to embed.

        Returns:
            List[List[float]]: List of vectors for each sentence
        """
        perm_texts, unpermute_func = self._permute(input)
        perm_texts_batched = self._batch(perm_texts)

        # Request
        map_args = (
            self._sync_request_embed,
            [model] * len(perm_texts_batched),
            perm_texts_batched,
        )
        if len(perm_texts_batched) == 1:
            embeddings_batch_perm = list(map(*map_args))
        else:
            with ThreadPoolExecutor(32) as p:
                embeddings_batch_perm = list(p.map(*map_args))

        embeddings_perm = self._unbatch(embeddings_batch_perm)
        embeddings = unpermute_func(embeddings_perm)
        return embeddings

    async def _async_request(
        self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
    ) -> List[List[float]]:
        async with session.post(**kwargs) as response:
            if response.status != 200:
                raise Exception(
                    f"Infinity returned an unexpected response with status "
                    f"{response.status}: {response.text}"
                )
            return [e["embedding"] for e in (await response.json())["data"]]

    async def aembed(self, model: str, input: List[str]) -> List[List[float]]:
        """call the embedding of model, async method

        Args:
            model (str): to embedding model
            input (List[str]): List of sentences to embed.

        Returns:
            List[List[float]]: List of vectors for each sentence
        """
        perm_texts, unpermute_func = self._permute(input)
        perm_texts_batched = self._batch(perm_texts)

        # Request

        async with aiohttp.ClientSession(trust_env=True) as session:
            embeddings_batch_perm = await asyncio.gather(
                *[
                    self._async_request(
                        session=session,
                        kwargs=self._kwargs_post_request(model=model, input=t),
                    )
                    for t in perm_texts_batched
                ]
            )

        embeddings_perm = self._unbatch(embeddings_batch_perm)
        embeddings = unpermute_func(embeddings_perm)
        return embeddings


class InfinityEmbeddings(BaseEmbedding):
    """Infinity class for embeddings.

    Args:
        model (str): Model for embedding.
            Defaults to `jinaai/jina-embeddings-v2-base-es`
        infinity_api_url (str): Infinity API URL.
            Defaults to `http://localhost:7997`
        callback_manager (Optional[CallbackManager]): Callback manager.
            Defaults to None.
        **kwargs: Additional keyword arguments.
    """

    model: str = Field(
        default="jinaai/jina-embeddings-v2-base-es", description="Model for embedding."
    )
    infinity_api_url: Optional[str] = Field(
        default="http://localhost:7997", description="Infinity API URL."
    )
    _session: Optional[TinyAsyncOpenAIInfinityEmbeddingClient] = PrivateAttr()

    def __init__(
        self,
        model: str = "jinaai/jina-embeddings-v2-base-es",
        infinity_api_url: Optional[str] = None,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            callback_manager=callback_manager,
            model=model,
            infinity_api_url=infinity_api_url,
            **kwargs,
        )
        self.infinity_api_url = get_from_param_or_env(
            "infinity_api_url",
            infinity_api_url,
            "INFINITY_API_URL",
            "http://localhost:7997",
        )
        self.model = model
        self._session = TinyAsyncOpenAIInfinityEmbeddingClient(
            host=self.infinity_api_url,
        )
        self.embed_batch_size = self._session.batch_size

    @classmethod
    def class_name(cls) -> str:
        return "InfinityEmbedding"

    def _get_text_embedding(self, text: List[str]) -> List[float]:
        """Get text embedding."""
        return self._get_text_embeddings(text)

    async def _aget_text_embedding(self, text: List[str]) -> List[float]:
        """Asynchronously get text embedding."""
        result = await self._aget_text_embeddings(text)
        return result

    def _get_text_embeddings(self, input: List[str]) -> List[List[float]]:
        """Call out to Infinity's embedding endpoint.

        Args:
            input: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        embeddings = self._session.embed(
            model=self.model,
            input=input,
        )
        return embeddings

    async def _aget_text_embeddings(self, input: List[str]) -> List[List[float]]:
        """Async call out to Infinity's embedding endpoint.

        Args:
            input: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        embeddings = await self._session.aembed(
            model=self.model,
            input=input,
        )
        return embeddings

    def _get_query_embedding(self, text: str) -> List[float]:
        """Call out to Infinity's embedding endpoint.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        return self._get_text_embedding([text])[0]

    async def _aget_query_embedding(self, text: str) -> List[float]:
        """Async call out to Infinity's embedding endpoint.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        embeddings = await self._aget_text_embedding([text])
        return embeddings[0]

@semoal
Copy link
Author

semoal commented Jun 14, 2024

To use it, in our case we directly do:

from rag.embeddings.infinity import InfinityEmbeddings
EMBED_MODEL = InfinityEmbeddings(
    model="jinaai/jina-embeddings-v2-base-es",
)

LlamaIndex service context:

def get_service_context(
    config: BaseConfig,
    prompt_schema: BasePromptSchema,
    transformations: List[TransformComponent] | None = None,
):
    completion_to_prompt = prompt_schema.completion_to_prompt
    messages_to_prompt = prompt_schema.messages_to_prompt
    return ServiceContext.from_defaults(
        context_window=config.context_window,
        llm=Fireworks(
            model=config.model,
            api_key=config.api_key,
            num_output=config.num_output,
            streaming=config.streaming,
            temperature=config.temperature,
            completion_to_prompt=completion_to_prompt,
            messages_to_prompt=messages_to_prompt,
            response_format=config.response_format,
        ),
        embed_model=EMBED_MODEL,
        chunk_size=CHUNK_SIZE,
        transformations=transformations
        or [SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_SIZE // 2)],
    )

@hvico
Copy link

hvico commented Jun 14, 2024

Wow that was fast! Thank you @semoal, I will test this on my setup right now!

@hvico
Copy link

hvico commented Jun 15, 2024

Hi. I installed the ROCm 6.1 version of pytorch, ran infinity server using that, and then I adapted my code to use the InfinityEmbedding class @semoal provided here. It is working perfectly well, I was able to reproduce the same accuracy I was having using HugginfaceEmbeddings directly over CUDA. Thanks!!

@luoyangen
Copy link

@semoal thx for sharing the code. Would you share the reranker class? Thanks again~

@semoal
Copy link
Author

semoal commented Jul 3, 2024

@semoal thx for sharing the code. Would you share the reranker class? Thanks again~

import requests
from typing import List, Optional

from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import BaseModel, Field
from llama_index.core.base.llms.generic_utils import get_from_param_or_env


class Result(BaseModel):
    relevance_score: float
    index: int


class Usage(BaseModel):
    prompt_tokens: int
    total_tokens: int


class RankResult(BaseModel):
    model: str
    results: List[Result]
    usage: Usage


class BgeRerank(BaseNodePostprocessor):
    infinity_api_url: Optional[str] = Field(
        default="http://localhost:7997",
        description="The host of the BGE reranker service",
    )

    def __init__(
        self,
        infinity_api_url: Optional[str] = None,
    ):
        super().__init__()
        self.infinity_api_url = get_from_param_or_env(
            "infinity_api_url",
            infinity_api_url,
            "INFINITY_API_URL",
            "http://localhost:7997",
        )

    @classmethod
    def class_name(cls) -> str:
        return "BGERerank"

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:
        if query_bundle is None:
            raise ValueError("Missing query bundle in extra info.")
        if len(nodes) == 0:
            return []

        with self.callback_manager.event(
            CBEventType.RERANKING,
            payload={
                EventPayload.NODES: nodes,
                EventPayload.MODEL_NAME: self.class_name(),
                EventPayload.QUERY_STR: query_bundle.query_str,
            },
        ) as event:
            texts = [node.node.get_content() for node in nodes]

            results = self._sync_request_rank(
                query=query_bundle.query_str,
                documents=texts,
            )

            new_nodes = []
            for result in results:
                new_node_with_score = NodeWithScore(
                    node=nodes[result.index].node, score=result.relevance_score
                )
                new_nodes.append(new_node_with_score)
            event.on_end(payload={EventPayload.NODES: new_nodes})

        return new_nodes

    def _kwargs_post_request(self, query: str, documents: List[str]) -> RankResult:
        """Build the kwargs for the Post request, used by sync

        Args:
            query (str): _description_
            documents (List[str]): _description_

        Returns:
            RankResult: _description_
        """
        return dict(
            url=f"{self.infinity_api_url}/rerank",
            headers={
                "content-type": "application/json",
            },
            json=dict(
                query=query,
                documents=documents,
            ),
        )

    def _sync_request_rank(self, query: str, documents: List[str]) -> List[Result]:
        response = requests.post(
            **self._kwargs_post_request(query=query, documents=documents)
        )
        if response.status_code != 200:
            raise Exception(
                f"Infinity returned an unexpected response with status "
                f"{response.status_code}: {response.text}"
            )
        return [Result(**result) for result in response.json()["results"]]

On infinity we just load the modal as another endpoint and that's it 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants