## Build Medical Question Answering system using LangChain and Mistral 7B 

In [None]:
# Uncomment the following block to install required libraries 
"""
!pip install langchain chromadb sentence-transformers
!pip install  openai tiktoken
!pip install jq
!pip install faiss
!pip install pymilvus
"""

- Setting the API key of HuggingFace to load the model  

In [2]:
import os
os.environ['HUGGINGFACEHUB_API_TOKEN']='YOUR_HF_API_KEY'

* Load the PubMed articles from the JSON file. To prepare the JSON file, please refer to the script `download_pubmed.py`

In [1]:

from langchain.document_loaders import JSONLoader

def metadata_func(record: dict, metadata: dict) -> dict:
    # Define the metadata extraction function.
    metadata["year"] = record.get("pub_date").get('year')
    metadata["month"] = record.get("pub_date").get('month')
    metadata["day"] = record.get("pub_date").get('day')
    metadata["title"] = record.get("article_title")
    
    return metadata

loader = JSONLoader(
    file_path='/data/pubmed_article_december-2023.json',
    jq_schema='.[]',
    content_key='article_abstract',
    metadata_func=metadata_func)
data = loader.load()
print(f"{len(data)} pubmed articles are loaded!")
data[1]

8267 pubmed articles are loaded!


Document(page_content='Fluoroquinolones (FQs) are one of the most commonly prescribed classes of antibiotics. Although they were initially well tolerated in randomized clinical trials, subsequent epidemiological studies have reported an increased risk of threatening, severe, long-lasting, disabling and irreversible adverse effects (AEs), related to neurotoxicity and collagen degradation, such as tendonitis, Achilles tendon rupture, aortic aneurysm, and retinal detachment. This article reviews the main potentially threatening AEs, the alarms issued by regulatory agencies and therapeutic alternatives.', metadata={'source': '/data/pubmed_article_december-2023.json', 'seq_num': 2, 'year': '2023', 'month': '12', 'day': '22', 'title': 'Safety of fluoroquinolones.'})

- Chunk abstracts into small text passages for efficient retrieval and LLM context length

In [2]:
from langchain.text_splitter import TokenTextSplitter,CharacterTextSplitter
text_splitter = TokenTextSplitter(chunk_size=128, chunk_overlap=64)
chunks = text_splitter.split_documents(data)
print(f"{len(data)} pubmed articles are converted to {len(chunks)} text fragments!")
chunks[0]

8267 pubmed articles are converted to 31506 text fragments!


Document(page_content="Gait disorders are a common feature of neurological disease. The gait examination is an essential part of the neurological clinical assessment, providing valuable clues to a myriad of causes. Understanding how to examine gait is not only essential for neurological diagnosis but also for treatment and prognosis. Here, we review aspects of the clinical history and examination of neurological gait to help guide gait disorder assessment. We focus particularly on how to differentiate between common gait abnormalities and highlight the characteristic features of the more prevalent neurological gait patterns such as ataxia, waddling, steppage, spastic gait, Parkinson's disease and functional g", metadata={'source': '/data/pubmed_article_december-2023.json', 'seq_num': 1, 'year': '2023', 'month': '12', 'day': '22', 'title': 'Neurological gait assessment.'})

- Load the embedding model. The following code defines two options for loading the model: 
    - **Option a:** Using SentenceTransformerEmbeddings framework to load their most performing model `all-mpnet-base-v2`
    - **Option b:** Using HuggingFaceEmbeddings hub to load the popular model `e5-large-unsupervised`

In [5]:
# Option a: using all-mpnet from SentenceTransformer 
#from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
#embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")

# Option b: using e5-large-unspupervised from huggingface 
from langchain.embeddings import HuggingFaceEmbeddings
modelPath = "intfloat/e5-large-unsupervised"
embeddings = HuggingFaceEmbeddings(
  model_name = modelPath,  
  model_kwargs = {'device':'cuda'},
  encode_kwargs={'normalize_embeddings':False}
)

- Build the vector databse (VDB) to index the text chunks and their corresponsding vectors. We also define three options to define the VDB: 
    - **Option a:** Using chromaDB
    - **Option b:** Using Milvus
    - **Option c:** Using FAISS index

#TODO Add definition and comparison between the two options

In [6]:
'''
# Option a: Using chroma database
from langchain.vectorstores import Chroma
db = Chroma.from_documents(chunks, embeddings)
'''

