Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add additional parameters #123

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
62 changes: 51 additions & 11 deletions src/rank_llm/api/server.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
import argparse

import torch
from flask import Flask, jsonify, request

from rank_llm import retrieve_and_rerank
from rank_llm.rerank.api_keys import get_azure_openai_args, get_openai_api_key
from rank_llm.rerank.rank_gpt import SafeOpenai
from rank_llm.rerank.rank_listwise_os_llm import RankListwiseOSLLM
from rank_llm.rerank.rankllm import PromptMode
from rank_llm.retrieve.pyserini_retriever import RetrievalMethod
from rank_llm.retrieve.retriever import RetrievalMode

""" API URL FORMAT

http://localhost:8082/api/model/{model_name}/index/{index_name}/{retriever_base_host}?query={query}&hits_retriever={top_k_retriever}&hits_reranker={top_k_reranker}&qid={qid}&num_passes={num_passes}
http://localhost:{host_name}/api/model/{model_name}/index/{index_name}/{retriever_base_host}?query={query}&hits_retriever={top_k_retriever}&hits_reranker={top_k_reranker}&qid={qid}&num_passes={num_passes}&retrieval_method={retrieval_method}

hits_retriever, hits_reranker, qid, and num_passes are OPTIONAL
Default to 20, 5, None, and 1 respectively
Default to 20, 10, None, and 1 respectively

"""


def create_app(model, port, use_azure_openai=False):
app = Flask(__name__)

global default_agent
default_agent = None

# Load specified model upon server initialization
if model == "rank_zephyr":
print(f"Loading {model} model...")
# Load specified model upon server initialization
default_agent = RankListwiseOSLLM(
model=f"castorini/{model}_7b_v1_full",
name=model,
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
num_few_shot_examples=0,
Expand All @@ -36,9 +44,9 @@ def create_app(model, port, use_azure_openai=False):
)
elif model == "rank_vicuna":
print(f"Loading {model} model...")
# Load specified model upon server initialization
default_agent = RankListwiseOSLLM(
model=f"castorini/{model}_7b_v1",
name=model,
context_size=4096,
prompt_mode=PromptMode.RANK_GPT,
num_few_shot_examples=0,
Expand All @@ -49,7 +57,6 @@ def create_app(model, port, use_azure_openai=False):
)
elif "gpt" in model:
print(f"Loading {model} model...")
# Load specified model upon server initialization
openai_keys = get_openai_api_key()
print(openai_keys)
default_agent = SafeOpenai(
Expand All @@ -60,27 +67,55 @@ def create_app(model, port, use_azure_openai=False):
keys=openai_keys,
**(get_azure_openai_args() if use_azure_openai else {}),
)
elif model in ["rank_random", "rank_identity"]:
# no rankLLm agent is required for trivial rerankers.
default_agent = model
else:
raise ValueError(f"Unsupported model: {model}")

# Start server
@app.route(
"/api/model/<string:model_path>/index/<string:dataset>/<string:retriever_host>",
methods=["GET"],
)
def search(model_path, dataset, retriever_host):
"""retrieve and rerank (search)

Args:
- model_path (str): name of reranking model (e.g., rank_zephyr)
- dataset (str): dataset from which to retrieve
- retriever_host (str): host of Anserini API
"""

# query to search for
query = request.args.get("query", type=str)
# search all of dataset and return top k candidates
top_k_retrieve = request.args.get("hits_retriever", default=20, type=int)
top_k_rerank = request.args.get("hits_reranker", default=5, type=int)
# rerank top_k_retrieve candidates from retrieve stage and return top_k_rerank candidates
top_k_rerank = request.args.get("hits_reranker", default=10, type=int)
# qid of query
qid = request.args.get("qid", default=None, type=str)
# number of passes reranker goes through
num_passes = request.args.get("num_passes", default=1, type=int)
# retrieval method to use
retrieval_method = request.args.get(
"retrieval_method", default="bm25", type=str
)

if "bm25" in retrieval_method.lower():
_retrieval_method = RetrievalMethod.BM25
else:
return jsonify({"error": str("Retrieval method must be BM25")}), 500

