Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential NoneType error #1247

Merged
merged 5 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 19 additions & 13 deletions examples/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
)
from metagpt.utils.exceptions import handle_exception

LLM_TIP = "If you not sure, just answer I don't know."

DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = "What are key qualities to be a good writer?"
QUESTION = f"What are key qualities to be a good writer? {LLM_TIP}"

TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = "What does Bob like?"

LLM_TIP = "If you not sure, just answer I don't know."
TRAVEL_QUESTION = f"What does Bob like? {LLM_TIP}"


class Player(BaseModel):
Expand All @@ -40,21 +40,21 @@ def rag_key(self) -> str:


class RAGExample:
"""Show how to use RAG.

Default engine use LLM Reranker, if the answer from the LLM is incorrect, may encounter `IndexError: list index out of range`.
"""
"""Show how to use RAG."""

def __init__(self, engine: SimpleEngine = None):
def __init__(self, engine: SimpleEngine = None, use_llm_ranker: bool = True):
self._engine = engine
self._use_llm_ranker = use_llm_ranker

@property
def engine(self):
if not self._engine:
ranker_configs = [LLMRankerConfig()] if self._use_llm_ranker else None

self._engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
ranker_configs=ranker_configs,
)
return self._engine

Expand Down Expand Up @@ -105,7 +105,7 @@ async def add_docs(self):
"""
self._print_title("Add Docs")

travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
travel_question = f"{TRAVEL_QUESTION}"
travel_filepath = TRAVEL_DOC_PATH

logger.info("[Before add docs]")
Expand Down Expand Up @@ -240,8 +240,14 @@ async def _retrieve_and_print(self, question):


async def main():
"""RAG pipeline."""
e = RAGExample()
"""RAG pipeline.

Note:
1. If `use_llm_ranker` is True, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking,
prefer `gpt-4-turbo`, otherwise might encounter `IndexError: list index out of range` or `ValueError: invalid literal for int() with base 10`.
"""
e = RAGExample(use_llm_ranker=False)

await e.run_pipeline()
await e.add_docs()
await e.add_objects()
Expand Down
2 changes: 1 addition & 1 deletion metagpt/rag/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[
return mrr_sum

return mrr_sum

async def semantic_similarity(self, response: str, reference: str) -> float:
result = await self.evaluator.aevaluate(
response=response,
Expand Down
10 changes: 6 additions & 4 deletions metagpt/rag/factories/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
BGERerankConfig,
CohereRerankConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
CohereRerankConfig,
BGERerankConfig
)


Expand Down Expand Up @@ -60,13 +60,15 @@ def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRera

def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
from llama_index.postprocessor.flag_embedding_reranker import (
FlagEmbeddingReranker,
)
except ImportError:
raise ImportError(
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
)
return FlagEmbeddingReranker(**config.model_dump())

def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

Expand Down
6 changes: 4 additions & 2 deletions metagpt/rag/retrievers/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)

self._index.insert_nodes(nodes, **kwargs)
if self._index:
self._index.insert_nodes(nodes, **kwargs)

def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)
if self._index:
self._index.storage_context.persist(persist_dir)