Skip to content

Commit

Permalink
add rerank endpoint (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfayoux committed Feb 2, 2023
1 parent 0c43388 commit 6e5a272
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 3 deletions.
1 change: 1 addition & 0 deletions cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FEEDBACK_URL = 'feedback'
GENERATE_URL = 'generate'
SUMMARIZE_URL = 'summarize'
RERANK_URL = 'rerank'

CHECK_API_KEY_URL = 'check-api-key'
TOKENIZE_URL = 'tokenize'
Expand Down
36 changes: 34 additions & 2 deletions cohere/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from urllib.parse import urljoin

import requests
Expand All @@ -22,6 +22,7 @@
from cohere.generation import Generations
from cohere.tokenize import Tokens
from cohere.summarize import SummarizeResponse
from cohere.rerank import Reranking

use_xhr_client = False
try:
Expand Down Expand Up @@ -308,7 +309,6 @@ def feedback(self, id: str, good_response: bool, desired_response: str = "", fee
Returns:
Feedback: a Feedback object
"""

json_body = {
'id': id,
'good_response': good_response,
Expand All @@ -318,6 +318,38 @@ def feedback(self, id: str, good_response: bool, desired_response: str = "", fee
self.__request(cohere.FEEDBACK_URL, json_body)
return Feedback(id=id, good_response=good_response, desired_response=desired_response, feedback=feedback)

def rerank(self,
query: str,
documents: Union[List[str], List[Dict[str, Any]]],
top_n: int = None) -> Reranking:
"""Returns an ordered list of documents ordered by their relevance to the provided query
Args:
query (str): The search query
documents (list[str], list[dict]): The documents to rerank
top_n (int): (optional) The number of results to return, defaults to returning all results
"""
parsed_docs = []
for doc in documents:
if isinstance(doc, str):
parsed_docs.append({'text': doc})
elif isinstance(doc, dict) and 'text' in doc:
parsed_docs.append(doc)
else:
raise CohereError(
message='invalid format for documents, must be a list of strings or dicts with a "text" key')

json_body = {
"query": query,
"documents": parsed_docs,
"top_n": top_n,
"return_documents": False
}
reranking = Reranking(self.__request(cohere.RERANK_URL, json=json_body))
for rank in reranking.results:
rank.document = parsed_docs[rank.index]
return reranking

def __print_warning_msg(self, response: Response):
if 'X-API-Warning' in response.headers:
print("\033[93mWarning: {}\n\033[0m".format(response.headers['X-API-Warning']), file=sys.stderr)
Expand Down
64 changes: 64 additions & 0 deletions cohere/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional, Dict, NamedTuple, Any, Iterator

from cohere.response import CohereObject

RerankDocument = NamedTuple("Document", [("text", str)])
RerankDocument.__doc__ = """
Returned by co.rerank,
dict which always contains text but can also contain arbitrary fields
"""


class RerankResult(CohereObject):

def __init__(self,
document: Dict[str, Any] = None,
index: int = None,
relevance_score: float = None,
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.document = document
self.index = index
self.relevance_score = relevance_score

def __repr__(self) -> str:
score = self.relevance_score
index = self.index
if self.document is None:
return f"RerankResult<index: {index}, relevance_score: {score}>"
else:
text = self.document['text']
return f"RerankResult<document['text']: {text}, index: {index}, relevance_score: {score}>"


class Reranking(CohereObject):

def __init__(self,
response: Optional[Dict[str, Any]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
assert response is not None
self.results = self._results(response)

def _results(self, response: Dict[str, Any]) -> List[RerankResult]:
results = []
for res in response['results']:
if 'document' in res.keys():
results.append(
RerankResult(res['document'], res['index'], res['relevance_score']))
else:
results.append(
RerankResult(index=res['index'], relevance_score=res['relevance_score']))
return results

def __str__(self) -> str:
return str(self.results)

def __repr__(self) -> str:
return self.results.__repr__()

def __iter__(self) -> Iterator:
return iter(self.results)

def __getitem__(self, index) -> RerankResult:
return self.results[index]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def has_ext_modules(foo) -> bool:


setuptools.setup(name='cohere',
version='3.2.6',
version='3.3.0',
author='1vn',
author_email='ivan@cohere.ai',
description='A Python library for the Cohere API',
Expand Down

0 comments on commit 6e5a272

Please sign in to comment.