diff --git a/src/ragas/metrics/_answer_correctness.py b/src/ragas/metrics/_answer_correctness.py index 89f83005e..3eed376fd 100644 --- a/src/ragas/metrics/_answer_correctness.py +++ b/src/ragas/metrics/_answer_correctness.py @@ -142,6 +142,7 @@ def _score_batch( prediction = ( prediction if isinstance(prediction, list) else [prediction] ) + if prediction: prediction = [ item.get(key_map[k], np.nan) @@ -152,7 +153,11 @@ def _score_batch( len(item) if isinstance(item, list) else np.nan for item in prediction ] - score = tp / (tp + 0.5 * (fp + fn)) + + if any([np.isnan(i) for i in [tp, fp, fn]]): + score = np.nan + else: + score = tp / (tp + 0.5 * (fp + fn)) if tp > 0 else 0 else: score = np.nan