In [None]:
%pip install cohere

In [1]:
import os
from dotenv import load_dotenv
load_dotenv()

import cohere

co = cohere.ClientV2(api_key=os.environ['COHERE_API_KEY']) 

In [3]:
from datasets import load_dataset
import pandas as pd
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter

ds = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus")
ds = ds['passages'].to_pandas().set_index('id', drop=True)
query_set = load_dataset("rag-datasets/rag-mini-bioasq", "question-answer-passages")
queries = query_set['test'].take(5)

# create a subset of the documents for faster testing

passages_required = set()
[ passages_required.update([int(id) for id in ids[1:-1].split(", ")]) for ids in query_set['test'].take(15)['relevant_passage_ids'] ];


docs = [Document(text=ds.loc[id].passage, metadata = {'id' : id}) for id in passages_required]
splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50)
nodes = splitter.get_nodes_from_documents(docs)
chunks = [node.text for node in nodes]

In [4]:
model = "embed-english-v3.0"

def batch_embed(texts, batch_size=96):
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        response = co.embed(
            texts=batch,
            model=model,
            input_type="search_document",
            embedding_types=['float']
        )
        all_embeddings.extend(response.embeddings.float)
    return all_embeddings

embeddings = batch_embed(chunks)
print(f"We just computed {len(embeddings)} embeddings.")

We just computed 155 embeddings.


In [5]:
import numpy as np
vector_database = {i: np.array(embedding) for i, embedding in enumerate(embeddings)}

In [6]:


query = queries[0]['question']


In [7]:
response = co.embed(
    texts=[query],
    model=model,
    input_type="search_query",
    embedding_types=['float']
)
query_embedding = response.embeddings.float[0]
print("query_embedding: ", query_embedding[:10] + ["..."])

query_embedding:  [-0.008018494, -0.025268555, -0.053741455, -0.04623413, -0.017044067, -0.0012226105, -0.04498291, 0.05731201, 0.011512756, -0.008666992, '...']


In [8]:
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# Calculate similarity between the user question & each chunk
similarities = [cosine_similarity(query_embedding, chunk) for chunk in embeddings]
print("similarity scores: ", similarities)

# Get indices of the top 10 most similar chunks
sorted_indices = np.argsort(similarities)[::-1]

# Keep only the top 10 indices
top_indices = sorted_indices[:10]
print("Here are the indices of the top 10 chunks after retrieval: ", top_indices)

# Retrieve the top 10 most similar chunks
top_chunks_after_retrieval = [chunks[i] for i in top_indices]
print("Here are the top 10 chunks after retrieval: ")
for t in top_chunks_after_retrieval:
    print("== " + t)

similarity scores:  [0.6184294480165881, 0.35236333210235044, 0.6998376372823466, 0.1330268226639961, 0.19542812668472234, 0.20182491411760228, 0.1330268226639961, 0.13034141861550594, 0.15351200064770593, 0.1383580878989443, 0.1330268226639961, 0.2798832268739146, 0.32773845614129604, 0.22731230612971828, 0.243941297149989, 0.23709537181882878, 0.25090650199710235, 0.7327030962189807, 0.1330268226639961, 0.22924309225486195, 0.3232216151120532, 0.1330268226639961, 0.1330268226639961, 0.2986756883536321, 0.1330268226639961, 0.24557032868948817, 0.13361017291978516, 0.20641959870428853, 0.25782486171727714, 0.19298501049239145, 0.13361017291978516, 0.3154883379141274, 0.2377010196180163, 0.24871249075263652, 0.13361017291978516, 0.13361017291978516, 0.30832163746233104, 0.2758915890131315, 0.13361017291978516, 0.13361017291978516, 0.2915465699640612, 0.3405231815248782, 0.13361017291978516, 0.32547634308238665, 0.256965052759638, 0.17408926596160476, 0.2654720391680199, 0.17746757258965

In [9]:
response = co.rerank(
    query=query,
    documents=top_chunks_after_retrieval,
    top_n=3,
    model="rerank-english-v3.0",
)

# top_chunks_after_rerank = [result.document['text'] for result in response]

top_chunks_after_rerank = [top_chunks_after_retrieval[result.index] for result in response.results]

print("Here are the top 3 chunks after rerank: ")
for t in top_chunks_after_rerank:
    print("== " + t)

Here are the top 3 chunks after rerank: 
== Hirschsprung's disease (HSCR) is a fairly frequent cause of intestinal 
obstruction in children. It is characterized as a sex-linked heterogonous 
disorder with variable severity and incomplete penetrance giving rise to a 
variable pattern of inheritance. Although Hirschsprung's disease occurs as an 
isolated phenotype in at least 70% of cases, it is not infrequently associated 
with a number of congenital abnormalities and associated syndromes, 
demonstrating a spectrum of congenital anomalies. Certain of these syndromic 
phenotypes have been linked to distinct genetic sites, indicating underlying 
genetic associations of the disease and probable gene-gene interaction, in its 
pathogenesis. These associations with HSCR include Down's syndrome and other 
chromosomal anomalies, Waardenburg syndrome and other Dominant sensorineural 
deafness, the Congenital Central Hypoventilation and Mowat-Wilson and other 
brain-related syndromes, as well as 

In [10]:
preamble = """
## Task & Context
You are a biology assistant. You are given a question and a context. Your task is to answer the question based on the context.

## Style Guide
Use scientific language and style, and avoid colloquialisms or slang.
"""

In [11]:
documents = [
    {"data": {"title": "chunk 0", "snippet": top_chunks_after_rerank[0]}},
    {"data": {"title": "chunk 1", "snippet": top_chunks_after_rerank[1]}},
    {"data": {"title": "chunk 2", "snippet": top_chunks_after_rerank[2]}},
  ]

# get model response
response = co.chat(
  model="command-r-08-2024",
  messages=[{"role" : "system", "content" : preamble},
            {"role" : "user", "content" : query}],
  documents=documents,  
  temperature=0.3
)

print("Final answer:")
print(response.message.content[0].text)

Final answer:
Hirschsprung disease (HSCR) is a hereditary disorder that causes intestinal obstruction. It is characterised by the absence of ganglion cells in the myenteric and submucosal plexuses of the gastrointestinal tract.

The majority of identified genes are related to Mendelian syndromic forms of Hirschsprung disease. However, non-syndromic non-familial, short-segment HSCR appears to represent a non-Mendelian condition with variable expression and sex-dependent penetrance.

Therefore, Hirschsprung disease is both a Mendelian and a multifactorial disorder.


In [12]:
print("Citations that support the final answer:")
for cite in response.message.citations:
    print(cite)

Citations that support the final answer:
start=21 end=27 text='(HSCR)' sources=[DocumentSource(type='document', id='doc:0', document={'id': 'doc:0', 'snippet': "Hirschsprung's disease (HSCR) is a fairly frequent cause of intestinal \nobstruction in children. It is characterized as a sex-linked heterogonous \ndisorder with variable severity and incomplete penetrance giving rise to a \nvariable pattern of inheritance. Although Hirschsprung's disease occurs as an \nisolated phenotype in at least 70% of cases, it is not infrequently associated \nwith a number of congenital abnormalities and associated syndromes, \ndemonstrating a spectrum of congenital anomalies. Certain of these syndromic \nphenotypes have been linked to distinct genetic sites, indicating underlying \ngenetic associations of the disease and probable gene-gene interaction, in its \npathogenesis. These associations with HSCR include Down's syndrome and other \nchromosomal anomalies, Waardenburg syndrome and other Dominant s

In [13]:
def insert_inline_citations(text, citations, field='text'):
    sorted_citations = sorted(citations, key=lambda c: c.start, reverse=True)
    
    for citation in sorted_citations:
        source_ids = [source.id.split(':')[-1] for source in citation.sources]
        citation_text = f"[{','.join(source_ids)}]"
        text = text[:citation.end] + citation_text + text[citation.end:]
    
    return text

def list_sources(citations, fields=['text']):
    unique_sources = set()
    for citation in citations:
        for source in citation.sources:
            source_data = tuple((field, source.document[field]) for field in fields if field in source.document)
            unique_sources.add((source.id.split(':')[-1], source_data))
    
    footnotes = []
    for source_id, source_data in sorted(unique_sources):
        footnote = f"[{source_id}] " + ", ".join(f"{key}: {value}" for key, value in source_data)
        footnotes.append(footnote)
    
    return "\n".join(footnotes)

# Use the functions
cited_text = insert_inline_citations(response.message.content[0].text, response.message.citations)

# Print the result with inline citations
print(cited_text)

# Print footnotes
if response.message.citations:
    print("\nSource documents:")
    print(list_sources(response.message.citations, fields=['title','snippet']))

Hirschsprung disease (HSCR)[0,1,2] is a hereditary disorder[1] that causes intestinal obstruction.[0,1] It is characterised by the absence of ganglion cells in the myenteric and submucosal plexuses of the gastrointestinal tract.[2]

The majority of identified genes are related to Mendelian syndromic forms of Hirschsprung disease.[2] However, non-syndromic non-familial, short-segment HSCR appears to represent a non-Mendelian condition with variable expression and sex-dependent penetrance.[0]

Therefore, Hirschsprung disease is both a Mendelian and a multifactorial disorder.

Source documents:
[0] title: chunk 0, snippet: Hirschsprung's disease (HSCR) is a fairly frequent cause of intestinal 
obstruction in children. It is characterized as a sex-linked heterogonous 
disorder with variable severity and incomplete penetrance giving rise to a 
variable pattern of inheritance. Although Hirschsprung's disease occurs as an 
isolated phenotype in at least 70% of cases, it is not infrequently as