# Implementing an Information Retrieval System

A key aspect of RAG systems is an Information Retrieval component. This is the part of the overall system that deals with finding the correct document (or in this case, nodes and clusters) to answer the user's query. If we dont implement a good IR module, the system will not be able to find the corresponding data and will likely hallucinate its answers.

Here I'm experimenting with a few "classic" NLP techniques to index documents and extract meaning from them. This approach allows me to run this part of the system locally with very few resources and without using additional HF calls, which makes it very resource-effective. I'm generally following the guidelines provided by [anthropic's recent blogpost](https://www.anthropic.com/news/contextual-retrieval).

## Setup

In [1]:
import pandas as pd
import pickle
import numpy as np
import bm25s 

data_processed = "../data/processed/"
data_external = "../data/external/"

## Load Data

I will be using preprocessed documents which have already been cleaned, filtered and lemmatized. Since these are medical texts, i dont think stemming will be a good idea since we could potentially lose a lot of specific vocabulary.

In [4]:
with open(data_external+"graph_data/processed_node_documents.pickle", 'rb') as handle:
    processed_node_documents = pickle.load(handle)

processed_node_documents = {int(k):v for k,v in processed_node_documents.items()}
graph_node_data = pd.read_csv(data_external+"graph_data/graph_node_data.csv")
raw_node_documents = pd.read_csv(data_external+"graph_data/disease_attributes.csv")

Generating "cluster documents" which are compilations of all of the diseases they contain and their information. These are highly noisy and unstructured documents.

In [41]:
def cluster_as_document(cluster_id,cluster_algorithm):

    cluster_nodes = graph_node_data.loc[graph_node_data[cluster_algorithm] == cluster_id, "node_index"].values
    cluster_corpus = [processed_node_documents[node_index] for node_index in cluster_nodes]
    cluster_document = " ".join(cluster_corpus)

    return cluster_document

infomap_ids = range(int(graph_node_data.dropna().comunidades_infomap.max()))
corpus = [cluster_as_document(cluster_id,"comunidades_infomap") for cluster_id in infomap_ids]

In [50]:
disease_attributes = pd.merge(graph_node_data[["node_index","comunidades_infomap","comunidades_louvain"]],raw_node_documents,left_on="node_index",right_on="node_index",how="right").set_index("node_index",drop=True)

def node_as_document(node_index,df,join_titles):
    if not pd.isna(disease_attributes.loc[node_index,"umls_description"]):
        data = df.loc[node_index,["node_name","umls_description"]].values.astype(str)
        if join_titles:
            document = " is described by UMLS as ".join(data)
        else: 
            document = str(data[1])
    elif not pd.isna(disease_attributes.loc[node_index,"mondo_definition"]):
        data = df.loc[node_index,["node_name","mondo_definition"]].values.astype(str)
        if join_titles:
            document = " is defined by MONDO as ".join(data)
        else: 
            document = str(data[1])
    elif not pd.isna(disease_attributes.loc[node_index,"orphanet_definition"]):
        data = df.loc[node_index,["node_name","orphanet_definition"]].values.astype(str)
        if join_titles:
            document = " is defined by Orphanet as ".join(data)
        else: 
            document = str(data[1])
    else:
        document = df.loc[node_index,"node_name"]
    return document

def cluster_as_raw_document(cluster_id,cluster_algorithm,join_titles=False):
    cluster_nodes = graph_node_data[graph_node_data.node_type == "disease"].loc[graph_node_data[cluster_algorithm] == cluster_id, "node_index"].values
    cluster_corpus = [node_as_document(node_index,disease_attributes,join_titles) for node_index in cluster_nodes]
    cluster_document = " ".join(cluster_corpus)

    return cluster_document

In [6]:
class GraphCorpus():
    def __init__(self, data_path):
        self.data_path = data_path
        self.node_data = pd.read_csv(self.data_path+"graph_node_data.csv")

        with open(self.data_path+"processed_node_documents.pickle", 'rb') as handle:
            processed_node_documents = pickle.load(handle)
        self.processed_node_documents = {int(k):v for k,v in processed_node_documents.items()}

        raw_node_documents = pd.read_csv(self.data_path+"disease_attributes.csv")
        self.disease_attributes = pd.merge(self.node_data[["node_index","comunidades_infomap"]],raw_node_documents,left_on="node_index",right_on="node_index",how="right").set_index("node_index",drop=True)
    
    def node_as_document(self,node_index,join_titles):
        if not pd.isna(self.disease_attributes.loc[node_index,"umls_description"]):
            data = self.disease_attributes.loc[node_index,["node_name","umls_description"]].values.astype(str)
            if join_titles:
                document = " is described by UMLS as ".join(data)
            else: 
                document = str(data[1])
        elif not pd.isna(self.disease_attributes.loc[node_index,"mondo_definition"]):
            data = self.disease_attributes.loc[node_index,["node_name","mondo_definition"]].values.astype(str)
            if join_titles:
                document = " is defined by MONDO as ".join(data)
            else: 
                document = str(data[1])
        elif not pd.isna(self.disease_attributes.loc[node_index,"orphanet_definition"]):
            data = self.disease_attributes.loc[node_index,["node_name","orphanet_definition"]].values.astype(str)
            if join_titles:
                document = " is defined by Orphanet as ".join(data)
            else: 
                document = str(data[1])
        else:
            document = self.disease_attributes.loc[node_index,"node_name"]
        return document

    def cluster_as_raw_document(self,cluster_id,join_titles=False):
        cluster_nodes = self.node_data[self.node_data.node_type == "disease"].loc[self.node_data["comunidades_infomap"] == cluster_id, "node_index"].values
        cluster_corpus = [self.node_as_document(node_index,join_titles) for node_index in cluster_nodes]
        cluster_document = " ".join(cluster_corpus)

        return cluster_document

    def cluster_as_document(self,cluster_id,cluster_algorithm):
        cluster_nodes = self.node_data.loc[self.node_data[cluster_algorithm] == cluster_id, "node_index"].values
        cluster_corpus = [self.processed_node_documents[node_index] for node_index in cluster_nodes]
        cluster_document = " ".join(cluster_corpus)

        return cluster_document
    
    def get_cluster_corpus(self):
        infomap_ids = range(int(self.node_data.dropna().comunidades_infomap.max()))
    
        return [self.cluster_as_document(cluster_id,"comunidades_infomap") for cluster_id in infomap_ids]


In [95]:
graph_corpus = GraphCorpus("../data/external/graph_data/")

In [96]:
graph_corpus.cluster_as_raw_document(62)

'A type of acute or chronic skin reaction in which sensitivity is manifested by reactivity to materials or substances coming in contact with the skin. It may involve allergic or non-allergic mechanisms. An inflammatory process in skin caused by an exogenous agent that directly or indirectly injure the skin. If the offending agent is identified and removed, the eruption will resolve. An unusual or patterned eruption may be a clue to the presence of a contact dermatitis. Patch testing may be helpful in the differential diagnosis.  An inflammatory skin condition caused by direct contact between the skin and either an irritating substance or an allergen. type of acute or chronic skin reaction in which sensitivity is manifested by reactivity to materials or substances coming in contact with the skin; may involve allergic or non-allergic mechanisms. A recurrent contact dermatitis caused by substances found in the work place. Contact dermatitis associated with allergens or irritants found in 

## Index corpus with BM25-S

In [52]:
# Tokenize the corpus and only keep the ids (faster and saves memory)
corpus_tokens = bm25s.tokenize(corpus, stopwords="en")

# Create the BM25 model and index the corpus
retriever = bm25s.BM25()
retriever.index(corpus_tokens)

Split strings:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/1147 [00:00<?, ?it/s]

In [58]:
# Query the corpus
query = "Show me diseases that affect memory"
query_tokens = bm25s.tokenize(query,stopwords="en")

# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=2)

for i in range(results.shape[1]):
    doc, score = results[0, i], scores[0, i]
    print(f"Rank {i+1} (score: {score:.2f}): {doc}")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

Rank 1 (score: 3.67):  familial alzheimer fad  early onset autosomal dominant alzheimer eoad progressive dementia reduction cognitive function eoad present phenotype sporadic alzheimer ad early age onset usually year old  early onset autosomal dominant alzheimer eoad progressive dementia reduction cognitive function eoad present phenotype sporadic alzheimer ad early age onset usually year old   alzheimer s  progressive neurodegenerative loss function death nerve cell area brain lead loss cognitive function memory language  acute potentially fatal metabolic condition cerebral edema fatty liver hypoglycemia occurs primarily child associate use aspirin treatment viral infection occur absence aspirin use disable degenerative nervous occur middle age old person dementia failure memory recent event follow total incapacitation death alzheimer differentiate age onset genetic characteristic early onset form late onset form form identify ad ad ad clinical characteristic alzheimer similar pick al

If we only want the indices of the documents we can omit the "corpus" argument

In [59]:
# Get top-k results as a tuple of (doc ids, scores). Both are arrays of shape (n_queries, k)
query = "Show me diseases that affect memory"
query_tokens = bm25s.tokenize(query,stopwords="en")
results, scores = retriever.retrieve(query_tokens, k=2)

for i in range(results.shape[1]):
    doc, score = results[0, i], scores[0, i]
    print(f"Rank {i+1} (score: {score:.2f}): {doc}")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

Rank 1 (score: 3.67): 392
Rank 2 (score: 3.54): 1


In [60]:
best_match = results[0][0]

In [63]:
print(cluster_as_raw_document(best_match,"comunidades_infomap"))



In [22]:
class QueryModule():
    def __init__(self,corpus):
        # Tokenize the corpus and only keep the ids (faster and saves memory)
        corpus_tokens = bm25s.tokenize(corpus, stopwords="en")

        # Create the BM25 model and index the corpus
        self.retriever = bm25s.BM25()
        self.retriever.index(corpus_tokens)
    
    def get_cluster(self, user_query):
        query_tokens = bm25s.tokenize(user_query,stopwords="en")
        top_match = self.retriever.retrieve(query_tokens, return_as ="documents", k=2)
        return int(top_match[0][0])

In [81]:
query_module = QueryModule(corpus)

Split strings:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/1147 [00:00<?, ?it/s]

In [78]:
query_module.get_cluster("Show me diseases that affect memory")

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

392

# Basic RAG pipeline

In [2]:
import requests
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from transformers import pipeline

load_dotenv(".env")
API_TOKEN = os.getenv("biollama_gnn_exploration")

In [3]:
class GraphCorpus():
    def __init__(self, data_path):
        self.data_path = data_path
        self.node_data = pd.read_csv(self.data_path+"graph_node_data.csv")

        with open(self.data_path+"processed_node_documents.pickle", 'rb') as handle:
            processed_node_documents = pickle.load(handle)
        self.processed_node_documents = {int(k):v for k,v in processed_node_documents.items()}

        raw_node_documents = pd.read_csv(self.data_path+"disease_attributes.csv")
        self.disease_attributes = pd.merge(self.node_data[["node_index","comunidades_infomap"]],raw_node_documents,left_on="node_index",right_on="node_index",how="right").set_index("node_index",drop=True)
    
    def node_as_document(self,node_index,join_titles):
        if not pd.isna(self.disease_attributes.loc[node_index,"umls_description"]):
            data = self.disease_attributes.loc[node_index,["node_name","umls_description"]].values.astype(str)
            if join_titles:
                document = " is described by UMLS as ".join(data)
            else: 
                document = str(data[1])
        elif not pd.isna(self.disease_attributes.loc[node_index,"mondo_definition"]):
            data = self.disease_attributes.loc[node_index,["node_name","mondo_definition"]].values.astype(str)
            if join_titles:
                document = " is defined by MONDO as ".join(data)
            else: 
                document = str(data[1])
        elif not pd.isna(self.disease_attributes.loc[node_index,"orphanet_definition"]):
            data = self.disease_attributes.loc[node_index,["node_name","orphanet_definition"]].values.astype(str)
            if join_titles:
                document = " is defined by Orphanet as ".join(data)
            else: 
                document = str(data[1])
        else:
            document = self.disease_attributes.loc[node_index,"node_name"]
        return document

    def cluster_as_raw_document(self,cluster_id,join_titles=False):
        cluster_nodes = self.node_data[self.node_data.node_type == "disease"].loc[self.node_data["comunidades_infomap"] == cluster_id, "node_index"].values
        cluster_corpus = [self.node_as_document(node_index,join_titles) for node_index in cluster_nodes]
        cluster_document = " ".join(cluster_corpus)

        return cluster_document

    def cluster_as_document(self,cluster_id):
        cluster_nodes = self.node_data.loc[self.node_data["comunidades_infomap"] == cluster_id, "node_index"].values
        cluster_corpus = [self.processed_node_documents[node_index] for node_index in cluster_nodes]
        cluster_document = " ".join(cluster_corpus)

        return cluster_document
    
    def get_cluster_corpus(self):
        infomap_ids = range(int(self.node_data.dropna().comunidades_infomap.max()))
    
        return [self.cluster_as_document(cluster_id) for cluster_id in infomap_ids]
    
    def get_raw_cluster(self,cluster_id):
        return self.cluster_as_raw_document(cluster_id)
    

In [4]:
class QueryModule():
    def __init__(self,corpus):
        # Tokenize the corpus and only keep the ids (faster and saves memory)
        corpus_tokens = bm25s.tokenize(corpus, stopwords="en")

        # Create the BM25 model and index the corpus
        self.retriever = bm25s.BM25()
        self.retriever.index(corpus_tokens)
    
    def get_cluster(self, user_query):
        query_tokens = bm25s.tokenize(user_query,stopwords="en")
        top_match = self.retriever.retrieve(query_tokens, return_as ="documents", k=2)
        return int(top_match[0][0])

In [5]:
user_query = "What causes diabetes?"

In [6]:
graph_corpus = GraphCorpus("../data/external/graph_data/")
cluster_corpus = graph_corpus.get_cluster_corpus()

query_module = QueryModule(cluster_corpus)

Split strings:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/1147 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/1147 [00:00<?, ?it/s]

In [7]:
cluster_id = query_module.get_cluster(user_query)
graph_corpus.get_raw_cluster(cluster_id)

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

'Any maturity-onset diabetes of the young in which the cause of the disease is a mutation in the CEL gene. A rare autosomal dominant form of diabetes mellitus affecting young people with a positive family history. mody is a form of monogenic diabetes, resulting from mutations in a single gene. The most common forms are HNF1alpha-MODY and gck-mody , due to mutations in the hnf1a and gck genes, respectively. The term Maturity-onset diabetes of the young was initially used for patients diagnosed with fasting hyperglycemia that could be treated without insulin for more than two years, where the initial diagnosis was made at a young age Thus, mody combines characteristics of type 1 diabetes and type 2 diabetes The term mody is now most often used to refer to a group of monogenic diseases with these characteristics. Here, the term is used to describe hyperglycemia diagnosed at a young age with no or minor insulin dependency, no evidence of insulin resistence, and lack of evidence of autoimmu

In [8]:
process_query_context = {"role":"system","content":f"In the following text, extract only the disease subject. Be concise and ONLY answer the disease name"}
text_client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=API_TOKEN)
messages = [process_query_context,
{"role":"user","content":"hello is it me youre lookjinf for ?? conditions with hair loss"}]

answer = text_client.chat_completion(messages, max_tokens=100)

print(answer.get("choices")[0]["message"]["content"])

Alopecia


In [14]:
user_query = "I want to learn more about obsessive compulsive disorder"

process_query_context = {"role":"system","content":"In the following text, extract only the disease subject. Be concise and ONLY answer the disease name"}
text_client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct", token=API_TOKEN)
messages = [process_query_context,
{"role":"user","content":user_query}]

answer = text_client.chat_completion(messages, max_tokens=100)
disease_subject = answer.get("choices")[0]["message"]["content"]

cluster_id = query_module.get_cluster(disease_subject)
disease_info = graph_corpus.get_raw_cluster(cluster_id)

form_answer_context = {"role":"system", "content":f"You are a research assistant. You provide clear and concise answers, based only in the following information: {disease_info}"}
messages = [form_answer_context,
{"role":"user","content":user_query}]

answer = text_client.chat_completion(messages, max_tokens=100)

print(answer.get("choices")[0]["message"]["content"])

Split strings:   0%|          | 0/1 [00:00<?, ?it/s]

BM25S Retrieve:   0%|          | 0/1 [00:00<?, ?it/s]

Obessive-Compulsive Disorder (OCD) is a mental health condition where individuals experience recurring, intrusive thoughts (obsessions) and repetitive behaviors (compulsions) that they feel are necessary or urgent. These thoughts and behaviors can interfere with an individual's daily life, relationships, and work.

