From 7a46111980b29da9ffc0b07e592816de19f578aa Mon Sep 17 00:00:00 2001 From: Shawn Liu Date: Sun, 18 Feb 2024 16:00:54 -0500 Subject: [PATCH] made embeddings and llms tied to the metric in evaluate function --- src/ragas/evaluation.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index fe238e4f7..c8b40417d 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -10,7 +10,8 @@ from ragas._analytics import EvaluationEvent, track from ragas.callbacks import new_group -from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper +from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper, embedding_factory +from ragas.llms import llm_factory from ragas.exceptions import ExceptionInRunner from ragas.executor import Executor from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper @@ -57,11 +58,11 @@ def evaluate( evaluation on the best set of metrics to give a complete view. llm: BaseRagasLLM, optional The language model to use for the metrics. If not provided then ragas will use - the default language model. This can we overridden by the llm specified in + the default language model for metrics which require an LLM. This can we overridden by the llm specified in the metric level with `metric.llm`. embeddings: BaseRagasEmbeddings, optional The embeddings to use for the metrics. If not provided then ragas will use - the default embeddings. This can we overridden by the embeddings specified in + the default embeddings for metrics which require embeddings. This can we overridden by the embeddings specified in the metric level with `metric.embeddings`. callbacks: Callbacks, optional Lifecycle Langchain Callbacks to run during evaluation. Check the @@ -144,34 +145,30 @@ def evaluate( validate_column_dtypes(dataset) # set the llm and embeddings - if llm is None: - from ragas.llms import llm_factory - - llm = llm_factory() - elif isinstance(llm, LangchainLLM): + if isinstance(llm, LangchainLLM): llm = LangchainLLMWrapper(llm, run_config=run_config) - if embeddings is None: - from ragas.embeddings.base import embedding_factory - - embeddings = embedding_factory() - elif isinstance(embeddings, LangchainEmbeddings): + if isinstance(embeddings, LangchainEmbeddings): embeddings = LangchainEmbeddingsWrapper(embeddings) + # init llms and embeddings binary_metrics = [] llm_changed: t.List[int] = [] embeddings_changed: t.List[int] = [] answer_correctness_is_set = -1 + for i, metric in enumerate(metrics): if isinstance(metric, AspectCritique): binary_metrics.append(metric.name) - if isinstance(metric, MetricWithLLM): - if metric.llm is None: - metric.llm = llm - llm_changed.append(i) - if isinstance(metric, MetricWithEmbeddings): - if metric.embeddings is None: - metric.embeddings = embeddings - embeddings_changed.append(i) + if isinstance(metric, MetricWithLLM) and metric.llm is None: + if llm is None: + llm = llm_factory() + metric.llm = llm + llm_changed.append(i) + if isinstance(metric, MetricWithEmbeddings) and metric.embeddings is None: + if embeddings is None: + embeddings = embedding_factory() + metric.embeddings = embeddings + embeddings_changed.append(i) if isinstance(metric, AnswerCorrectness): if metric.answer_similarity is None: answer_correctness_is_set = i