Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create rz rv rerankers #57

Merged
merged 8 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/rank_llm/demo/rerank_dataset_with_prebuilt_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

from rank_llm.retrieve.retriever import Retriever
from rank_llm.retrieve.pyserini_retriever import RetrievalMethod
from rank_llm.rerank.vicuna_reranker import VicunaReranker

# By default uses BM25 for retrieval
dataset_name = "dl19"
retrieved_results = Retriever.from_dataset_with_prebuit_index(dataset_name)
print(retrieved_results)
# TODO: add rerank instead of printing retrieved results
reranker = VicunaReranker()
rerank_results = reranker.rerank(retrieved_results)
print(rerank_results)

# Users can specify other retrieval methods:
retrieved_results = Retriever.from_dataset_with_prebuit_index(
dataset_name, RetrievalMethod.SPLADE_P_P_ENSEMBLE_DISTIL
)
print(retrieved_results)
# TODO: add rerank instead of printing retrieved results
reranker = VicunaReranker()
rerank_results = reranker.rerank(retrieved_results)
print(rerank_results)
6 changes: 4 additions & 2 deletions src/rank_llm/demo/rerank_inline_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
sys.path.append(parent)

from rank_llm.retrieve.retriever import Retriever
from rank_llm.rerank.zephyr_reranker import ZephyrReranker

query = "What is the capital of the United States?"
docs = [
Expand All @@ -17,5 +18,6 @@
]

retrieved_results = Retriever.from_inline_documents(query, documents=docs)
print(retrieved_results)
# TODO: add rerank instead of printing retrieved results
reranker = ZephyrReranker()
rerank_results = reranker.rerank(retrieved_results)
print(rerank_results)
6 changes: 4 additions & 2 deletions src/rank_llm/demo/rerank_inline_hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
sys.path.append(parent)

from rank_llm.retrieve.retriever import Retriever
from rank_llm.rerank.zephyr_reranker import ZephyrReranker

query = "how long is life cycle of flea"
hits = [
Expand Down Expand Up @@ -69,5 +70,6 @@
]

retrieved_results = Retriever.from_inline_hits(query=query, hits=hits)
print(retrieved_results)
# TODO: add rerank instead of printing retrieved results
reranker = ZephyrReranker()
rerank_results = reranker.rerank(retrieved_results)
print(rerank_results)
6 changes: 4 additions & 2 deletions src/rank_llm/demo/rerank_stored_retrieved_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
sys.path.append(parent)

from rank_llm.retrieve.retriever import Retriever
from rank_llm.rerank.zephyr_reranker import ZephyrReranker

file_name = "retrieve_results/BM25/retrieve_results_dl19.json"
retrieved_results = Retriever.from_saved_results(file_name)
print(retrieved_results)
# TODO: add rerank instead of printing retrieved results
reranker = ZephyrReranker()
rerank_results = reranker.rerank(retrieved_results)
print(rerank_results)
31 changes: 20 additions & 11 deletions src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import List, Union, Dict, Any
from typing import List

from tqdm import tqdm

Expand All @@ -10,21 +10,29 @@


class Reranker:
def __init__(self, agent: RankLLM, top_k_candidates: int) -> None:
def __init__(self, agent: RankLLM) -> None:
self._agent = agent
self._top_k_candidates = top_k_candidates

