Skip to content

Commit

Permalink
Add message to documents (#12552)
Browse files Browse the repository at this point in the history
This adds the response message as a document to the rag retriever so
users can choose to use this. Also drops document limit.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
billytrend-cohere and baskaryan committed Nov 9, 2023
1 parent 5f38770 commit b346d4a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
14 changes: 12 additions & 2 deletions libs/langchain/langchain/chat_models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ async def _astream(
if run_manager:
await run_manager.on_llm_new_token(delta)

def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
return {
"documents": response.documents,
"citations": response.citations,
"search_results": response.search_results,
"search_queries": response.search_queries,
"token_count": response.token_count,
}

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -185,7 +195,7 @@ def _generate(
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
Expand All @@ -211,7 +221,7 @@ async def _agenerate(
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
generation_info = self._get_generation_info(response)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
Expand Down
24 changes: 17 additions & 7 deletions libs/langchain/langchain/retrievers/cohere_rag_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,27 @@


def _get_docs(response: Any) -> List[Document]:
return [
docs = [
Document(page_content=doc["snippet"], metadata=doc)
for doc in response.generation_info["documents"]
]
docs.append(
Document(
page_content=response.message.content,
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
)
return docs


class CohereRagRetriever(BaseRetriever):
"""`ChatGPT plugin` retriever."""

top_k: int = 3
"""Number of documents to return."""
"""Cohere Chat API with RAG."""

connectors: List[Dict] = Field(default_factory=lambda: [{"id": "web-search"}])
"""
Expand Down Expand Up @@ -55,7 +65,7 @@ def _get_relevant_documents(
callbacks=run_manager.get_child(),
**kwargs,
).generations[0][0]
return _get_docs(res)[: self.top_k]
return _get_docs(res)

async def _aget_relevant_documents(
self,
Expand All @@ -73,4 +83,4 @@ async def _aget_relevant_documents(
**kwargs,
)
).generations[0][0]
return _get_docs(res)[: self.top_k]
return _get_docs(res)

0 comments on commit b346d4a

Please sign in to comment.