From 8f209bfeb5cd43b4de7dab5601c5a42c935fde6c Mon Sep 17 00:00:00 2001 From: Wei-Jianan Date: Thu, 9 May 2024 21:11:59 +0800 Subject: [PATCH] [fix] stream field in llmconfig not work --- metagpt/provider/base_llm.py | 5 ++++- metagpt/rag/benchmark/base.py | 2 +- metagpt/rag/factories/ranker.py | 10 ++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 6387e3936..f444e32ba 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -22,6 +22,7 @@ wait_random_exponential, ) +from metagpt.config2 import config from metagpt.configs.llm_config import LLMConfig from metagpt.const import LLM_API_TIMEOUT, USE_CONFIG_TIMEOUT from metagpt.logs import logger @@ -132,7 +133,7 @@ async def aask( format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, - stream=True, + stream=None, ) -> str: if system_msgs: message = self._system_msgs(system_msgs) @@ -146,6 +147,8 @@ async def aask( message.append(self._user_msg(msg, images=images)) else: message.extend(msg) + if stream is None: + stream = config.llm.stream logger.debug(message) rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp diff --git a/metagpt/rag/benchmark/base.py b/metagpt/rag/benchmark/base.py index c1fd297d9..b5d265b35 100644 --- a/metagpt/rag/benchmark/base.py +++ b/metagpt/rag/benchmark/base.py @@ -121,7 +121,7 @@ def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[ return mrr_sum return mrr_sum - + async def semantic_similarity(self, response: str, reference: str) -> float: result = await self.evaluator.aevaluate( response=response, diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index b75745a1f..7abda162a 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -8,11 +8,11 @@ from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor from metagpt.rag.schema import ( BaseRankerConfig, + BGERerankConfig, + CohereRerankConfig, ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig, - CohereRerankConfig, - BGERerankConfig ) @@ -60,13 +60,15 @@ def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRera def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank: try: - from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker + from llama_index.postprocessor.flag_embedding_reranker import ( + FlagEmbeddingReranker, + ) except ImportError: raise ImportError( "`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`" ) return FlagEmbeddingReranker(**config.model_dump()) - + def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank: return ObjectSortPostprocessor(**config.model_dump())