Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add passing list of results as option to analyser class #87

Merged
merged 3 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 107 additions & 12 deletions src/rank_llm/analysis/response_analysis_verbose.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,83 @@
import argparse
import json
import os
import re
from typing import Dict, List
import sys
from typing import Dict, List, Tuple, Union

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(SCRIPT_DIR)
parent = os.path.dirname(parent)
sys.path.append(parent)

from rank_llm.result import Result


class ResponseAnalyzer:
def __init__(
self,
files: List[str],
data: Union[List[str], List[Result]],
) -> None:
self._files = files
self._data = data

@staticmethod
def from_inline_results(results: List[Result]) -> "ResponseAnalyzer":
"""
Method to create a ResponseAnalyzer instance from a list of Result objects.

Args:
results (List[Result]): A list of Result objects.

Returns:
ResponseAnalyzer: An instance of the ResponseAnalyzer.
"""
return ResponseAnalyzer(data=results)

@staticmethod
def from_stored_files(filenames: List[str]) -> "ResponseAnalyzer":
"""
Method to create to create a ResponseAnalyzer instance from a list of filenames.

Args:
filenames (List[str]): A list of filenames where each file contains data to be analyzed.

Returns:
ResponseAnalyzer: An instance of the ResponseAnalyzer.
"""
data = []
for filename in filenames:
with open(filename, "r") as file:
file_data = json.load(file)
data.extend(file_data)
return ResponseAnalyzer(data=data)

def read_saved_responses(self) -> List[str]:
def read_results_responses(self) -> Tuple[List[str], List[int]]:
"""
Reads responses from the specified files and produces the total number of passages.
Reads responses from the specified list of Result objects and produces the total number of passages.

Returns:
Tuple[List[str], List[int]]: A tuple object containing a list of responses and a list of corresponding numbers of passages.
"""
num_passages = []
responses = []
for filename in self._files:
with open(filename) as f:
for result in self._data:
for exec_info in result.ranking_exec_summary:
responses.append(exec_info.response)
num_passage = self._get_num_passages(exec_info.prompt)
num_passages.append(int(num_passage))
return responses, num_passages

def read_saved_responses(self) -> Tuple[List[str], List[int]]:
"""
Reads responses from the specified list of files and produces the total number of passages.

Returns:
Tuple[List[str], List[int]]: A tuple object containing a list of responses and a list of corresponding numbers of passages.
"""
num_passages = []
responses = []
for result in self._data:
with open(result) as f:
ranking_exec_summaries = json.load(f)
for summary in ranking_exec_summaries:
for exec_info in summary["ranking_exec_summary"]:
Expand All @@ -29,6 +86,22 @@ def read_saved_responses(self) -> List[str]:
num_passages.append(int(num_passage))
return responses, num_passages

def read_responses(self) -> Tuple[List[str], List[int]]:
"""
Selects what read response class method to call depending on the input type.

Returns:
Tuple[List[str], List[int]]: A tuple object containing a list of responses and a list of corresponding numbers of passages.
"""
if all(isinstance(item, str) for item in self._data):
return self.read_saved_responses()
elif all(isinstance(item, Result) for item in self._data):
return self.read_results_responses()
else:
raise ValueError(
"Input data must be a list of file paths or a list of Result objects."
)

def _validate_format(self, response: str) -> bool:
for c in response:
if not c.isdigit() and c != "[" and c != "]" and c != ">" and c != " ":
Expand All @@ -51,20 +124,18 @@ def _get_num_passages(self, prompt) -> int:
raise ValueError(f"Unsupported prompt format.")
return int(match.group(2))

def count_errors(
self, responses: List[str], num_passages: List[int], verbose: bool = False
) -> Dict[str, int]:
def count_errors(self, verbose: bool = False) -> Dict[str, int]:
"""
Counts an array of different types of errors in the given responses.

Args:
responses (List[str]): A list of response strings.
num_passages (List[int]): A list of the expected number of passages in each response.
verbose (bool, optional): If True, prints the erroneous responses. Defaults to False.

Returns:
Dict[str, int]: A dictionary object containing counts of different types of errors.
"""
responses, num_passages = self.read_responses()

stats_dict = {
"ok": 0,
"wrong_format": 0,
Expand Down Expand Up @@ -105,3 +176,27 @@ def count_errors(
# Round to two decimal places
normalized_stats_dict[key] = round(normalized_stats_dict[key], 2)
return normalized_stats_dict


def main(args):
if args.files:
response_analyzer = ResponseAnalyzer.from_stored_files(args.files)
else:
print("Error: Please specify the files containing ranking summaries.")
sys.exit(1)

error_counts = response_analyzer.count_errors(args.verbose)
print("Normalized scores:", error_counts)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files", nargs="+", help="Filenames of ranking summaries", required=False
)
parser.add_argument(
"--verbose", action="store_true", help="Verbose output of errors"
)
args = parser.parse_args()

main(args)
32 changes: 32 additions & 0 deletions src/rank_llm/evaluation/trec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,45 @@
import platform
import subprocess
import tempfile
from typing import List

import pandas as pd
from pyserini.search import get_qrels_file
from pyserini.util import download_evaluation_script

from src.rank_llm.result import Result


class EvalFunction:
@staticmethod
def from_results(results: List[Result], qrels: str) -> str:
"""
This method processes a list of Result objects and immediately evaluates them,
returning the evaluation result as a string.

Args:
results (List[Result]): A list of Result objects.
qrels (str): Path to the qrels file.

Returns:
str: Evaluation results as a string.
"""
# Convert the list of Result objects to a temporary run file format
temp_run_file = tempfile.NamedTemporaryFile(
delete=False, mode="w", suffix=".txt"
).name
with open(temp_run_file, "w") as file:
for result in results:
for hit in result.hits:
file.write(
f"{result.query} Q0 {hit['doc_id']} {hit['rank']} {hit['score']} RUN\n"
)

eval_result = EvalFunction.eval(temp_run_file, qrels, trunc=True)
os.remove(temp_run_file)

return eval_result

@staticmethod
def trunc(qrels: str, run: str):
"""
Expand Down
Empty file added test/analysis/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions test/analysis/test_response_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest

from src.rank_llm.analysis.response_analysis_verbose import ResponseAnalyzer
from src.rank_llm.result import RankingExecInfo, Result


class TestResponseAnalyzer(unittest.TestCase):
# create a list of mock Result objects
def setUp(self):
self.mock_results = [
Result(
query="Query 1",
hits=[],
ranking_exec_summary=[
RankingExecInfo(
prompt="I will provide you with 3 passages",
response="1 > 2 > 3",
input_token_count=100,
output_token_count=50,
),
RankingExecInfo(
prompt="I will provide you with 2 passages",
response="2 > 1",
input_token_count=80,
output_token_count=40,
),
],
),
Result(
query="Query 2",
hits=[],
ranking_exec_summary=[
RankingExecInfo(
prompt="I will provide you with 4 passages",
response="4 > 3 > 2 > 1",
input_token_count=120,
output_token_count=60,
)
],
),
]

def test_read_results_responses(self):
analyzer = ResponseAnalyzer.from_inline_results(self.mock_results)
responses, num_passages = analyzer.read_results_responses()

self.assertEqual(len(responses), 3, "Should have 3 responses")
self.assertEqual(len(num_passages), 3, "Should have 3 num_passages")
self.assertEqual(
num_passages, [3, 2, 4], "Num passages should match expected values"
)


if __name__ == "__main__":
unittest.main()
Empty file added test/evaluation/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions test/evaluation/test_trec_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest
from unittest.mock import patch

from src.rank_llm.evaluation.trec_eval import EvalFunction
from src.rank_llm.result import Result


class TestEvalFunction(unittest.TestCase):
def setUp(self):
self.results = [
Result(
query="Query1",
hits=[
{"doc_id": "D1", "rank": 1, "score": 0.9},
{"doc_id": "D2", "rank": 2, "score": 0.8},
],
),
Result(query="Query2", hits=[{"doc_id": "D3", "rank": 1, "score": 0.85}]),
]
self.qrels_path = "path/to/qrels"

@patch("src.rank_llm.evaluation.trec_eval.EvalFunction.eval")
def test_from_results(self, mock_eval):
mock_eval.return_value = "Evaluation success"
eval_output = EvalFunction.from_results(self.results, self.qrels_path)

mock_eval.assert_called()
self.assertEqual(eval_output, "Evaluation success")


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion test/rerank/test_RankListwiseOSLLM.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest

from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM
from rank_llm.rerank.rankllm import PromptMode
from rank_llm.result import Result
import unittest

# model, context_size, prompt_mode, num_few_shot_examples, variable_passages, window_size, system_message
valid_inputs = [
Expand Down Expand Up @@ -285,6 +286,7 @@ def test_create_prompt(
)

import re

def get_first_int(s):
match = re.search(r"\d+", s)
return int(match.group()) if match else None
Expand Down
5 changes: 3 additions & 2 deletions test/rerank/test_SafeOpenai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from rank_llm.rerank.rank_gpt import SafeOpenai
from rank_llm.rerank.rankllm import PromptMode
import unittest
from unittest.mock import patch

from rank_llm.rerank.rank_gpt import SafeOpenai
from rank_llm.rerank.rankllm import PromptMode

# model, context_size, prompt_mode, num_few_shot_examples, keys, key_start_id
valid_inputs = [
("gpt-3.5-turbo", 4096, PromptMode.RANK_GPT, 0, "OPEN_AI_API_KEY", None),
Expand Down
7 changes: 4 additions & 3 deletions test/retrieve/test_PyseriniRetriever.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod
from rank_llm.retrieve.indices_dict import INDICES
from rank_llm.result import Result
import unittest
from unittest.mock import MagicMock, patch

from rank_llm.result import Result
from rank_llm.retrieve.indices_dict import INDICES
from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod

valid_inputs = [
("dl19", RetrievalMethod.BM25),
("dl19", RetrievalMethod.BM25_RM3),
Expand Down