-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create llama-index InfinityEmbeddings
as langchain
#111
Comments
@semoal Love the idea - and I also love llamaindex. Im low on time, but I recently opened an issue myself. Also - feel free to copy anything from my langchain PR / langchain integration. Ill be happily review your PR and update this Readme to highlight the integration. |
@semoal Any help required here? |
No, we have it done and operative in production already. Even the reranker embedding for LlamaIndex too, using Infinity as backend; we've been lately a bit busy to reach a deadline for 22th, but we'll push it forward soon, before the end of the month. |
Exciting, looking forward for the llama index community contribution. |
Hi everyone! Looking forward to test LlamaIndex integration with infinity, also willing to test on a ROCm based workstation. Hope it is realeased soon. Thanks! |
This our code guys, sorry I couldn't find time yet to publish a package on LlamaIndex but I share the code that we've been using already for 4 months and working like a charm. """written under MIT Licence, Sergio Moreno inspired by Michael Feil 2023."""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple
import aiohttp
import numpy as np
import requests
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.embeddings.base import (
BaseEmbedding,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
__all__ = ["InfinityEmbeddings"]
class TinyAsyncOpenAIInfinityEmbeddingClient: #: :meta private:
"""A helper tool to embed Infinity. Not part of LlamaIndex's stable API,
direct use discouraged.
"""
host: str = Field(default="http://localhost:7797")
batch_size: Optional[int] = Field(
default=128, description="Batch size for embedding."
)
# _aiosession: Optional[aiohttp.ClientSession] = PrivateAttr(default=None)
def __init__(
self,
host: str = "http://localhost:7797",
batch_size: Optional[int] = 1024,
# aiosession: Optional[aiohttp.ClientSession] = None,
) -> None:
self.host = host
self.batch_size = batch_size
self.embed_batch_size = batch_size
# self._aiosession = aiosession
if self.host is None or len(self.host) < 3:
raise ValueError(" param `host` must be set to a valid url")
@staticmethod
def _permute(
input: List[str], sorter: Callable = len
) -> Tuple[List[str], Callable]:
"""Sort texts in ascending order, and
delivers a lambda expr, which can sort a same length list
https://github.com/UKPLab/sentence-transformers/blob/
c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
Args:
input (List[str]): _description_
sorter (Callable, optional): _description_. Defaults to len.
Returns:
Tuple[List[str], Callable]: _description_
Example:
```
texts = ["one","three","four"]
perm_texts, undo = self._permute(texts)
texts == undo(perm_texts)
```
"""
if len(input) == 1:
# special case query
return input, lambda t: t
length_sorted_idx = np.argsort([-sorter(sen) for sen in input])
texts_sorted = [input[idx] for idx in length_sorted_idx]
return texts_sorted, lambda unsorted_embeddings: [
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]
def _batch(self, input: List[str]) -> List[List[str]]:
"""
splits Lists of text parts into batches of size max `self.batch_size`
When encoding vector database,
Args:
input (List[str]): List of sentences
self.batch_size (int, optional): max batch size of one request.
Returns:
List[List[str]]: Batches of List of sentences
"""
if len(input) == 1 and len(input[0]) < self.batch_size:
# special case query
return [input]
elif len(input) == 1 and len(input[0]) > self.batch_size:
return [
input[0][i : i + self.batch_size]
for i in range(0, len(input[0]), self.batch_size)
]
else:
batches = []
for start_index in range(0, len(input), self.batch_size):
batches.append(input[start_index : start_index + self.batch_size])
return batches
@staticmethod
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
# special case query
return batch_of_texts[0]
texts = []
for sublist in batch_of_texts:
texts.extend(sublist)
return texts
def _kwargs_post_request(self, model: str, input: List[str]) -> Dict[str, Any]:
"""Build the kwargs for the Post request, used by sync
Args:
model (str): _description_
input (List[str]): _description_
Returns:
Dict[str, Collection[str]]: _description_
"""
return dict(
url=f"{self.host}/embeddings",
headers={
"content-type": "application/json",
},
json=dict(
input=input,
model=model,
),
)
def _sync_request_embed(
self, model: str, batch_texts: List[str]
) -> List[List[float]]:
response = requests.post(
**self._kwargs_post_request(model=model, input=batch_texts)
)
if response.status_code != 200:
raise Exception(
f"Infinity returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
return [e["embedding"] for e in response.json()["data"]]
def embed(self, model: str, input: List[str]) -> List[List[float]]:
"""call the embedding of model
Args:
model (str): to embedding model
input (List[str]): List of sentences to embed.
Returns:
List[List[float]]: List of vectors for each sentence
"""
perm_texts, unpermute_func = self._permute(input)
perm_texts_batched = self._batch(perm_texts)
# Request
map_args = (
self._sync_request_embed,
[model] * len(perm_texts_batched),
perm_texts_batched,
)
if len(perm_texts_batched) == 1:
embeddings_batch_perm = list(map(*map_args))
else:
with ThreadPoolExecutor(32) as p:
embeddings_batch_perm = list(p.map(*map_args))
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings
async def _async_request(
self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
) -> List[List[float]]:
async with session.post(**kwargs) as response:
if response.status != 200:
raise Exception(
f"Infinity returned an unexpected response with status "
f"{response.status}: {response.text}"
)
return [e["embedding"] for e in (await response.json())["data"]]
async def aembed(self, model: str, input: List[str]) -> List[List[float]]:
"""call the embedding of model, async method
Args:
model (str): to embedding model
input (List[str]): List of sentences to embed.
Returns:
List[List[float]]: List of vectors for each sentence
"""
perm_texts, unpermute_func = self._permute(input)
perm_texts_batched = self._batch(perm_texts)
# Request
async with aiohttp.ClientSession(trust_env=True) as session:
embeddings_batch_perm = await asyncio.gather(
*[
self._async_request(
session=session,
kwargs=self._kwargs_post_request(model=model, input=t),
)
for t in perm_texts_batched
]
)
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings
class InfinityEmbeddings(BaseEmbedding):
"""Infinity class for embeddings.
Args:
model (str): Model for embedding.
Defaults to `jinaai/jina-embeddings-v2-base-es`
infinity_api_url (str): Infinity API URL.
Defaults to `http://localhost:7997`
callback_manager (Optional[CallbackManager]): Callback manager.
Defaults to None.
**kwargs: Additional keyword arguments.
"""
model: str = Field(
default="jinaai/jina-embeddings-v2-base-es", description="Model for embedding."
)
infinity_api_url: Optional[str] = Field(
default="http://localhost:7997", description="Infinity API URL."
)
_session: Optional[TinyAsyncOpenAIInfinityEmbeddingClient] = PrivateAttr()
def __init__(
self,
model: str = "jinaai/jina-embeddings-v2-base-es",
infinity_api_url: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
callback_manager=callback_manager,
model=model,
infinity_api_url=infinity_api_url,
**kwargs,
)
self.infinity_api_url = get_from_param_or_env(
"infinity_api_url",
infinity_api_url,
"INFINITY_API_URL",
"http://localhost:7997",
)
self.model = model
self._session = TinyAsyncOpenAIInfinityEmbeddingClient(
host=self.infinity_api_url,
)
self.embed_batch_size = self._session.batch_size
@classmethod
def class_name(cls) -> str:
return "InfinityEmbedding"
def _get_text_embedding(self, text: List[str]) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings(text)
async def _aget_text_embedding(self, text: List[str]) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings(text)
return result
def _get_text_embeddings(self, input: List[str]) -> List[List[float]]:
"""Call out to Infinity's embedding endpoint.
Args:
input: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = self._session.embed(
model=self.model,
input=input,
)
return embeddings
async def _aget_text_embeddings(self, input: List[str]) -> List[List[float]]:
"""Async call out to Infinity's embedding endpoint.
Args:
input: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = await self._session.aembed(
model=self.model,
input=input,
)
return embeddings
def _get_query_embedding(self, text: str) -> List[float]:
"""Call out to Infinity's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self._get_text_embedding([text])[0]
async def _aget_query_embedding(self, text: str) -> List[float]:
"""Async call out to Infinity's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embeddings = await self._aget_text_embedding([text])
return embeddings[0] |
To use it, in our case we directly do: from rag.embeddings.infinity import InfinityEmbeddings
EMBED_MODEL = InfinityEmbeddings(
model="jinaai/jina-embeddings-v2-base-es",
) LlamaIndex service context: def get_service_context(
config: BaseConfig,
prompt_schema: BasePromptSchema,
transformations: List[TransformComponent] | None = None,
):
completion_to_prompt = prompt_schema.completion_to_prompt
messages_to_prompt = prompt_schema.messages_to_prompt
return ServiceContext.from_defaults(
context_window=config.context_window,
llm=Fireworks(
model=config.model,
api_key=config.api_key,
num_output=config.num_output,
streaming=config.streaming,
temperature=config.temperature,
completion_to_prompt=completion_to_prompt,
messages_to_prompt=messages_to_prompt,
response_format=config.response_format,
),
embed_model=EMBED_MODEL,
chunk_size=CHUNK_SIZE,
transformations=transformations
or [SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_SIZE // 2)],
) |
Wow that was fast! Thank you @semoal, I will test this on my setup right now! |
Hi. I installed the ROCm 6.1 version of pytorch, ran infinity server using that, and then I adapted my code to use the InfinityEmbedding class @semoal provided here. It is working perfectly well, I was able to reproduce the same accuracy I was having using HugginfaceEmbeddings directly over CUDA. Thanks!! |
@semoal thx for sharing the code. Would you share the reranker class? Thanks again~ |
import requests
from typing import List, Optional
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import BaseModel, Field
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
class Result(BaseModel):
relevance_score: float
index: int
class Usage(BaseModel):
prompt_tokens: int
total_tokens: int
class RankResult(BaseModel):
model: str
results: List[Result]
usage: Usage
class BgeRerank(BaseNodePostprocessor):
infinity_api_url: Optional[str] = Field(
default="http://localhost:7997",
description="The host of the BGE reranker service",
)
def __init__(
self,
infinity_api_url: Optional[str] = None,
):
super().__init__()
self.infinity_api_url = get_from_param_or_env(
"infinity_api_url",
infinity_api_url,
"INFINITY_API_URL",
"http://localhost:7997",
)
@classmethod
def class_name(cls) -> str:
return "BGERerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.class_name(),
EventPayload.QUERY_STR: query_bundle.query_str,
},
) as event:
texts = [node.node.get_content() for node in nodes]
results = self._sync_request_rank(
query=query_bundle.query_str,
documents=texts,
)
new_nodes = []
for result in results:
new_node_with_score = NodeWithScore(
node=nodes[result.index].node, score=result.relevance_score
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
return new_nodes
def _kwargs_post_request(self, query: str, documents: List[str]) -> RankResult:
"""Build the kwargs for the Post request, used by sync
Args:
query (str): _description_
documents (List[str]): _description_
Returns:
RankResult: _description_
"""
return dict(
url=f"{self.infinity_api_url}/rerank",
headers={
"content-type": "application/json",
},
json=dict(
query=query,
documents=documents,
),
)
def _sync_request_rank(self, query: str, documents: List[str]) -> List[Result]:
response = requests.post(
**self._kwargs_post_request(query=query, documents=documents)
)
if response.status_code != 200:
raise Exception(
f"Infinity returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
return [Result(**result) for result in response.json()["results"]] On infinity we just load the modal as another endpoint and that's it 👍 |
Hi! Kudos for this project Michael! It is amazing.
We're migrating from a single repo with a RAG and and T40, to one repo with a RAG with just cpu and and another service with our embeddings models, rerankers... and just start/stop this machine (which is expensive) when traffic arrives.
We have seen your tool and looks promising, and we're willing to contribute.
Did you consider support llama-index? I think we (my company) could work on the llama-index integration, we can ping you when it's done to review.
What do you think?
The text was updated successfully, but these errors were encountered: