In [None]:
import cohere
import datasets
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.schema import TextNode
from llama_index.embeddings.cohere import CohereEmbedding
import pandas as pd

import json
from pathlib import Path
from tqdm import tqdm
from typing import List

In [3]:
from dotenv import load_dotenv
import os

# Load the .env file
load_dotenv()
cohere_api_key = os.getenv('COHERE_API_KEY')

co = cohere.Client(api_key=cohere_api_key)

### Basic embeddings

In [4]:
embed_model = CohereEmbedding(
    cohere_api_key=cohere_api_key,
    model_name="embed-english-v3.0"
)

In [None]:
def prepare_qa_texts(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    texts = [f"Q: {item['question']} A: {item['answer']}" for item in data]
    
    return texts

texts = prepare_qa_texts("../data/home0001qa.json")
print(texts[:1])

In [15]:
embeddings = co.embed(
    texts=texts,
    model="embed-english-v3.0",
    input_type="search_document",
    embedding_types=['float']
)

In [None]:
print(len(embeddings.embeddings.float))

### Embeddings with Text Nodes

In [None]:
def prepare_text_nodes(file_path):

    with open(file_path, 'r') as f:
        data = json.load(f)

    nodes = [
        TextNode(
            text=f"Q: {entry['question']}\nA: {entry['answer']}",
            metadata={"question": entry["question"]}
        )
        for idx, entry in enumerate(data)
    ]
    return nodes

text_nodes = prepare_text_nodes("../data/home0001qa.json")

print(text_nodes[0])
print(text_nodes[0].metadata)


In [20]:
index = VectorStoreIndex(text_nodes, embed_model=embed_model)

index.storage_context.persist("./cohere")

In [22]:
# To load persist db
storage_context = StorageContext.from_defaults(persist_dir="./cohere")

index = load_index_from_storage(storage_context, embed_model=embed_model)

In [21]:
retriever = index.as_retriever(similarity_top_k=3)

### Retriever with rerank

In [26]:
class RetrieverWithRerank:
    def __init__(self, retriever, api_key):
        self.retriever = retriever
        self.co = cohere.Client(api_key=api_key)

    def retrieve(self, query: str, top_n: int):
        # First call to the retriever fetches the closest indices
        nodes = self.retriever.retrieve(query)
        nodes = [
            {
                "text": node.node.text,
                "llamaindex_id": node.node.id_,
            }
            for node
            in nodes
        ]
        # Call co.rerank to improve the relevance of retrieved documents
        reranked = self.co.rerank(query=query, documents=nodes, model="rerank-english-v3.0", top_n=top_n)
        nodes = [nodes[node.index] for node in reranked.results]
        return nodes


top_k = 10 # how many documents to fetch on first pass
top_n = 4 # how many documents to sub-select with rerank

retriever = RetrieverWithRerank(
    index.as_retriever(similarity_top_k=top_k),
    api_key=cohere_api_key,
)



In [None]:
query = "Where is HOME0001 available?"

documents = retriever.retrieve(query, top_n=top_n)

response = co.chat(message=query, model="command-r", temperature=0., documents=documents)
print(response.text)

In [None]:
def build_answer_with_citations(response):
    """ """
    text = response.text
    citations = response.citations

    # Construct text_with_citations adding citation spans as we iterate through citations
    end = 0
    text_with_citations = ""

    for citation in citations:
        # Add snippet between last citatiton and current citation
        start = citation.start
        text_with_citations += text[end : start]
        end = citation.end  # overwrite
        citation_blocks = " [" + ", ".join([stub[4:] for stub in citation.document_ids]) + "] "
        text_with_citations += text[start : end] + citation_blocks
    # Add any left-over
    text_with_citations += text[end:]

    return text_with_citations

grounded_answer = build_answer_with_citations(response)
print(grounded_answer)