Here are some key things to know about OCD:

**Causes:** The exact cause of OCD is still unknown, but research suggests that it may be related to a combination of genetic, environmental,


In [15]:
disease_info

"Inadequate responses to physical, social, and emotional demands; general ineptness and instability, despite absence of actual physical or mental deficit. Personality disorder characterized by pervasive patterns of dependent, passive, and submissive behavior. A personality disorder characterized by a pervasive and excessive need to be taken care of that leads to submissive and clinging behavior and fears of separation, beginning by early adulthood and present in a variety of contexts.  A disorder characterized by an enduring pattern of an extreme need to be taken care of together with fear of separation that lead the individual to urgently seek out and submit to another person and allow that person to make decisions that impact all areas of the individual's life. Personality disorders are a group of mental illnesses. They involve long-term patterns of thoughts and behaviors that are unhealthy and inflexible. The behaviors cause serious problems with relationships and work. People with 

# Generate titles and summarize clusters (TODO)

Use LLMs to extract summaries of clusters, finding symptoms and potential risk factors.

Use T5 to generate cluster titles and summaries

In [97]:
import requests
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from transformers import pipeline

load_dotenv(".env")
API_TOKEN = os.getenv("biollama_gnn_exploration")

In [98]:
summarizer = pipeline("summarization", model="t5-small")

In [125]:
cluster_info = cluster_as_raw_document(best_match,"comunidades_infomap")
summary = summarizer(cluster_info, max_length=20, min_length=10, do_sample=False)
print(summary[0]['summary_text'])

Token indices sequence length is longer than the specified maximum sequence length for this model (9373 > 512). Running this sequence through the model will result in indexing errors


: 