From 64151b5272e36cbe3794f7e8815a5f001847e011 Mon Sep 17 00:00:00 2001 From: sahel Date: Thu, 25 Jan 2024 02:16:13 -0500 Subject: [PATCH 1/7] added result and resultwriter classes --- src/rank_llm/result.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rank_llm/result.py b/src/rank_llm/result.py index 05a4e6e3..77f76299 100644 --- a/src/rank_llm/result.py +++ b/src/rank_llm/result.py @@ -5,7 +5,7 @@ class RankingExecInfo: def __init__( self, prompt, response: str, input_token_count: int, output_token_count: int - ): + ) -> None: self.prompt = prompt self.response = response self.input_token_count = input_token_count @@ -26,7 +26,6 @@ def __init__( def __repr__(self): return str(self.__dict__) - class ResultsWriter: def __init__(self, results: List[Result], append: bool = False): self._results = results @@ -49,6 +48,7 @@ def write_in_json_format(self, filename: str): with open(filename, "a" if self._append else "w") as f: json.dump(results, f, indent=2) + def write_in_trec_eval_format(self, filename: str): with open(filename, "a" if self._append else "w") as f: for result in self._results: From c7dfba53ab1894d206ef7afc411f4f34d6634524 Mon Sep 17 00:00:00 2001 From: sahel Date: Fri, 26 Jan 2024 18:37:52 -0500 Subject: [PATCH 2/7] integrated result into pyserini retriever --- src/rank_llm/result.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/rank_llm/result.py b/src/rank_llm/result.py index 77f76299..039a11b7 100644 --- a/src/rank_llm/result.py +++ b/src/rank_llm/result.py @@ -5,7 +5,7 @@ class RankingExecInfo: def __init__( self, prompt, response: str, input_token_count: int, output_token_count: int - ) -> None: + ): self.prompt = prompt self.response = response self.input_token_count = input_token_count @@ -25,6 +25,10 @@ def __init__( def __repr__(self): return str(self.__dict__) +<<<<<<< HEAD +======= + +>>>>>>> integrated result into pyserini retriever class ResultsWriter: def __init__(self, results: List[Result], append: bool = False): @@ -47,7 +51,10 @@ def write_in_json_format(self, filename: str): results.append({"query": result.query, "hits": result.hits}) with open(filename, "a" if self._append else "w") as f: json.dump(results, f, indent=2) +<<<<<<< HEAD +======= +>>>>>>> integrated result into pyserini retriever def write_in_trec_eval_format(self, filename: str): with open(filename, "a" if self._append else "w") as f: From 0bb8d75285d430fe58f0ef5a7152c53de8c78240 Mon Sep 17 00:00:00 2001 From: sahel Date: Fri, 26 Jan 2024 23:42:19 -0500 Subject: [PATCH 3/7] integrated results in reranking --- .gitignore | 1 + src/rank_llm/evaluation/trec_eval.py | 16 ++-- src/rank_llm/rerank/rank_gpt.py | 23 +++--- src/rank_llm/rerank/rank_listwise_os_llm.py | 9 ++- src/rank_llm/rerank/rankllm.py | 60 +++++++-------- src/rank_llm/rerank/reranker.py | 84 ++++----------------- src/rank_llm/result.py | 3 + src/rank_llm/retrieve_and_rerank.py | 14 +--- 8 files changed, 74 insertions(+), 136 deletions(-) diff --git a/.gitignore b/.gitignore index 0a1c2d76..f30839bf 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,5 @@ prompts_and_responses/ rerank_results/ token_counts/ retrieve_results/ +ranking_execution_summary/ repro/ diff --git a/src/rank_llm/evaluation/trec_eval.py b/src/rank_llm/evaluation/trec_eval.py index 83da1c52..d23b52ae 100644 --- a/src/rank_llm/evaluation/trec_eval.py +++ b/src/rank_llm/evaluation/trec_eval.py @@ -26,8 +26,8 @@ class EvalFunction: @staticmethod def trunc(qrels, run): qrels = get_qrels_file(qrels) - run = pd.read_csv(run, delim_whitespace=True, header=None) - qrels = pd.read_csv(qrels, delim_whitespace=True, header=None) + run = pd.read_csv(run, sep='\s+', header=None) + qrels = pd.read_csv(qrels, sep='\s+', header=None) run[0] = run[0].astype(str) qrels[0] = qrels[0].astype(str) @@ -76,7 +76,7 @@ def eval(args, trunc=True): print("msmarco run detected. Converting to trec...") run = pd.read_csv( args[-1], - delim_whitespace=True, + sep='\s+', header=None, names=["query_id", "doc_id", "rank"], ) @@ -86,8 +86,8 @@ def eval(args, trunc=True): run.to_csv(temp_file, sep="\t", header=None, index=None) args[-1] = temp_file - run = pd.read_csv(args[-1], delim_whitespace=True, header=None) - qrels = pd.read_csv(args[-2], delim_whitespace=True, header=None) + run = pd.read_csv(args[-1], sep='\s+', header=None) + qrels = pd.read_csv(args[-2], sep='\s+', header=None) # cast doc_id column as string run[0] = run[0].astype(str) @@ -148,13 +148,17 @@ def main(args): for retrieval_method in RetrievalMethod: if retrieval_method == RetrievalMethod.UNSPECIFIED: continue + directory = f"rerank_results/{retrieval_method.name}" + if not os.path.isdir(directory): + continue for top_k_canidadates in [20, 100]: - directory = f"rerank_results/{retrieval_method.name}" for filename in os.listdir(directory): if not filename.startswith( f"{model}_{context_size}_{top_k_canidadates}_{prompt_mode}_{dataset}" ): continue + if filename.endswith(".json"): + continue f = os.path.join(directory, filename) # checking if it is a file if os.path.isfile(f): diff --git a/src/rank_llm/rerank/rank_gpt.py b/src/rank_llm/rerank/rank_gpt.py index ce576e56..99bac435 100644 --- a/src/rank_llm/rerank/rank_gpt.py +++ b/src/rank_llm/rerank/rank_gpt.py @@ -8,6 +8,7 @@ import tiktoken from rank_llm.rerank.rankllm import RankLLM, PromptMode +from rank_llm.result import Result def replace_number(s: str) -> str: @@ -138,24 +139,24 @@ def num_output_tokens(self) -> int: return 200 def create_prompt( - self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int + self, result: Result, rank_start: int, rank_end: int ) -> Tuple[List[Dict[str, str]], int]: if self._prompt_mode == PromptMode.RANK_GPT: - return self.create_rank_gpt_prompt(retrieved_result, rank_start, rank_end) + return self.create_rank_gpt_prompt(result, rank_start, rank_end) else: - return self.create_LRL_prompt(retrieved_result, rank_start, rank_end) + return self.create_LRL_prompt(result, rank_start, rank_end) def create_rank_gpt_prompt( - self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int + self, result: Result, rank_start: int, rank_end: int ) -> Tuple[List[Dict[str, str]], int]: - query = retrieved_result["query"] - num = len(retrieved_result["hits"][rank_start:rank_end]) + query = result.query + num = len(result.hits[rank_start:rank_end]) max_length = 300 * (20 / (rank_end - rank_start)) while True: messages = self._get_prefix_for_rank_gpt_prompt(query, num) rank = 0 - for hit in retrieved_result["hits"][rank_start:rank_end]: + for hit in result.hits[rank_start:rank_end]: rank += 1 content = hit["content"] content = content.replace("Title: Content: ", "") @@ -187,16 +188,16 @@ def create_rank_gpt_prompt( return messages, self.get_num_tokens(messages) def create_LRL_prompt( - self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int + self, result: Result, rank_start: int, rank_end: int ) -> Tuple[List[Dict[str, str]], int]: - query = retrieved_result["query"] - num = len(retrieved_result["hits"][rank_start:rank_end]) + query = result.query + num = len(result.hits[rank_start:rank_end]) max_length = 300 * (20 / (rank_end - rank_start)) psg_ids = [] while True: message = "Sort the list PASSAGES by how good each text answers the QUESTION (in descending order of relevancy).\n" rank = 0 - for hit in retrieved_result["hits"][rank_start:rank_end]: + for hit in result.hits[rank_start:rank_end]: rank += 1 psg_id = f"PASSAGE{rank}" content = hit["content"] diff --git a/src/rank_llm/rerank/rank_listwise_os_llm.py b/src/rank_llm/rerank/rank_listwise_os_llm.py index c94ced31..7f1333d6 100644 --- a/src/rank_llm/rerank/rank_listwise_os_llm.py +++ b/src/rank_llm/rerank/rank_listwise_os_llm.py @@ -9,6 +9,7 @@ from transformers.generation import GenerationConfig from rank_llm.rerank.rankllm import RankLLM, PromptMode +from rank_llm.result import Result def replace_number(s): @@ -105,10 +106,10 @@ def _add_few_shot_examples(self, conv): return conv def create_prompt( - self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int + self, result: Result, rank_start: int, rank_end: int ) -> Tuple[str, int]: - query = retrieved_result["query"] - num = len(retrieved_result["hits"][rank_start:rank_end]) + query = result.query + num = len(result.hits[rank_start:rank_end]) max_length = 300 * (20 / (rank_end - rank_start)) while True: conv = get_conversation_template(self._model) @@ -118,7 +119,7 @@ def create_prompt( prefix = self._add_prefix_prompt(query, num) rank = 0 input_context = f"{prefix}\n" - for hit in retrieved_result["hits"][rank_start:rank_end]: + for hit in result.hits[rank_start:rank_end]: rank += 1 content = hit["content"] content = content.replace("Title: Content: ", "") diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 88487d8a..578f8cf5 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -6,6 +6,8 @@ from tqdm import tqdm +from rank_llm.result import Result, RankingExecInfo + class PromptMode(Enum): UNSPECIFIED = "unspecified" @@ -38,7 +40,7 @@ def run_llm(self, prompt: Union[str, List[Dict[str, str]]]) -> Tuple[str, int]: @abstractmethod def create_prompt( - self, retrieved_result: Dict[str, Any], rank_start: int, rank_end: int + self, result: Result, rank_start: int, rank_end: int ) -> Tuple[Union[str, List[Dict[str, str]]], int]: pass @@ -56,7 +58,7 @@ def num_output_tokens(self) -> int: def permutation_pipeline( self, - result: Dict[str, Any], + result: Result, rank_start: int, rank_end: int, logging: bool = False, @@ -69,14 +71,18 @@ def permutation_pipeline( ) if logging: print(f"output: {permutation}") - rerank_result = self.receive_permutation( + ranking_exec_info = RankingExecInfo(prompt, permutation, in_token_count, out_token_count) + if result.ranking_exec_summary == None: + result.ranking_exec_summary = [] + result.ranking_exec_summary.append(ranking_exec_info) + result = self.receive_permutation( result, permutation, rank_start, rank_end ) - return rerank_result, in_token_count, out_token_count, prompt, permutation + return result def sliding_windows( self, - retrieved_result: Dict[str, Any], + retrieved_result: Result, rank_start: int, rank_end: int, window_size: int, @@ -84,41 +90,27 @@ def sliding_windows( shuffle_candidates: bool = False, logging: bool = False, ): - in_token_count = 0 - out_token_count = 0 rerank_result = copy.deepcopy(retrieved_result) if shuffle_candidates: # First randomly shuffle rerank_result between rank_start and rank_end - rerank_result["hits"][rank_start:rank_end] = random.sample( - rerank_result["hits"][rank_start:rank_end], - len(rerank_result["hits"][rank_start:rank_end]), + rerank_result.hits[rank_start:rank_end] = random.sample( + rerank_result.hits[rank_start:rank_end], + len(rerank_result.hits[rank_start:rank_end]), ) # Next rescore all candidates with 1/rank - for i, hit in enumerate(rerank_result["hits"]): + for i, hit in enumerate(rerank_result.hits): hit["score"] = 1.0 / (i + 1) hit["rank"] = i + 1 end_pos = rank_end start_pos = rank_end - window_size - prompts = [] - permutations = [] # end_pos > rank_start ensures that the list is non-empty while allowing last window to be smaller than window_size # start_pos + step != rank_start prevents processing of redundant windows (e.g. 0-20, followed by 0-10) while end_pos > rank_start and start_pos + step != rank_start: start_pos = max(start_pos, rank_start) - ( - rerank_result, - in_count, - out_count, - prompt, - permutation, - ) = self.permutation_pipeline(rerank_result, start_pos, end_pos, logging) - in_token_count += in_count - out_token_count += out_count - prompts.append(prompt) - permutations.append(permutation) + rerank_result = self.permutation_pipeline(rerank_result, start_pos, end_pos, logging) end_pos = end_pos - step start_pos = start_pos - step - return rerank_result, in_token_count, out_token_count, prompts, permutations + return rerank_result def get_ranking_cost_upperbound( self, num_q: int, rank_start: int, rank_end: int, window_size: int, step: int @@ -181,19 +173,19 @@ def _remove_duplicate(self, response: List[int]) -> List[int]: return new_response def receive_permutation( - self, item: Dict[str, Any], permutation: str, rank_start: int, rank_end: int - ) -> Dict[str, Any]: + self, result: Result, permutation: str, rank_start: int, rank_end: int + ) -> Result: response = self._clean_response(permutation) response = [int(x) - 1 for x in response.split()] response = self._remove_duplicate(response) - cut_range = copy.deepcopy(item["hits"][rank_start:rank_end]) + cut_range = copy.deepcopy(result.hits[rank_start:rank_end]) original_rank = [tt for tt in range(len(cut_range))] response = [ss for ss in response if ss in original_rank] response = response + [tt for tt in original_rank if tt not in response] for j, x in enumerate(response): - item["hits"][j + rank_start] = copy.deepcopy(cut_range[x]) - if "rank" in item["hits"][j + rank_start]: - item["hits"][j + rank_start]["rank"] = cut_range[j]["rank"] - if "score" in item["hits"][j + rank_start]: - item["hits"][j + rank_start]["score"] = cut_range[j]["score"] - return item + result.hits[j + rank_start] = copy.deepcopy(cut_range[x]) + if "rank" in result.hits[j + rank_start]: + result.hits[j + rank_start]["rank"] = cut_range[j]["rank"] + if "score" in result.hits[j + rank_start]: + result.hits[j + rank_start]["score"] = cut_range[j]["score"] + return result diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 3441a34b..8f33fb99 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -6,6 +6,7 @@ from tqdm import tqdm from rank_llm.rerank.rankllm import RankLLM +from rank_llm.result import Result, ResultsWriter class Reranker: @@ -13,21 +14,10 @@ def __init__(self, agent: RankLLM, top_k_candidates: int) -> None: self._agent = agent self._top_k_candidates = top_k_candidates - def rerank(self, retrieved_results: List[Dict[str, Any]], **kwargs): + def rerank(self, retrieved_results: List[Result], **kwargs): rerank_results = [] - input_token_counts = [] - output_token_counts = [] - aggregated_prompts = [] - aggregated_responses = [] - for result in tqdm(retrieved_results): - ( - rerank_result, - in_token_count, - out_token_count, - prompts, - responses, - ) = self._agent.sliding_windows( + rerank_result = self._agent.sliding_windows( result, rank_start=0, rank_end=kwargs["rank_end"], @@ -37,43 +27,17 @@ def rerank(self, retrieved_results: List[Dict[str, Any]], **kwargs): logging=kwargs["logging"], ) rerank_results.append(rerank_result) - input_token_counts.append(in_token_count) - output_token_counts.append(out_token_count) - aggregated_prompts.extend(prompts) - aggregated_responses.extend(responses) - - # print(f"rerank_results={rerank_results}") - print(f"input_tokens_counts={input_token_counts}") - print(f"total input token count={sum(input_token_counts)}") - print(f"output_token_counts={output_token_counts}") - print(f"total output token count={sum(output_token_counts)}") - - return ( - rerank_results, - input_token_counts, - output_token_counts, - aggregated_prompts, - aggregated_responses, - ) + return rerank_results def write_rerank_results( self, retrieval_method_name: str, - rerank_results: List[Dict[str, Any]], - input_token_counts: List[int], - output_token_counts: List[int], - # List[str] for Vicuna, List[List[Dict[str, str]]] for gpt models. - prompts: Union[List[str], List[List[Dict[str, str]]]], - responses: List[str], + results: List[Result], shuffle_candidates: bool = False, pass_ct: int = None, window_size: int = None, dataset_name: str = None, ) -> str: - # write rerank results - Path(f"rerank_results/{retrieval_method_name}/").mkdir( - parents=True, exist_ok=True - ) _modelname = self._agent._model.split("/")[-1] if _modelname.startswith("checkpoint"): _modelname = self._agent._model.split("/")[-2] + "_" + _modelname @@ -91,37 +55,17 @@ def write_rerank_results( name += f"_window_{window_size}" if pass_ct is not None: name += f"_pass_{pass_ct}" - result_file_name = f"rerank_results/{retrieval_method_name}/{name}.txt" - with open(result_file_name, "w") as f: - for i in range(len(rerank_results)): - rank = 1 - hits = rerank_results[i]["hits"] - for hit in hits: - f.write( - f"{hit['qid']} Q0 {hit['docid']} {rank} {hit['score']} rank\n" - ) - rank += 1 - # Write token counts - Path(f"token_counts/{retrieval_method_name}/").mkdir( + # write rerank results + writer = ResultsWriter(results) + Path(f"rerank_results/{retrieval_method_name}/").mkdir( parents=True, exist_ok=True ) - count_file_name = f"token_counts/{retrieval_method_name}/{name}.txt" - counts = {} - for i, (in_count, out_count) in enumerate( - zip(input_token_counts, output_token_counts) - ): - counts[rerank_results[i]["query"]] = (in_count, out_count) - with open(count_file_name, "w") as f: - json.dump(counts, f, indent=4) - # Write prompts and responses - Path(f"prompts_and_responses/{retrieval_method_name}/").mkdir( + result_file_name = f"rerank_results/{retrieval_method_name}/{name}.txt" + writer.write_in_trec_eval_format(result_file_name) + writer.write_in_json_format(f"rerank_results/{retrieval_method_name}/{name}.json") + # Write ranking execution summary + Path(f"ranking_execution_summary/{retrieval_method_name}/").mkdir( parents=True, exist_ok=True ) - with open( - f"prompts_and_responses/{retrieval_method_name}/{name}.json", - "w", - ) as f: - for p, r in zip(prompts, responses): - json.dump({"prompt": p, "response": r}, f) - f.write("\n") + writer.write_ranking_exec_summary(f"ranking_execution_summary/{retrieval_method_name}/{name}.txt") return result_file_name diff --git a/src/rank_llm/result.py b/src/rank_llm/result.py index 039a11b7..6ecb7824 100644 --- a/src/rank_llm/result.py +++ b/src/rank_llm/result.py @@ -11,6 +11,9 @@ def __init__( self.input_token_count = input_token_count self.output_token_count = output_token_count + def __repr__(self): + return str(self.__dict__) + class Result: def __init__( diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 3247cdfd..2c34c26b 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -100,13 +100,7 @@ def retrieve_and_rerank( reranker = Reranker(agent, top_k_candidates) for pass_ct in range(num_passes): print(f"Pass {pass_ct + 1} of {num_passes}:") - ( - rerank_results, - input_token_counts, - output_token_counts, - aggregated_prompts, - aggregated_responses, - ) = reranker.rerank( + rerank_results = reranker.rerank( retrieved_results, rank_end=top_k_candidates, window_size=min(window_size, top_k_candidates), @@ -120,10 +114,6 @@ def retrieve_and_rerank( file_name = reranker.write_rerank_results( retrieval_method.name, rerank_results, - input_token_counts, - output_token_counts, - aggregated_prompts, - aggregated_responses, shuffle_candidates, pass_ct=None if num_passes == 1 else pass_ct, window_size=window_size, @@ -148,5 +138,7 @@ def retrieve_and_rerank( print(f"Skipping evaluation as {dataset} is not in TOPICS.") if num_passes > 1: retrieved_results = rerank_results + for r in retrieved_results: + r.ranking_exec_summary = None return rerank_results From 7b933a7ed7666acd8cf2f9d97cbcc3bd0be41a06 Mon Sep 17 00:00:00 2001 From: sahel Date: Sat, 27 Jan 2024 02:00:32 -0500 Subject: [PATCH 4/7] updated result.py --- src/rank_llm/result.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/rank_llm/result.py b/src/rank_llm/result.py index 6ecb7824..a0aa7c12 100644 --- a/src/rank_llm/result.py +++ b/src/rank_llm/result.py @@ -28,10 +28,6 @@ def __init__( def __repr__(self): return str(self.__dict__) -<<<<<<< HEAD -======= - ->>>>>>> integrated result into pyserini retriever class ResultsWriter: def __init__(self, results: List[Result], append: bool = False): @@ -54,10 +50,6 @@ def write_in_json_format(self, filename: str): results.append({"query": result.query, "hits": result.hits}) with open(filename, "a" if self._append else "w") as f: json.dump(results, f, indent=2) -<<<<<<< HEAD - -======= ->>>>>>> integrated result into pyserini retriever def write_in_trec_eval_format(self, filename: str): with open(filename, "a" if self._append else "w") as f: From 47d1b5a8bc6d03f5f277fd916c1e0305c2cec7f9 Mon Sep 17 00:00:00 2001 From: sahel Date: Sat, 27 Jan 2024 02:02:12 -0500 Subject: [PATCH 5/7] format --- src/rank_llm/evaluation/trec_eval.py | 10 +++++----- src/rank_llm/rerank/rankllm.py | 12 +++++++----- src/rank_llm/rerank/reranker.py | 8 ++++++-- src/rank_llm/result.py | 1 + 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/rank_llm/evaluation/trec_eval.py b/src/rank_llm/evaluation/trec_eval.py index d23b52ae..e324dd0d 100644 --- a/src/rank_llm/evaluation/trec_eval.py +++ b/src/rank_llm/evaluation/trec_eval.py @@ -26,8 +26,8 @@ class EvalFunction: @staticmethod def trunc(qrels, run): qrels = get_qrels_file(qrels) - run = pd.read_csv(run, sep='\s+', header=None) - qrels = pd.read_csv(qrels, sep='\s+', header=None) + run = pd.read_csv(run, sep="\s+", header=None) + qrels = pd.read_csv(qrels, sep="\s+", header=None) run[0] = run[0].astype(str) qrels[0] = qrels[0].astype(str) @@ -76,7 +76,7 @@ def eval(args, trunc=True): print("msmarco run detected. Converting to trec...") run = pd.read_csv( args[-1], - sep='\s+', + sep="\s+", header=None, names=["query_id", "doc_id", "rank"], ) @@ -86,8 +86,8 @@ def eval(args, trunc=True): run.to_csv(temp_file, sep="\t", header=None, index=None) args[-1] = temp_file - run = pd.read_csv(args[-1], sep='\s+', header=None) - qrels = pd.read_csv(args[-2], sep='\s+', header=None) + run = pd.read_csv(args[-1], sep="\s+", header=None) + qrels = pd.read_csv(args[-2], sep="\s+", header=None) # cast doc_id column as string run[0] = run[0].astype(str) diff --git a/src/rank_llm/rerank/rankllm.py b/src/rank_llm/rerank/rankllm.py index 578f8cf5..651a0808 100644 --- a/src/rank_llm/rerank/rankllm.py +++ b/src/rank_llm/rerank/rankllm.py @@ -71,13 +71,13 @@ def permutation_pipeline( ) if logging: print(f"output: {permutation}") - ranking_exec_info = RankingExecInfo(prompt, permutation, in_token_count, out_token_count) + ranking_exec_info = RankingExecInfo( + prompt, permutation, in_token_count, out_token_count + ) if result.ranking_exec_summary == None: result.ranking_exec_summary = [] result.ranking_exec_summary.append(ranking_exec_info) - result = self.receive_permutation( - result, permutation, rank_start, rank_end - ) + result = self.receive_permutation(result, permutation, rank_start, rank_end) return result def sliding_windows( @@ -107,7 +107,9 @@ def sliding_windows( # start_pos + step != rank_start prevents processing of redundant windows (e.g. 0-20, followed by 0-10) while end_pos > rank_start and start_pos + step != rank_start: start_pos = max(start_pos, rank_start) - rerank_result = self.permutation_pipeline(rerank_result, start_pos, end_pos, logging) + rerank_result = self.permutation_pipeline( + rerank_result, start_pos, end_pos, logging + ) end_pos = end_pos - step start_pos = start_pos - step return rerank_result diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 8f33fb99..28271040 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -62,10 +62,14 @@ def write_rerank_results( ) result_file_name = f"rerank_results/{retrieval_method_name}/{name}.txt" writer.write_in_trec_eval_format(result_file_name) - writer.write_in_json_format(f"rerank_results/{retrieval_method_name}/{name}.json") + writer.write_in_json_format( + f"rerank_results/{retrieval_method_name}/{name}.json" + ) # Write ranking execution summary Path(f"ranking_execution_summary/{retrieval_method_name}/").mkdir( parents=True, exist_ok=True ) - writer.write_ranking_exec_summary(f"ranking_execution_summary/{retrieval_method_name}/{name}.txt") + writer.write_ranking_exec_summary( + f"ranking_execution_summary/{retrieval_method_name}/{name}.txt" + ) return result_file_name diff --git a/src/rank_llm/result.py b/src/rank_llm/result.py index a0aa7c12..581bc5a2 100644 --- a/src/rank_llm/result.py +++ b/src/rank_llm/result.py @@ -29,6 +29,7 @@ def __init__( def __repr__(self): return str(self.__dict__) + class ResultsWriter: def __init__(self, results: List[Result], append: bool = False): self._results = results From a54d12cce2fe685e34dbfa4803540f8d4cc1820d Mon Sep 17 00:00:00 2001 From: sahel Date: Sat, 27 Jan 2024 02:08:08 -0500 Subject: [PATCH 6/7] use json format for writing exec summary --- src/rank_llm/rerank/reranker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 28271040..25153524 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -70,6 +70,6 @@ def write_rerank_results( parents=True, exist_ok=True ) writer.write_ranking_exec_summary( - f"ranking_execution_summary/{retrieval_method_name}/{name}.txt" + f"ranking_execution_summary/{retrieval_method_name}/{name}.json" ) return result_file_name From 88d4c6ad26553ba26f0216b8614a4d2a1a4dd318 Mon Sep 17 00:00:00 2001 From: sahel Date: Sun, 28 Jan 2024 17:40:56 -0500 Subject: [PATCH 7/7] merge all changes --- .../rerank_dataset_with_prebuilt_index.py | 11 ++-- src/rank_llm/demo/rerank_inline_docs.py | 6 ++- src/rank_llm/demo/rerank_inline_hits.py | 6 ++- .../demo/rerank_stored_retrieved_results.py | 6 ++- src/rank_llm/rerank/reranker.py | 31 +++++++---- src/rank_llm/rerank/vicuna_reranker.py | 53 +++++++++++++++++++ src/rank_llm/rerank/zephyr_reranker.py | 53 +++++++++++++++++++ src/rank_llm/retrieve_and_rerank.py | 3 +- 8 files changed, 147 insertions(+), 22 deletions(-) create mode 100644 src/rank_llm/rerank/vicuna_reranker.py create mode 100644 src/rank_llm/rerank/zephyr_reranker.py diff --git a/src/rank_llm/demo/rerank_dataset_with_prebuilt_index.py b/src/rank_llm/demo/rerank_dataset_with_prebuilt_index.py index 3e3d9750..830025a5 100644 --- a/src/rank_llm/demo/rerank_dataset_with_prebuilt_index.py +++ b/src/rank_llm/demo/rerank_dataset_with_prebuilt_index.py @@ -8,16 +8,19 @@ from rank_llm.retrieve.retriever import Retriever from rank_llm.retrieve.pyserini_retriever import RetrievalMethod +from rank_llm.rerank.vicuna_reranker import VicunaReranker # By default uses BM25 for retrieval dataset_name = "dl19" retrieved_results = Retriever.from_dataset_with_prebuit_index(dataset_name) -print(retrieved_results) -# TODO: add rerank instead of printing retrieved results +reranker = VicunaReranker() +rerank_results = reranker.rerank(retrieved_results) +print(rerank_results) # Users can specify other retrieval methods: retrieved_results = Retriever.from_dataset_with_prebuit_index( dataset_name, RetrievalMethod.SPLADE_P_P_ENSEMBLE_DISTIL ) -print(retrieved_results) -# TODO: add rerank instead of printing retrieved results +reranker = VicunaReranker() +rerank_results = reranker.rerank(retrieved_results) +print(rerank_results) diff --git a/src/rank_llm/demo/rerank_inline_docs.py b/src/rank_llm/demo/rerank_inline_docs.py index d335fd21..c4486ca8 100644 --- a/src/rank_llm/demo/rerank_inline_docs.py +++ b/src/rank_llm/demo/rerank_inline_docs.py @@ -7,6 +7,7 @@ sys.path.append(parent) from rank_llm.retrieve.retriever import Retriever +from rank_llm.rerank.zephyr_reranker import ZephyrReranker query = "What is the capital of the United States?" docs = [ @@ -17,5 +18,6 @@ ] retrieved_results = Retriever.from_inline_documents(query, documents=docs) -print(retrieved_results) -# TODO: add rerank instead of printing retrieved results +reranker = ZephyrReranker() +rerank_results = reranker.rerank(retrieved_results) +print(rerank_results) diff --git a/src/rank_llm/demo/rerank_inline_hits.py b/src/rank_llm/demo/rerank_inline_hits.py index c2fe1dd6..0fabf9ca 100644 --- a/src/rank_llm/demo/rerank_inline_hits.py +++ b/src/rank_llm/demo/rerank_inline_hits.py @@ -7,6 +7,7 @@ sys.path.append(parent) from rank_llm.retrieve.retriever import Retriever +from rank_llm.rerank.zephyr_reranker import ZephyrReranker query = "how long is life cycle of flea" hits = [ @@ -69,5 +70,6 @@ ] retrieved_results = Retriever.from_inline_hits(query=query, hits=hits) -print(retrieved_results) -# TODO: add rerank instead of printing retrieved results +reranker = ZephyrReranker() +rerank_results = reranker.rerank(retrieved_results) +print(rerank_results) diff --git a/src/rank_llm/demo/rerank_stored_retrieved_results.py b/src/rank_llm/demo/rerank_stored_retrieved_results.py index 14f13816..fa3f8b70 100644 --- a/src/rank_llm/demo/rerank_stored_retrieved_results.py +++ b/src/rank_llm/demo/rerank_stored_retrieved_results.py @@ -7,8 +7,10 @@ sys.path.append(parent) from rank_llm.retrieve.retriever import Retriever +from rank_llm.rerank.zephyr_reranker import ZephyrReranker file_name = "retrieve_results/BM25/retrieve_results_dl19.json" retrieved_results = Retriever.from_saved_results(file_name) -print(retrieved_results) -# TODO: add rerank instead of printing retrieved results +reranker = ZephyrReranker() +rerank_results = reranker.rerank(retrieved_results) +print(rerank_results) diff --git a/src/rank_llm/rerank/reranker.py b/src/rank_llm/rerank/reranker.py index 25153524..257876a1 100644 --- a/src/rank_llm/rerank/reranker.py +++ b/src/rank_llm/rerank/reranker.py @@ -1,7 +1,7 @@ import json from datetime import datetime from pathlib import Path -from typing import List, Union, Dict, Any +from typing import List from tqdm import tqdm @@ -10,21 +10,29 @@ class Reranker: - def __init__(self, agent: RankLLM, top_k_candidates: int) -> None: + def __init__(self, agent: RankLLM) -> None: self._agent = agent - self._top_k_candidates = top_k_candidates - def rerank(self, retrieved_results: List[Result], **kwargs): + def rerank( + self, + retrieved_results: List[Result], + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, + ): rerank_results = [] for result in tqdm(retrieved_results): rerank_result = self._agent.sliding_windows( result, - rank_start=0, - rank_end=kwargs["rank_end"], - window_size=kwargs["window_size"], - step=kwargs["step"], - shuffle_candidates=kwargs["shuffle_candidates"], - logging=kwargs["logging"], + rank_start=max(rank_start, 0), + rank_end=min(rank_end, len(result.hits)), + window_size=window_size, + step=step, + shuffle_candidates=shuffle_candidates, + logging=logging, ) rerank_results.append(rerank_result) return rerank_results @@ -34,6 +42,7 @@ def write_rerank_results( retrieval_method_name: str, results: List[Result], shuffle_candidates: bool = False, + top_k_candidates: int = 100, pass_ct: int = None, window_size: int = None, dataset_name: str = None, @@ -41,7 +50,7 @@ def write_rerank_results( _modelname = self._agent._model.split("/")[-1] if _modelname.startswith("checkpoint"): _modelname = self._agent._model.split("/")[-2] + "_" + _modelname - name = f"{_modelname}_{self._agent._context_size}_{self._top_k_candidates}_{self._agent._prompt_mode}" + name = f"{_modelname}_{self._agent._context_size}_{top_k_candidates}_{self._agent._prompt_mode}" if dataset_name: name = f"{name}_{dataset_name}" if self._agent._num_few_shot_examples > 0: diff --git a/src/rank_llm/rerank/vicuna_reranker.py b/src/rank_llm/rerank/vicuna_reranker.py new file mode 100644 index 00000000..7ae5d804 --- /dev/null +++ b/src/rank_llm/rerank/vicuna_reranker.py @@ -0,0 +1,53 @@ +from typing import List + +from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM +from rank_llm.rerank.rankllm import PromptMode +from rank_llm.rerank.reranker import Reranker +from rank_llm.result import Result + + +class VicunaReranker: + def __init__( + self, + model_path: str = "castorini/rank_vicuna_7b_v1", + context_size: int = 4096, + prompt_mode: PromptMode = PromptMode.RANK_GPT, + num_few_shot_examples: int = 0, + device: str = "cuda", + num_gpus: int = 1, + variable_passages: bool = False, + window_size: int = 20, + system_message: str = None, + ): + agent = RankListwiseOSLLM( + model=model_path, + context_size=context_size, + prompt_mode=prompt_mode, + num_few_shot_examples=num_few_shot_examples, + device=device, + num_gpus=num_gpus, + variable_passages=variable_passages, + window_size=window_size, + system_message=system_message, + ) + self._reranker = Reranker(agent) + + def rerank( + self, + retrieved_results: List[Result], + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, + ): + return self._reranker.rerank( + retrieved_results=retrieved_results, + rank_start=rank_start, + rank_end=rank_end, + window_size=window_size, + step=step, + shuffle_candidates=shuffle_candidates, + logging=logging, + ) diff --git a/src/rank_llm/rerank/zephyr_reranker.py b/src/rank_llm/rerank/zephyr_reranker.py new file mode 100644 index 00000000..36633849 --- /dev/null +++ b/src/rank_llm/rerank/zephyr_reranker.py @@ -0,0 +1,53 @@ +from typing import List + +from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM +from rank_llm.rerank.rankllm import PromptMode +from rank_llm.rerank.reranker import Reranker +from rank_llm.result import Result + + +class ZephyrReranker: + def __init__( + self, + model_path: str = "castorini/rank_zephyr_7b_v1_full", + context_size: int = 4096, + prompt_mode: PromptMode = PromptMode.RANK_GPT, + num_few_shot_examples: int = 0, + device: str = "cuda", + num_gpus: int = 1, + variable_passages: bool = True, + window_size: int = 20, + system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query", + ): + agent = RankListwiseOSLLM( + model=model_path, + context_size=context_size, + prompt_mode=prompt_mode, + num_few_shot_examples=num_few_shot_examples, + device=device, + num_gpus=num_gpus, + variable_passages=variable_passages, + window_size=window_size, + system_message=system_message, + ) + self._reranker = Reranker(agent) + + def rerank( + self, + retrieved_results: List[Result], + rank_start: int = 0, + rank_end: int = 100, + window_size: int = 20, + step: int = 10, + shuffle_candidates: bool = False, + logging: bool = False, + ): + return self._reranker.rerank( + retrieved_results=retrieved_results, + rank_start=rank_start, + rank_end=rank_end, + window_size=window_size, + step=step, + shuffle_candidates=shuffle_candidates, + logging=logging, + ) diff --git a/src/rank_llm/retrieve_and_rerank.py b/src/rank_llm/retrieve_and_rerank.py index 2c34c26b..28958a4a 100644 --- a/src/rank_llm/retrieve_and_rerank.py +++ b/src/rank_llm/retrieve_and_rerank.py @@ -97,7 +97,7 @@ def retrieve_and_rerank( else: raise ValueError(f"Invalid retrieval mode: {retrieval_mode}") print("Reranking:") - reranker = Reranker(agent, top_k_candidates) + reranker = Reranker(agent) for pass_ct in range(num_passes): print(f"Pass {pass_ct + 1} of {num_passes}:") rerank_results = reranker.rerank( @@ -115,6 +115,7 @@ def retrieve_and_rerank( retrieval_method.name, rerank_results, shuffle_candidates, + top_k_candidates=top_k_candidates, pass_ct=None if num_passes == 1 else pass_ct, window_size=window_size, dataset_name=dataset,