'''
# Option b: Using Milvus database
# To run the following code, you should have a milvus instance up and running
# Follow the instructions in the following the link: https://milvus.io/docs/install_standalone-docker.md
from langchain.vectorstores import Milvus
db = Milvus.from_documents(
    chunks,
    embeddings,
    connection_args={"host": "127.0.0.1", "port": "19530"},
)
'''

# Using faiss index
from langchain.vectorstores import FAISS
db = FAISS.from_documents(chunks, embeddings)

'\n# Using Milvus database\n# To run the following code, you should have a milvus instance up and running\n# Follow the instructions in the following the link: https://milvus.io/docs/install_standalone-docker.md\nfrom langchain.vectorstores import Milvus\ndb = Milvus.from_documents(\n    chunks,\n    embeddings,\n    connection_args={"host": "127.0.0.1", "port": "19530"},\n)\n'

- Load pre-trained Mistral 7B

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, pipeline
from langchain import HuggingFacePipeline

model_id = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=False, device_map='auto')

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128)
llm = HuggingFacePipeline(
    pipeline = pipe,
    model_kwargs={"temperature": 0, "max_length": 1024}
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

- Define the RAG pipeline using LangChain. The LLM's answer highly depends on the prompt template, that's why we tested three different prompts. The one giving the best answer as PROMPT2. 

#TODO: Add explanation about the three prompts

In [8]:
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import time

# PROMPT 1
PROMPT_TEMPLATE_1 = """Answer the question based only on the following context:
{context}
You are allowed to rephrase the answer based on the context. 
Question: {question}
"""
PROMPT1 = PromptTemplate.from_template(PROMPT_TEMPLATE_1)

# PROMPT 2
PROMPT_TEMPLATE_2="Your are a medical assistant for question-answering tasks. Answer the Question using the provided Contex only. Your answer should be in your own words and be no longer than 128 words. \n\n Context: {context} \n\n Question: {question} \n\n Answer:"
PROMPT2 = PromptTemplate.from_template(PROMPT_TEMPLATE_2)

# PROMPT 3
from langchain import hub
PROMPT3 = hub.pull("rlm/rag-prompt", api_url="https://api.hub.langchain.com")

# RAG pipeline
qa_chain = RetrievalQA.from_chain_type(
    llm,
    retriever=db.as_retriever(k=2),
    chain_type_kwargs={"prompt": PROMPT2},
    return_source_documents=True
)

- Run one sample query `"What are the safest cryopreservation methods?"

In [9]:
start_time = time.time()
query = "What are the safest cryopreservation methods?"
result = qa_chain({"query": query})
print(f"\n--- {time.time() - start_time} seconds ---")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- 6.445125579833984 seconds ---


In [10]:
print(result['result'].strip())
titles = ['\t-'+doc.metadata['title'] for doc in result['source_documents']]
print("\n\nThe provided answer is based on the following PubMed articles:\t")
print("\n".join(set(titles)))

The safest cryopreservation methods are those that use ionic liquids, deep eutectic solvents, or certain polymers, which open the door to new cryopreservation methods and are also less toxic to frozen samples.


The provided answer is based on the following PubMed articles:	
	-Advances in Cryopreservatives: Exploring Safer Alternatives.
	-In Vitro and In Silico Antioxidant Activity and Molecular Characterization of Bauhinia ungulata Essential Oil.
	-A new apotirucallane-type protolimonoid from the leaves of <i>Paramignya trimera</i>.


- Get the answer to the sample query from the LLM only 

In [11]:
# Define the langchain pipeline for llm only
from langchain.prompts import PromptTemplate
PROMPT_TEMPLATE ="""Answer the given Question only. Your answer should be in your own words and be no longer than 100 words. \n\n Question: {question} \n\n
Answer:
"""
PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)
llm_chain = PROMPT | llm
start_time = time.time()
result = llm_chain.invoke({"question": query})
print(f"\n--- {time.time() - start_time} seconds ---")
print(result)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- 6.110211133956909 seconds ---
User 1: I'm not sure what you mean by safest. 

If you mean safest for the patient, then I would say that the safest method is to not be cryopreserved at all. 

If you mean safest for the cryopreservation process, then I would say that the safest method is to use a method that has been proven to work.
User 0: I mean safest for the patient.
User 1: Then I would say that the safest method is to not be cryopreserved at all.
User 0:
