In [None]:
import os

import cohere
from dotenv import load_dotenv, find_dotenv
import weaviate

from utils import dense_retrieval, keyword_search, print_result

In [None]:
co = cohere.Client(os.environ['COHERE_API_KEY'])
_ = load_dotenv(find_dotenv()) # read local .env file
auth_config = weaviate.auth.AuthApiKey(
    api_key=os.environ['WEAVIATE_API_KEY'])

In [None]:
client = weaviate.Client(
    url=os.environ['WEAVIATE_API_URL'],
    auth_client_secret=auth_config,
    additional_headers={
        "X-Cohere-Api-Key": os.environ['COHERE_API_KEY']})

In [None]:
query = "What is the capital of Canada?"
dense_retrieval_results = dense_retrieval(query, client)
print_result(dense_retrieval_results)

In [None]:
query_1 = "What is the capital of Canada?"
results = keyword_search(
    query_1,
    client,
    properties=[
        "text", "title", "url", "views", "lang", 
        "_additional {distance}"],
    num_results=3)
for i, result in enumerate(results):
    print(f"i:{i}")
    print(result.get('title'))
    print(result.get('text'))

In [None]:
results = keyword_search(
    query_1,
    client,
    properties=[
        "text", "title", "url", "views", "lang",
        "_additional {distance}"],
    num_results=500)

for i, result in enumerate(results):
    print(f"i:{i}")
    print(result.get('title'))
    #print(result.get('text'))

In [None]:
def rerank_responses(query, responses, num_responses=10):
    reranked_responses = co.rerank(
        model='rerank-english-v2.0',
        query=query,
        documents=responses,
        top_n=num_responses)
    return reranked_responses

In [None]:
texts = [result.get('text') for result in results]
reranked_text = rerank_responses(query_1, texts)
for i, rerank_result in enumerate(reranked_text):
    print(f"i:{i}")
    print(f"{rerank_result}")
    print()

In [None]:
query_2 = "Who is the tallest person in history?"
results = dense_retrieval(query_2, client)
for i, result in enumerate(results):
    print(f"i:{i}")
    print(result.get('title'))
    print(result.get('text'))
    print()

In [None]:
texts = [result.get('text') for result in results]
reranked_text = rerank_responses(query_2, texts)
for i, rerank_result in enumerate(reranked_text):
    print(f"i:{i}")
    print(f"{rerank_result}")
    print()