In [1]:
%pip install -r requirements.txt --quiet

Note: you may need to restart the kernel to use updated packages.


In [None]:
import cohere
from typing import List
import numpy as np
import json

In [None]:
def generate_search_queries(message: str) -> List[str]:
    
    # Define the query generation tool
    query_gen_tool = [
        {
            "type": "function",
            "function": {
                "name": "internet_search",
                "description": "Returns a list of relevant document snippets for a textual query retrieved from the internet",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "queries": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "a list of queries to search the internet with.",
                        }
                    },
                    "required": ["queries"],
                },
            },
        }
    ]


    # Define a preamble to optimize search query generation
    instructions = "Write a search query that will find helpful information for answering the user's question accurately. If you need more than one search query, write a list of search queries. If you decide that a search is very unlikely to find information that would be useful in constructing a response to the user, you should instead directly answer."

    # Generate search queries (if any)
    search_queries = []
    
    res = co.chat(
        model="command-r-08-2024",
        messages=[
            {"role": "system", "content": instructions},
            {"role": "user", "content": message},
        ],
        tools=query_gen_tool
    )
    
    if res.message.tool_calls:
        for tc in res.message.tool_calls:
            queries = json.loads(tc.function.arguments)["queries"]
            search_queries.extend(queries)

    return search_queries


In [None]:
query = "Suggest some products to Jane Doe based on their purchase history"
queries_for_search = generate_search_queries(query)
print(queries_for_search)

In [None]:
#### Embed search Query

# Add the user query
query = "Suggest some products to Jane Doe based on their purchase history"

# Generate the search query
# Note: For simplicity, we are assuming only one query generated. For actual implementations, you will need to perform search for each query.
queries_for_search = generate_search_queries(query)[0]
print("Search query: ", queries_for_search)

# Embed the search query
query_emb = co.embed(
    model="embed-english-v3.0",
    input_type="search_query",
    texts=[queries_for_search],
    embedding_types=["float"]).embeddings.float


In [None]:
### Search for most relavant documents

# Compute dot product similarity and display results
n = 5
scores = np.dot(query_emb, np.transpose(embeddings))[0]
max_idx = np.argsort(-scores)[:n]

retrieved_documents = [document_texts[item] for item in max_idx]

for rank, idx in enumerate(max_idx):
    print(f"Rank: {rank+1}")
    print(f"Score: {scores[idx]}")
    print(f"Document: {retrieved_documents[rank]}\n")


In [None]:
# Rerank the documents
retrieved_documents_str = [str(doc) for doc in retrieved_documents]

results = co.rerank(query=queries_for_search,
                    documents=retrieved_documents_str,
                    top_n=2,
                    model='rerank-english-v3.0')

# Display the reranking results
for idx, result in enumerate(results.results):
    print(f"Rank: {idx+1}") 
    print(f"Score: {result.relevance_score}")
    print(f"Document: {retrieved_documents[result.index]}\n")
    
reranked_documents = [retrieved_documents_str[result.index] for result in results.results]


In [None]:
# Generate the response
response = co.chat(model="command-r-plus-08-2024",
                   messages=[{'role': 'user', 'content': query}],
                   documents=reranked_documents)

# Display the response
print(response.message.content[0].text)

# Display the citations and source documents
if response.message.citations:
    print("\nCITATIONS:")
    for citation in response.message.citations:
        print(citation, "\n")