# If the request model is not the default model
global default_agent
if default_agent is not None and model_path != default_agent.get_name():
# Delete the old agent to clear up the CUDA cache
del default_agent # this line is required for clearing the cache
torch.cuda.empty_cache()
default_agent = None
try:
# Assuming the function is called with these parameters and returns a response
response = retrieve_and_rerank.retrieve_and_rerank(
# calls Anserini retriever API and reranks
(response, agent) = retrieve_and_rerank.retrieve_and_rerank(
dataset=dataset,
retrieval_mode=RetrievalMode.DATASET,
query=query,
model_path=model_path,
host="http://localhost:" + retriever_host,
Expand All @@ -91,8 +126,13 @@ def search(model_path, dataset, retriever_host):
populate_exec_summary=False,
default_agent=default_agent,
num_passes=num_passes,
retrieval_method=_retrieval_method,
print_prompts_responses=False
)

# set the default reranking agent to the most recently used reranking agent
default_agent = agent

return jsonify(response[0]), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/rerank/identity_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def rerank_batch(
shuffle_candidates: bool = False,
) -> List[Result]:
"""
A trivial reranker that returns a subsection of the retireved candidates list as-is or shuffled.
A trivial reranker that returns a subsection of the retrieved candidates list as-is or shuffled.

Args:
requests (List[Request]): The list of requests. Each request has a query and a candidates list.
Expand Down
4 changes: 4 additions & 0 deletions src/rank_llm/rerank/rank_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
f"unsupported prompt mode for GPT models: {prompt_mode}, expected {PromptMode.RANK_GPT}, {PromptMode.RANK_GPT_APEER} or {PromptMode.LRL}."
)

self._model = model
self._window_size = window_size
self._output_token_estimate = None
self._keys = keys
Expand Down Expand Up @@ -346,3 +347,6 @@ def cost_per_1k_token(self, input_token: bool) -> float:
}
model_key = "gpt-3.5" if "gpt-3" in self._model else "gpt-4"
return cost_dict[(model_key, self._context_size)]

def get_name(self) -> str:
return self._model
5 changes: 5 additions & 0 deletions src/rank_llm/rerank/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class RankListwiseOSLLM(RankLLM):
def __init__(
self,
model: str,
name: str,
context_size: int = 4096,
prompt_mode: PromptMode = PromptMode.RANK_GPT,
num_few_shot_examples: int = 0,
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
model, device=device, num_gpus=num_gpus
)
self._vllm_batched = vllm_batched
self._name = name
self._variable_passages = variable_passages
self._window_size = window_size
self._system_message = system_message
Expand Down Expand Up @@ -257,3 +259,6 @@ def get_num_tokens(self, prompt: str) -> int:

def cost_per_1k_token(self, input_token: bool) -> float:
return 0

def get_name(self) -> str:
return self._name
9 changes: 7 additions & 2 deletions src/rank_llm/rerank/rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,22 @@ def receive_permutation(
Items not mentioned in the permutation string remain in their original sequence but are moved after
the permuted items.
"""

# Parse and normalize the permutation indices
response = self._clean_response(permutation)
response = [int(x) - 1 for x in response.split()]
response = self._remove_duplicate(response)

# Extract the relevant candidates and create a mapping for new order
cut_range = copy.deepcopy(result.candidates[rank_start:rank_end])
original_rank = [tt for tt in range(len(cut_range))]
response = [ss for ss in response if ss in original_rank]
response = response + [tt for tt in original_rank if tt not in response]

# Update candidates in the new order
for j, x in enumerate(response):
result.candidates[j + rank_start] = copy.deepcopy(cut_range[x])
if result.candidates[j + rank_start].score:
result.candidates[j + rank_start].score = cut_range[j].score

return result

def _replace_number(self, s: str) -> str:
Expand Down
56 changes: 36 additions & 20 deletions src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from datetime import datetime
from pathlib import Path
from typing import List
from enum import Enum

from tqdm import tqdm

from rank_llm.data import DataWriter, Request, Result
from rank_llm.rerank.rankllm import RankLLM

class OperationMode(Enum):
STANDARD = 0
VLLM = 1
T5 = 2

@classmethod
def from_int(cls, val):
for mode in cls:
if mode.value == val:
return mode
raise ValueError(f"{val} is not a valid {cls.__name__}")

class Reranker:
def __init__(self, agent: RankLLM) -> None:
Expand All @@ -21,7 +33,7 @@ def rerank_batch(
step: int = 10,
shuffle_candidates: bool = False,
logging: bool = False,
vllm_batched: bool = False,
operation_mode: OperationMode = OperationMode.STANDARD,
populate_exec_summary: bool = True,
) -> List[Result]:
"""
Expand All @@ -44,11 +56,26 @@ def rerank_batch(
Returns:
List[Result]: A list containing the reranked candidates.
"""
if vllm_batched:
for i in range(1, len(requests)):
assert len(requests[0].candidates) == len(
requests[i].candidates
), "Batched requests must have the same number of candidates"

if operation_mode == OperationMode.STANDARD:
results = []
for request in tqdm(requests):
result = self._agent.sliding_windows(
request,
rank_start=max(rank_start, 0),
rank_end=min(rank_end, len(request.candidates)),
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
populate_exec_summary=populate_exec_summary,
)
results.append(result)
return results
elif operation_mode == OperationMode.VLLM:
if len(set([len(req.candidates) for req in requests])) !=1:
raise ValueError("Batched requests must have the same number of candidates")

return self._agent.sliding_windows_batched(
requests,
rank_start=max(rank_start, 0),
Expand All @@ -60,20 +87,9 @@ def rerank_batch(
shuffle_candidates=shuffle_candidates,
logging=logging,
)
results = []
for request in tqdm(requests):
result = self._agent.sliding_windows(
request,
rank_start=max(rank_start, 0),
rank_end=min(rank_end, len(request.candidates)),
window_size=window_size,
step=step,
shuffle_candidates=shuffle_candidates,
logging=logging,
populate_exec_summary=populate_exec_summary,
)
results.append(result)
return results
else: # T5 Operation mode
# TODO
return ()

def rerank(
self,
Expand Down
2 changes: 0 additions & 2 deletions src/rank_llm/retrieve/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from rank_llm.data import Request
from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod
from rank_llm.retrieve.repo_info import HITS_INFO
from rank_llm.retrieve.utils import compute_md5, download_cached_hits


class RetrievalMode(Enum):
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/retrieve/service_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def retrieve(
ValueError: If the retrieval mode is invalid or the result format is not as expected.
"""

url = f"{host}/api/index/{dataset}/search?query={parse.quote(request.query.text)}&hits={str(k)}&qid={request.query.qid}"
url = f"{host}/api/collection/{dataset}/search?query={parse.quote(request.query.text)}&hits={str(k)}&qid={request.query.qid}"

try:
response = requests.get(url, timeout=timeout)
Expand Down
Loading