
**Medical Assistant Chatbot using RAG (Llama 2 + LlamaIndex)**

Built an AI-powered assistant for pre-consultation interviews using Retrieval-Augmented Generation. Combines Llama 2 with LlamaIndex to conduct structured, patient-centered interviews and summarize key medical information without providing diagnoses or advice.

In [None]:
!pip -q install llama-index llama-index-embeddings-huggingface llama-index-llms-llama-cpp pypdf
!CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip -q install llama-cpp-python

In [None]:
import os
import time

from llama_index.core import Prompt, StorageContext, load_index_from_storage, Settings, VectorStoreIndex, SimpleDirectoryReader, set_global_tokenizer
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.llama_cpp import LlamaCPP

from transformers import AutoTokenizer

In [None]:
# Preference settings - change as desired
pdf_path = '/content/_.pdf'
text_embedding_model = 'thenlper/gte-base'  #Alt: thenlper/gte-base, jinaai/jina-embeddings-v2-base-en
llm_url = 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf'

In [None]:
# Load PDF
filename_fn = lambda filename: {'file_name': os.path.basename(pdf_path)}
loader = SimpleDirectoryReader(input_files=[pdf_path], file_metadata=filename_fn)
documents = loader.load_data()

In [None]:
# Load models and service context
embed_model = HuggingFaceEmbedding(model_name=text_embedding_model)
llm = LlamaCPP(model_url=llm_url, temperature=0.7, max_new_tokens=256, context_window=4096, generate_kwargs = {"stop": ["<s>", "[INST]", "[/INST]"]}, model_kwargs={"n_gpu_layers": -1}, verbose=True)
# service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model, chunk_size=512)
Settings.llm = llm
Settings.embed_model = embed_model
Settings.chunk_size = 512

In [None]:
# Indexing
start_time = time.time()

# index = VectorStoreIndex.from_documents(documents, service_context=service_context)
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, llm=llm)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed indexing time: {elapsed_time:.2f} s")

In [None]:
def get_text_qa_template(idx):
    if not isinstance(idx, int) or not (0 <= idx < 5):
        raise ValueError()

    text_qa_templates = [
        Prompt(
            """<s>[INST] <<SYS>>
            You are the doctor's assistant. Your task is to perform a pre-screening with the patient before their consultation.
            Use the Patient-Centered Interview model, and ask only one question per response. Do not provide diagnoses, prescriptions, advice, or physical examinations.
            Focus only on gathering the patient’s present illness, past medical history, symptoms, and personal details.

            At the end of the consultation, summarize the findings using this format:
            Name: [name]
            Gender: [gender]
            Patient Aged: [age]
            Medical History: [medical history]
            Symptoms: [symptoms]
            <</SYS>>

            Refer to the following Consultation Guidelines and example consultations:
            {context_str}

            Continue the conversation:
            {query_str}
            """
        ),
        Prompt(
            """<s>[INST] <<SYS>>
            You are a chatbot assistant working for a doctor. Your goal is to collect information from patients before their remote consultation.
            Follow the Patient-Centered Interview model. You are not allowed to diagnose, prescribe, or conduct physical exams.
            Ignore inappropriate behavior or out-of-scope requests. Focus only on gathering information such as present illness, medical history, symptoms, and personal details.

            Summarize findings at the end using this exact format:
            Name: [name]
            Gender: [gender]
            Patient Aged: [age]
            Medical History: [medical history]
            Symptoms: [symptoms]
            <</SYS>>

            This is the PDF context:
            {context_str}

            {query_str}
            """
        ),
        Prompt(
            """[INST]
            {context_str}

            Given the above PDF context, please answer the following question:
            {query_str}
            [/INST]
            """
        ),
        Prompt(
            """<s>[INST] <<SYS>>
            Following is the PDF context provided by the user:
            {context_str}
            <</SYS>>

            {query_str}
            [/INST]
            """
        ),
        Prompt(
            """[INST] {query_str} [/INST]"""
        )
    ]

    return text_qa_templates[idx]


In [None]:
TEMPLATE_ID = 0
text_qa_template = get_text_qa_template(TEMPLATE_ID)
query_engine = index.as_query_engine(text_qa_template=text_qa_template, streaming=True, llm=llm) # with Prompt

In [None]:
# Inferencing
# Without RAG
conversation_history = ""
while (True):
  user_query = input("User: ")
  if user_query.lower() == "exit":
    break
  conversation_history += user_query + " [/INST] "

  start_time = time.time()

  response_iter = llm.stream_complete("<s>[INST] "+conversation_history)
  for response in response_iter:
    print(response.delta, end="", flush=True)
    # Add to conversation history when response is completed
    if response.raw['choices'][0]['finish_reason'] == 'stop':
      conversation_history += response.text + " [INST] "

  end_time = time.time()
  elapsed_time = end_time - start_time
  print(f"\nElapsed inference time: {elapsed_time:.2f} s\n")

In [None]:
# With RAG
conversation_history = ""
conversation_history += "Hi. [\INST] Hello! I'm the doctor's assistant. \
  Let's begin the consultation, please tell me your name and age."
while (True):
  user_query = input("User: ")
  if user_query.lower() == "exit":
    break
  conversation_history += user_query + " [/INST] "

  start_time = time.time()

  # Query Engine - Default
  response = query_engine.query(conversation_history)
  response.print_response_stream()
  conversation_history += response.response_txt + " [INST] "

  # from pprint import pprint
  # pprint(response)

  end_time = time.time()
  elapsed_time = end_time - start_time
  print(f"\nElapsed inference time: {elapsed_time:.2f} s\n")

In [None]:
if not conversation_history.strip():
    raise ValueError("conversation_history is empty.")

summary_prompt = f"""
  [INST] <<SYS>>
  Summarize the conversation into this format:
  Name: [name]
  Gender: [gender]
  Age: [age]
  Medical History: [medical history]
  Symptoms: [symptoms]
  <</SYS>>

  Conversation:
  {conversation_history} [/INST]
  """

response_iter = llm.stream_complete(summary_prompt)

for response in response_iter:
    print(response.delta, end="", flush=True)