<a href="https://colab.research.google.com/github/fsminako/text_rag/blob/main/5588654_rag_m1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RETRIEVAL AUGMENTED GENERATION (RAG) FOR MEDICAL RESEARCH

## Dataset Loading

The dataset used in this study will be medical research abstract sourced from the ArXiv library.

In [1]:
#Installing necessary packages
!pip install arxiv

Collecting arxiv
  Downloading arxiv-2.1.0-py3-none-any.whl (11 kB)
Collecting feedparser==6.0.10 (from arxiv)
  Downloading feedparser-6.0.10-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.1/81.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting sgmllib3k (from feedparser==6.0.10->arxiv)
  Downloading sgmllib3k-1.0.0.tar.gz (5.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: sgmllib3k
  Building wheel for sgmllib3k (setup.py) ... [?25l[?25hdone
  Created wheel for sgmllib3k: filename=sgmllib3k-1.0.0-py3-none-any.whl size=6049 sha256=bbe9368ecf35cf6bc850654e5abc2285d0ef13e08eb23e0ba317028f99434128
  Stored in directory: /root/.cache/pip/wheels/f0/69/93/a47e9d621be168e9e33c7ce60524393c0b92ae83cf6c6e89c5
Successfully built sgmllib3k
Installing collected packages: sgmllib3k, feedparser, arxiv
Successfully installed arxiv-2.1.0 feedparser-6.0.10 sgmllib3k-1.0.0


In [2]:
#Import packages
import arxiv
import numpy as np
import pandas as pd

In [3]:
#Total observation that will be used in this study is 100 abstracts
n_records = 100

client = arxiv.Client()

search = arxiv.Search(
  query = "medical", #specifying the topic of the research
  max_results = n_records,
  sort_by = arxiv.SortCriterion.SubmittedDate #sorting the search based on the latest journal
)

results = client.results(search)

In [4]:
#Abstract extraction process
abstracts = []

for r in client.results(search):
  abstracts.append(r.summary)

# Naming the column for the dataframe
df_data = {'abstract': abstracts}


In [5]:
#Saving the extracted data as a data frame
df = pd.DataFrame(df_data)
df.head()

Unnamed: 0,abstract
0,The mining of adverse drug events (ADEs) is pi...
1,To address existing challenges with intravascu...
2,"In the past years, the amount of research on a..."
3,Many observational studies feature irregular l...
4,"In medical image analysis, the expertise scarc..."


## Data Cleaning

In [6]:
import re

In [7]:
#Defining the cleaning function
def cleaning(text):
    if isinstance(text, str):
        url_pattern = re.compile(r'https://\S+|www\.\S+')
        text = url_pattern.sub('', text)
        text = re.sub(r"[’]", "'", text)
        text = re.sub(r"[^a-zA-Z\s'-]", "", text)
        text = ' '.join(text.split())
        text = text.lower()
    return text

df['abstract'] = df['abstract'].apply(lambda x: cleaning(x))

This process will:
*   remove url from the texts
*   converting " ’ " to " ' "
*   remove non-alphabetic character except ' and -
*   remove any extra whitespace (ensure that only single whitespace between each word)
*   convert all character into lowercase

In [8]:
df.head()

Unnamed: 0,abstract
0,the mining of adverse drug events ades is pivo...
1,to address existing challenges with intravascu...
2,in the past years the amount of research on ac...
3,many observational studies feature irregular l...
4,in medical image analysis the expertise scarci...


In [9]:
#Save the dataframe as a csv file
df["abstract"].to_csv("abstract.csv")

## Chunking

In [10]:
#Installing required library
!pip install llama_index.core
!pip install llama_index.readers.file

Collecting llama_index.core
  Downloading llama_index_core-0.10.39.post1-py3-none-any.whl (15.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.4/15.4 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
Collecting dataclasses-json (from llama_index.core)
  Downloading dataclasses_json-0.6.6-py3-none-any.whl (28 kB)
Collecting deprecated>=1.2.9.3 (from llama_index.core)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting dirtyjson<2.0.0,>=1.0.8 (from llama_index.core)
  Downloading dirtyjson-1.0.8-py3-none-any.whl (25 kB)
Collecting httpx (from llama_index.core)
  Downloading httpx-0.27.0-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting llamaindex-py-client<0.2.0,>=0.1.18 (from llama_index.core)
  Downloading llamaindex_py_client-0.1.19-py3-none-any.whl (141 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 

In [11]:
#Importing the library
from llama_index.readers.file import FlatReader
from llama_index.core.node_parser import SentenceSplitter
from pathlib import Path

To separate the text into chunks, we will be using the SentenceSplitter from llama_index. This function will split the text in such a way that one sentence will not be separated into different chunk

In [12]:
#Importing the dataset
documents = FlatReader().load_data(Path("/content/abstract.csv"))

# we will limit to chunk size 100
parser = SentenceSplitter(chunk_size=100, chunk_overlap=10)
doc_nodes = parser.get_nodes_from_documents(documents)

In [13]:
#Make a separate directory for the chunk data to ensure that it does not get mixed up with other data file
!mkdir -p '/content/chunk_data/'

In [14]:
# Directory to save the individual chunk files
output_dir = Path("/content/chunk_data/")

# Save each chunk into a separate file
for i, node in enumerate(doc_nodes):
    output_file_path = output_dir / f"chunk_{i+1}.txt"
    with output_file_path.open("w", encoding="utf-8") as f:
        f.write(node.text)

print(f"Saved {len(doc_nodes)} chunks to {output_dir}")

Saved 265 chunks to /content/chunk_data


Each chunk will be saved into different data file. This process will later be helpful to identify which document is used to generate response during query processing

## Embedding

In [15]:
#Installing necessary packages
!pip install langchain
!pip install langchain-community
!pip install sentence-transformers
!pip install llama-index-embeddings-langchain

Collecting langchain
  Downloading langchain-0.2.1-py3-none-any.whl (973 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/973.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.8/973.5 kB[0m [31m6.1 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m522.2/973.5 kB[0m [31m7.2 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m880.6/973.5 kB[0m [31m8.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m973.5/973.5 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting langchain-core<0.3.0,>=0.2.0 (from langchain)
  Downloading langchain_core-0.2.1-py3-none-any.whl (308 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m308.5/308.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain

For our medical abstracts dataset, we will use PubMedBERT as our embedding model. PubMedBERT is trained on abstracts extracted from PubMed making it highly suitable for our dataset

In [16]:
#Importing necessary library for the embeddings model
from langchain.embeddings import HuggingFaceEmbeddings

#Importing PubMedBERT from the hugging face library
embedding_model = HuggingFaceEmbeddings(model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

## Indexing

In [18]:
from llama_index.core import SimpleDirectoryReader

# Load all the documents in the chunk_data directory
reader = SimpleDirectoryReader("/content/chunk_data") # load documents from the /data folder
documents = reader.load_data()
print(f"{len(documents)} documents are loaded")

265 documents are loaded


In [19]:
#Installing necessary library
!pip install llama-index-vector-stores-chroma
!pip install chromadb

Collecting llama-index-vector-stores-chroma
  Downloading llama_index_vector_stores_chroma-0.1.8-py3-none-any.whl (4.8 kB)
Collecting chromadb<0.6.0,>=0.4.0 (from llama-index-vector-stores-chroma)
  Downloading chromadb-0.5.0-py3-none-any.whl (526 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m526.8/526.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting chroma-hnswlib==0.7.3 (from chromadb<0.6.0,>=0.4.0->llama-index-vector-stores-chroma)
  Downloading chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fastapi>=0.95.2 (from chromadb<0.6.0,>=0.4.0->llama-index-vector-stores-chroma)
  Downloading fastapi-0.111.0-py3-none-any.whl (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting uvicorn[standard]

In [20]:
%%time
#Importing required packafes
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex

# Creating a medical_articles database
db = chromadb.PersistentClient(path="./medical_articles_db")

# Create a table inside the database called "medical-abstract"
chroma_collection = db.create_collection("medical-abstract")

vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

# Indexing the documents into the databse
vector_index = VectorStoreIndex.from_documents(
    documents,
    storage_context = storage_context,
    embed_model = embedding_model
)

# Printing the metadata
print(chroma_collection)

name='medical-abstract' id=UUID('db6e0803-936f-4692-8ca3-eb3a98afa1e4') metadata=None tenant='default_tenant' database='default_database'
CPU times: user 1min 54s, sys: 441 ms, total: 1min 54s
Wall time: 2min 1s


## Prompt Template

Prompt template is crucial to engineer better response. We will use a customised prompt template from the llama library.
The prompt template that we use ensure that the LLM generate response as a medical expert but avoiding the use of medical terminology that is not generally used.

In [21]:
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core import ChatPromptTemplate

#Prompt string for the LLM
qa_prompt_str = (
    "You are a medical expert, give responses to the following "
    "question: {query_str}. Do not use technical words, give easy "
    "to understand responses."
)

# Text QA Prompt
chat_text_qa_msgs = [
    ChatMessage(
        role=MessageRole.SYSTEM,
        content=(
            "Always answer the question, even if the context isn't helpful."
        ),
    ),
    ChatMessage(role=MessageRole.USER, content=qa_prompt_str),
]

text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)

## Query Processing and Response Generation

We will inegrate our RAG system with T5 as the LLM. T5 utilises text-to-text approach where all NLP problems are converted into a text generation format, makes it particularly adaptable for specialized domains, including the medical field.

In [22]:
#Installing necessary packages
!pip install transformers
!pip install llama-index-llms-langchain

Collecting llama-index-llms-langchain
  Downloading llama_index_llms_langchain-0.1.3-py3-none-any.whl (4.6 kB)
Collecting langchain<0.2.0,>=0.1.3 (from llama-index-llms-langchain)
  Downloading langchain-0.1.20-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Collecting llama-index-llms-anyscale<0.2.0,>=0.1.1 (from llama-index-llms-langchain)
  Downloading llama_index_llms_anyscale-0.1.4-py3-none-any.whl (4.2 kB)
Collecting llama-index-llms-openai<0.2.0,>=0.1.1 (from llama-index-llms-langchain)
  Downloading llama_index_llms_openai-0.1.21-py3-none-any.whl (11 kB)
Collecting langchain-community<0.1,>=0.0.38 (from langchain<0.2.0,>=0.1.3->llama-index-llms-langchain)
  Downloading langchain_community-0.0.38-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-core<0.2.0,>=0.1.52

In [25]:
#Importing LLM from Hugging face
from langchain.llms import HuggingFaceHub
import os

API_TOKEN = "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"
os.environ["HUGGINGFACEHUB_API_TOKEN"] = API_TOKEN
llm = HuggingFaceHub(repo_id = "google/flan-t5-base", model_kwargs={"temperature":0.6})

  warn_deprecated(


In [26]:
#Testing the LLM without integrating with our vector database
llm("How does text mining assist the treatment of mental health disorder?")

  warn_deprecated(


'text mining is a form of cognitive behavioral therapy'

In [28]:
#Input query for our RAG system
query = "Explain the application of deep learning models in medical image analysis"

In [29]:
%%time
#Response processing
query_engine = vector_index.as_query_engine(
   text_qa_template=text_qa_template,
   llm=llm
)

response = query_engine.query(query)
response.response

  warn_deprecated(


CPU times: user 465 ms, sys: 42.4 ms, total: 508 ms
Wall time: 1.03 s


'A deep learning model is a model that can be used to analyze images.'

In [30]:
#Retrieving the related documents to generate the response
response.metadata

{'2e3e413d-5b7a-4ae9-98b0-b40990f64d60': {'file_path': '/content/chunk_data/chunk_224.txt',
  'file_name': 'chunk_224.txt',
  'file_type': 'text/plain',
  'file_size': 527,
  'creation_date': '2024-05-27',
  'last_modified_date': '2024-05-27'},
 'b273fdb7-fbea-4ca1-97d3-3589a2b1fd7f': {'file_path': '/content/chunk_data/chunk_194.txt',
  'file_name': 'chunk_194.txt',
  'file_type': 'text/plain',
  'file_size': 631,
  'creation_date': '2024-05-27',
  'last_modified_date': '2024-05-27'}}

In [31]:
#Checking the text of the retrieved metadata
doc_nodes[98].text

'regions we conduct disease localization experiments on medical image datasets and achieve the best performance on multiple evaluation metrics compared with previous interpretable attribution methods we performed additional ablation studies to verify the effectiveness of each method\n36,malaria is a life-threatening infectious disease caused by plasmodium parasites which poses a significant public health challenge worldwide particularly in tropical and subtropical regions timely and accurate detection of malaria parasites in blood cells is crucial for effective treatment and control of the'

In [32]:
#Checking the text of the retrieved metadata
doc_nodes[95].text

'state-of-the-art baselines the code is available at\n35,with the widespread application of deep learning technology in medical image analysis how to effectively explain model decisions and improve diagnosis accuracy has become an urgent problem that needs to be solved attribution methods have become a key tool to help doctors better understand the diagnostic basis of models and they are used to explain and localize diseases in medical images however previous methods suffer from inaccurate and incomplete localization problems for fundus diseases with complex'