Skip to content

Commit

Permalink
Merge pull request #1224 from seehi/fix-rag-redundant-embedding
Browse files Browse the repository at this point in the history
Fix the potential duplicate embeddings in the RAG module
  • Loading branch information
geekan committed Apr 24, 2024
2 parents 12a7825 + 1607228 commit c779f69
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 75 deletions.
11 changes: 9 additions & 2 deletions examples/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def rag_key(self) -> str:


class RAGExample:
"""Show how to use RAG."""
"""Show how to use RAG.
Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`.
"""

def __init__(self, engine: SimpleEngine = None):
self._engine = engine
Expand All @@ -59,6 +62,7 @@ def engine(self):
def engine(self, value: SimpleEngine):
self._engine = value

@handle_exception
async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss retriever and llm ranker, will print something like:
Expand All @@ -79,6 +83,7 @@ async def run_pipeline(self, question=QUESTION, print_title=True):
answer = await self.engine.aquery(question)
self._print_query_result(answer)

@handle_exception
async def add_docs(self):
"""This example show how to add docs.
Expand Down Expand Up @@ -148,6 +153,7 @@ async def add_objects(self, print_title=True):
except Exception as e:
logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")

@handle_exception
async def init_objects(self):
"""This example show how to from objs, will print something like:
Expand All @@ -160,6 +166,7 @@ async def init_objects(self):
await self.add_objects(print_title=False)
self.engine = pre_engine

@handle_exception
async def init_and_query_chromadb(self):
"""This example show how to use chromadb. how to save and load index. will print something like:
Expand Down Expand Up @@ -233,7 +240,7 @@ async def _retrieve_and_print(self, question):


async def main():
"""RAG pipeline"""
"""RAG pipeline."""
e = RAGExample()
await e.run_pipeline()
await e.add_docs()
Expand Down
63 changes: 49 additions & 14 deletions metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from typing import Any, Optional, Union

from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
Expand Down Expand Up @@ -63,15 +63,15 @@ def __init__(
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()

@classmethod
def from_docs(
Expand Down Expand Up @@ -103,12 +103,17 @@ def from_docs(
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)

index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)

return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)

@classmethod
def from_objs(
Expand Down Expand Up @@ -137,12 +142,15 @@ def from_objs(
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")

nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(

return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)

@classmethod
def from_index(
Expand Down Expand Up @@ -183,7 +191,7 @@ def add_docs(self, input_files: list[str]):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)

nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)

def add_objs(self, objs: list[RAGObject]):
Expand All @@ -199,6 +207,29 @@ def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):

self._persist(str(persist_dir), **kwargs)

@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()

retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)

@classmethod
def _from_index(
cls,
Expand All @@ -208,14 +239,14 @@ def _from_index(
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()

retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []

return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)

def _ensure_retriever_modifiable(self):
Expand Down Expand Up @@ -266,3 +297,7 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] =
return MockEmbedding(embed_dim=1)

return embed_model or get_rag_embedding()

@staticmethod
def _default_transformations():
return [SentenceSplitter()]
17 changes: 11 additions & 6 deletions metagpt/rag/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""

def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)

self._raise_for_key(key)

def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")

@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
Expand All @@ -57,6 +64,4 @@ def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any
if key in kwargs:
return kwargs[key]

raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None
89 changes: 69 additions & 20 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""RAG Retriever Factory."""

import copy

from functools import wraps

import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
Expand All @@ -24,10 +27,25 @@
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)


def get_or_build_index(build_index_func):
"""Decorator to get or build an index.
Get index using `_extract_index` method, if not found, using build_index_func.
"""

@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)

return wrapper


class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""

Expand All @@ -54,48 +72,79 @@ def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) ->
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]

def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)

return index.as_retriever()

def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)

return FAISSRetriever(**config.model_dump())

def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)

return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)

vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)

return ChromaRetriever(**config.model_dump())

def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)

return ElasticsearchRetriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)

def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)

def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)

return index

@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

return self._build_index_from_vector_store(config, vector_store, **kwargs)

def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index

return index


get_retriever = RetrieverFactory().get_retriever
Loading

0 comments on commit c779f69

Please sign in to comment.