def rerank(self, retrieved_results: List[Result], **kwargs):
def rerank(
self,
retrieved_results: List[Result],
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
):
rerank_results = []
for result in tqdm(retrieved_results):
rerank_result = self._agent.sliding_windows(
result,
rank_start=0,
rank_end=kwargs["rank_end"],
window_size=kwargs["window_size"],
step=kwargs["step"],
shuffle_candidates=kwargs["shuffle_candidates"],
logging=kwargs["logging"],
rank_start=max(rank_start, 0),
rank_end=min(rank_end, len(result.hits)),
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
rerank_results.append(rerank_result)
return rerank_results
Expand All @@ -34,14 +42,15 @@ def write_rerank_results(
retrieval_method_name: str,
results: List[Result],
shuffle_candidates: bool = False,
top_k_candidates: int = 100,
pass_ct: int = None,
window_size: int = None,
dataset_name: str = None,
) -> str:
_modelname = self._agent._model.split("/")[-1]
if _modelname.startswith("checkpoint"):
_modelname = self._agent._model.split("/")[-2] + "_" + _modelname
name = f"{_modelname}_{self._agent._context_size}_{self._top_k_candidates}_{self._agent._prompt_mode}"
name = f"{_modelname}_{self._agent._context_size}_{top_k_candidates}_{self._agent._prompt_mode}"
if dataset_name:
name = f"{name}_{dataset_name}"
if self._agent._num_few_shot_examples > 0:
Expand Down
53 changes: 53 additions & 0 deletions src/rank_llm/rerank/vicuna_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List

from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM
from rank_llm.rerank.rankllm import PromptMode
from rank_llm.rerank.reranker import Reranker
from rank_llm.result import Result


class VicunaReranker:
def __init__(
self,
model_path: str = "castorini/rank_vicuna_7b_v1",
context_size: int = 4096,
prompt_mode: PromptMode = PromptMode.RANK_GPT,
num_few_shot_examples: int = 0,
device: str = "cuda",
num_gpus: int = 1,
variable_passages: bool = False,
window_size: int = 20,
system_message: str = None,
):
agent = RankListwiseOSLLM(
model=model_path,
context_size=context_size,
prompt_mode=prompt_mode,
num_few_shot_examples=num_few_shot_examples,
device=device,
num_gpus=num_gpus,
variable_passages=variable_passages,
window_size=window_size,
system_message=system_message,
)
self._reranker = Reranker(agent)

def rerank(
self,
retrieved_results: List[Result],
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
):
return self._reranker.rerank(
retrieved_results=retrieved_results,
rank_start=rank_start,
rank_end=rank_end,
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
53 changes: 53 additions & 0 deletions src/rank_llm/rerank/zephyr_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List

from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM
from rank_llm.rerank.rankllm import PromptMode
from rank_llm.rerank.reranker import Reranker
from rank_llm.result import Result


class ZephyrReranker:
def __init__(
self,
model_path: str = "castorini/rank_zephyr_7b_v1_full",
context_size: int = 4096,
prompt_mode: PromptMode = PromptMode.RANK_GPT,
num_few_shot_examples: int = 0,
device: str = "cuda",
num_gpus: int = 1,
variable_passages: bool = True,
window_size: int = 20,
system_message: str = "You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query",
):
agent = RankListwiseOSLLM(
model=model_path,
context_size=context_size,
prompt_mode=prompt_mode,
num_few_shot_examples=num_few_shot_examples,
device=device,
num_gpus=num_gpus,
variable_passages=variable_passages,
window_size=window_size,
system_message=system_message,
)
self._reranker = Reranker(agent)

def rerank(
self,
retrieved_results: List[Result],
rank_start: int = 0,
rank_end: int = 100,
window_size: int = 20,
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
):
return self._reranker.rerank(
retrieved_results=retrieved_results,
rank_start=rank_start,
rank_end=rank_end,
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
)
3 changes: 2 additions & 1 deletion src/rank_llm/retrieve_and_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def retrieve_and_rerank(
else:
raise ValueError(f"Invalid retrieval mode: {retrieval_mode}")
print("Reranking:")
reranker = Reranker(agent, top_k_candidates)
reranker = Reranker(agent)
for pass_ct in range(num_passes):
print(f"Pass {pass_ct + 1} of {num_passes}:")
rerank_results = reranker.rerank(
Expand All @@ -115,6 +115,7 @@ def retrieve_and_rerank(
retrieval_method.name,
rerank_results,
shuffle_candidates,
top_k_candidates=top_k_candidates,
pass_ct=None if num_passes == 1 else pass_ct,
window_size=window_size,
dataset_name=dataset,
Expand Down