Skip to content

Commit

Permalink
add rerank endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
lfayoux committed Feb 1, 2023
1 parent 0c43388 commit d5b039e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
1 change: 1 addition & 0 deletions cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FEEDBACK_URL = 'feedback'
GENERATE_URL = 'generate'
SUMMARIZE_URL = 'summarize'
RERANK_URL = 'rerank'

CHECK_API_KEY_URL = 'check-api-key'
TOKENIZE_URL = 'tokenize'
Expand Down
36 changes: 35 additions & 1 deletion cohere/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from urllib.parse import urljoin

import requests
Expand All @@ -22,6 +22,7 @@
from cohere.generation import Generations
from cohere.tokenize import Tokens
from cohere.summarize import SummarizeResponse
from cohere.rerank import Reranking

use_xhr_client = False
try:
Expand Down Expand Up @@ -309,6 +310,39 @@ def feedback(self, id: str, good_response: bool, desired_response: str = "", fee
Feedback: a Feedback object
"""

def rerank(self,
query: str,
documents: Union[List[str], List[Dict[str, Any]]],
top_n: int = None) -> Reranking:
"""Returns an ordered list of documents oridered by their relevance to the provided query
Args:
query (str): The search query
documents (list[str], list[dict]): The documents to rerank
top_n (int): (optional) The number of results to return, defaults to return all results
"""
parsed_docs = []
for doc in documents:
if isinstance(doc, str):
parsed_docs.append({'text': doc})
elif isinstance(doc, dict) and 'text' in doc:
parsed_docs.append(doc)
else:
raise CohereError(
message='invalid format for documents, must be a list of strings or dicts with a "text" key')

json_body = {
"query": query,
"documents": parsed_docs,
"top_n": top_n,
"return_documents": False
}
rankings = Reranking(self.__request(cohere.RERANK_URL, json=json_body)).results
for rank in rankings:
rank.document = parsed_docs[rank.index]
return rankings

def feedback(self, id: str, feedback: str, accepted: bool):
json_body = {
'id': id,
'good_response': good_response,
Expand Down
64 changes: 64 additions & 0 deletions cohere/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional, Dict, NamedTuple, Any, Iterator

from cohere.response import CohereObject

RerankDocument = NamedTuple("Document", [("text", str)])
RerankDocument.__doc__ = """
Returned by co.rerank,
dict which always contain text but can also contain aribitrary fields
"""


class RerankResult(CohereObject):

def __init__(self,
document: Dict[str, Any] = None,
index: int = None,
relevance_score: float = None,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.document = document
self.index = index
self.relevance_score = relevance_score

def __repr__(self) -> str:
score = self.relevance_score
index = self.index
if self.document is None:
return f"RerankResult<index: {index}, relevance_score: {score}>"
else:
text = self.document['text']
return f"RerankResult<text: {text}, index: {index}, relevance_score: {score}>"


class Reranking(CohereObject):

def __init__(self,
response: Optional[Dict[str, Any]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
assert response is not None
self.results = self._results(response)

def _results(self, response: Dict[str, Any]) -> List[RerankResult]:
results = []
for res in response['results']:
if 'document' in res.keys():
results.append(
RerankResult(res['document'], res['index'], res['relevance_score']))
else:
results.append(
RerankResult(index=res['index'], relevance_score=res['relevance_score']))
return results

def __str__(self) -> str:
return str(self.results)

def __repr__(self) -> str:
return self.results.__repr__()

def __iter__(self) -> Iterator:
return iter(self.results)

def __getitem__(self, index) -> RerankResult:
return self.results[index]

0 comments on commit d5b039e

Please sign in to comment.