diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index 6a959d46d..b5b22275e 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -65,12 +65,10 @@ class AnswerCorrectness(MetricWithLLM): batch_size: int batch size for evaluation weights: - a list of two weights corresponding to semantic similarity and factuality - Defaults [0.5, 0.5] + a list of two weights corresponding to factuality and semantic similarity + Defaults [0.75, 0.25] answer_similarity: The AnswerSimilarity object - faithfulness - The faithfulness object """ name: str = "answer_correctness" # type: ignore[reportIncompatibleMethodOverride] @@ -108,41 +106,41 @@ def _score_batch( ) prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - result = self.llm.generate(prompts, callbacks=batch_group) - outputs = result.generations - key_map = { - "TP": "statements that are present in both the answer and the ground truth", - "FP": "statements present in the answer but not found in the ground truth", - "FN": "relevant statements found in the ground truth but omitted in the answer", # noqa: E501 - } - - f1_score = [] - for prediction in outputs: - prediction = json_loader.safe_load(prediction[0].text, self.llm) - prediction = prediction if isinstance(prediction, list) else [] - if prediction: - prediction = [ - item.get(key_map[k], np.nan) - for item in prediction - for k in key_map.keys() - ] - tp, fp, fn = [ - len(item) if isinstance(item, list) else np.nan - for item in prediction - ] - score = tp / (tp + 0.5 * (fp + fn)) - else: - score = np.nan - - f1_score.append(score) - - similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore - scores_stacked = np.vstack([f1_score, similarity_scores]) - scores = np.average( - scores_stacked, - axis=0, - weights=self.weights, - ) + result = self.llm.generate(prompts, callbacks=batch_group) + outputs = result.generations + key_map = { + "TP": "statements that are present in both the answer and the ground truth", + "FP": "statements present in the answer but not found in the ground truth", + "FN": "relevant statements found in the ground truth but omitted in the answer", # noqa: E501 + } + + f1_score = [] + for prediction in outputs: + prediction = json_loader.safe_load(prediction[0].text, self.llm) + prediction = prediction if isinstance(prediction, list) else [] + if prediction: + prediction = [ + item.get(key_map[k], np.nan) + for item in prediction + for k in key_map.keys() + ] + tp, fp, fn = [ + len(item) if isinstance(item, list) else np.nan + for item in prediction + ] + score = tp / (tp + 0.5 * (fp + fn)) + else: + score = np.nan + + f1_score.append(score) + + similarity_scores = self.answer_similarity._score_batch(dataset, callbacks=batch_group) # type: ignore + scores_stacked = np.vstack([f1_score, similarity_scores]) + scores = np.average( + scores_stacked, + axis=0, + weights=self.weights, + ) return scores.tolist()