diff --git a/src/ragas/metrics/llms.py b/src/ragas/metrics/llms.py index b361493d6..d7df01514 100644 --- a/src/ragas/metrics/llms.py +++ b/src/ragas/metrics/llms.py @@ -20,10 +20,12 @@ def generate( n: t.Optional[int] = None, ) -> LLMResult: old_n = None + n_swapped = False if n is not None: if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI): old_n = llm.n llm.n = n + n_swapped = True else: raise Exception( f"n={n} was passed to generate but the LLM {llm} does not support it." @@ -36,6 +38,6 @@ def generate( ps = [p.format_messages() for p in prompts] result = llm.generate(ps) - if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI): + if (isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)) and n_swapped: llm.n = old_n # type: ignore return result