In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="5"

import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import faiss
import numpy as np
import torch

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers import TFIDFRetriever
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from functools import lru_cache
from collections import defaultdict

### Get Retriever with Vectorstore

In [None]:
class MyRetriever:
    def __init__(
            self,
            raw_data_path:str="",
            retriever_type:str="contriever",
            chunk_size:int=256,
            device:str="cuda"
        ):

        self.raw_data_path = raw_data_path
        self.retriever_type = retriever_type
        self.chunk_size = chunk_size
        self.device = device

        self.texts = self._load_dataset()
    
        if self.retriever_type == "medcpt":
            model_name = "ncbi/MedCPT-Article-Encoder"
        elif self.retriever_type == 'specter':
            model_name = "allenai/specter"
        else:
            model_name = "facebook/contriever"
        
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=chunk_size)
        # print(f"Model and Tokenizer {model_name} loaded.")
        
        # Dense - Contriever, MedCPT, Specter, RRF-2
        if self.retriever_type in ["contriever", "medcpt", "specter", "rrf2"]:
            text_embedding = self._compute_embeddings(self.texts)
            text_embedding = text_embedding.cpu()
            self.faiss_index = self._create_dense_retriever(text_embedding)
            # print("FAISS index created.")

        # Sparse - BM25, TF-IDF
        self.bm25_index = self._get_bm25_index(self.texts)
        self.tfidf_retriever = TFIDFRetriever.from_texts(self.texts)

    def _compute_embeddings(self, texts):
        def mean_pooling(token_embeddings, mask):
            token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
            sentence_embeddings = (token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]).detach()
            return sentence_embeddings
        
        all_embeddings = []
        with torch.no_grad():
            for i in range(0, len(texts), 32):
                batch_texts = texts[i:i+32]
                inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
                outputs = self.model(**inputs)
                embeddings = mean_pooling(outputs[0], inputs['attention_mask']).cpu()
                embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
                all_embeddings.append(embeddings)
        return torch.cat(all_embeddings, dim=0)

    def _compute_query_embedding(self, query):
        return self._compute_embeddings([query]).numpy()
    
    def _create_dense_retriever(self, text_embedding):
        d = text_embedding.shape[1]
        faiss_index = faiss.IndexFlatIP(d) # [IndexFlatIP, IndexFlatL2]
        faiss_index.add(np.ascontiguousarray(text_embedding.numpy()))
        return faiss_index
    
    def _get_dense_results(self, query, k=5):
        query_embedding = self._compute_query_embedding(query)
        distances, indices = self.faiss_index.search(query_embedding, k)
        faiss_results = []
        for idx_list, dist_list in zip(indices, distances):
            for idx, dist in zip(idx_list, dist_list):
                faiss_results.append((self.texts[idx], dist))
        return faiss_results[:k]
    
    def _get_bm25_index(self, texts):
        tokenized_corpus = [text.lower().split() for text in texts]
        bm25_index = BM25Okapi(tokenized_corpus)
        return bm25_index
    
    def _get_bm25_results(self, query):
        query_tokens = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(query_tokens)
        bm25_results = [(self.texts[i], bm25_scores[i]) for i in range(len(bm25_scores))]
        bm25_results = sorted(bm25_results, key=lambda x: x[1], reverse=True)
        return bm25_results
    
    def _load_dataset(self):
        texts = []
        for file_name in os.listdir(self.raw_data_path):
            file_path = os.path.join(self.raw_data_path, file_name)
            if file_name.endswith(".txt"):
                with open(file_path, "r", encoding="utf-8") as f:
                    texts.append(f.read().strip())
            elif file_name.endswith(".csv"):
                df = pd.read_csv(file_path)
                if "summary" in df.columns:
                    texts.extend(df["summary"].dropna().tolist())

        concatenated_text = " ".join(texts)
        texts = self._split_text(concatenated_text, self.chunk_size)
        # print(f"Loaded {len(texts)} documents from directory: {self.raw_data_path}")
        return texts

    def _ensemble_scores(self, *results_with_weights):
        """
        Args: results_with_weights: Tuples of (results, weight) where results are lists of (text, score).
        Returns: List of (text, combined_score) tuples sorted by score in descending order.
        """
        combined_scores = defaultdict(float)

        for results, weight in results_with_weights:
            scores_dict = {text: score for text, score in results}
            values = np.array(list(scores_dict.values()))
            if len(values) > 0:
                min_val, max_val = values.min(), values.max()
                if max_val > min_val:
                    scores_dict = {k: (v - min_val) / (max_val - min_val) for k, v in scores_dict.items()}

            for text, score in scores_dict.items():
                combined_scores[text] += weight * score

        sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_results
    
    def _split_text(self, texts, chunk_size, chunk_overlap=50):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )
        return text_splitter.split_text(texts)

    def search(self, query:str, k:int=5):
        if self.retriever_type in ["contriever", "medcpt", "specter"]:
            return self._get_dense_results(query, k=k)

        elif self.retriever_type == "bm25":
            return self._get_bm25_results(query)[:k]
        
        elif self.retriever_type == "tfidf":
            return self.tfidf_retriever.invoke(query, k=k)
        
        elif self.retriever_type == "rrf2":
            faiss_results = self._get_dense_results(query, k=k)
            bm25_results = self._get_bm25_results(query)
            return self._ensemble_scores((faiss_results, 0.5), (bm25_results, 0.5))[:k]
        else:
            raise ValueError("Invalid retriever type. Choose from ['contriever', 'specter', 'medcpt', 'bm25', 'tfidf', 'rrf2']")


In [3]:
retriever = MyRetriever(raw_data_path="../data/raw/medical_googleai", retriever_type="contriever")

# Search query with ensemble of Specter + BM25
query = "What is neurology?"
results = retriever.search(query, k=8)

# Print results
for text, score in results:
    print(f"Text: {text}\nScore: {score}\n")

Text: Clinical neurology begins with the neurological examination, the most important diagnostic tool. It systematically assesses: Mental Status (level of consciousness, orientation, attention, memory, language, executive function); Cranial Nerves (I-XII);
Score: 0.7006931900978088

Text: of key concepts in Neurology. Remember to integrate your knowledge of basic neuroscience with clinical findings, and to practice the neurological examination until it becomes second nature. Continuous learning and staying current with new developments are
Score: 0.6949046850204468

Text: Diagnostic techniques in neurology include neuroimaging: computed tomography (CT: head CT, CT angiography) and magnetic resonance imaging (MRI: brain MRI, spinal cord MRI, MR angiography, MR spectroscopy). Positron emission tomography (PET) assesses brain
Score: 0.692385733127594

Text: Therapeutic approaches in neurology include pharmacological therapies: antiepileptic drugs (AEDs) for seizures, anti-Parkinsonian med

### Model

In [4]:
from langchain_ollama import ChatOllama

model = ChatOllama(model="llama3.2:1b", temperature=0)

In [5]:
# Run the model with a prompt
def run_simple_prompt(model, question, options):
    template = '''You are an AI assistant for grad students.
    Answer the question and select the most appropriate answer from the given choices.
    You must return only the correct answer (a, b, c or d) and nothing else.
    
    Question: {question}
    Choices:
    {options}

    You must return a single character (a, b, c or d). Do not provide any explanation.
    '''
    prompt = ChatPromptTemplate.from_template(template)
    simple_chain = (prompt | model | StrOutputParser())
    answer_without_rag = simple_chain.invoke({"question": question, "options": options})
    return answer_without_rag

question = "You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?"
options = ["Buffer Analysis", "Network Analysis", "Raster Reclassification", "Thiessen Polygon Generation"]
answer_without_rag = run_simple_prompt(model, question, options)

print("Query (No RAG):", question)
print("Answer (No RAG):", answer_without_rag)

Query (No RAG): You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?
Answer (No RAG): c


In [6]:
# Run the model with a prompt and RAG
def run_rag_chain(model, retriever, question, options, k=2):
    def format_docs(docs):
        if hasattr(docs[0], 'page_content'):
            return '\n\n'.join([d.page_content for d in docs])
        else:
            return '\n\n'.join([d[0] for d in docs])
        # return '\n\n'.join([ d.page_content for d in docs])

    template = '''
    Here are the relevant summaries: {context}
    
    You are an AI assistant for graduate students.
    Use the provided context to select the most accurate answer.
    Return only one letter (a, b, c, or d) with no additional text or explanation.

    Question: {question}
    Potential choices: {options}

    Output only one character (a, b, c, or d) with no explanation.
    '''

    prompt = ChatPromptTemplate.from_template(template)
    retrieved_docs = retriever.search(query=question, k=k)
    formatted_context = format_docs(retrieved_docs)
    rag_chain = (prompt | model | StrOutputParser())

    answer = rag_chain.invoke({"context": formatted_context, "question": question, "options": options}).strip()
    return answer

# question = "You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?"
# options = ["Buffer Analysis", "Network Analysis", "Raster Reclassification", "Thiessen Polygon Generation"]
# answer = run_rag_chain(model, question, options)

# print("Query (RAG):", question)
# print("Answer (RAG):", answer)

### Evaluation

In [9]:
import pandas as pd
import json
from datasets import load_dataset

"""
Prompt-made Computer Science dataset
"""
# with open("../data/qna/computer_science.json", "r", encoding="utf-8") as f:
#     data = json.load(f)
# qna_df = pd.DataFrame([{
#     "question": item["question"],
#     "options": item["options"],
#     "answer": item["answer_idx"]
# } for item in data])

"""
GBaker/MedQA-USMLE-4-options-hf
"""
dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf", split="test") # 4 choices
qna_df = dataset.to_pandas()
qna_df.rename(columns={"sent1": "question"}, inplace=True)
answer_mapping = {0: 'a', 1: 'b', 2: 'c', 3: 'd'}
qna_df['answer'] = qna_df['label'].map(answer_mapping)
qna_df['options'] = qna_df.apply(lambda row: [row['ending0'], row['ending1'], row['ending2'], row['ending3']], axis=1)
# print(qna_df.head())

"""
MedMCQA dataset
"""
# dataset = load_dataset("openlifescienceai/medmcqa", split="validation") # 4 choices / subject_name
# qna_df = dataset.to_pandas()
# qna_df.rename(columns={"opa": "optionA", "opb": "optionB", "opc": "optionC", "opd": "optionD", "cop": "answerindex"}, inplace=True)
# answer_mapping = {0: 'a', 1: 'b', 2: 'c', 3: 'd'}
# qna_df['answer'] = qna_df['answerindex'].map(answer_mapping)
# qna_df['options'] = qna_df.apply(lambda row: [row['optionA'], row['optionB'], row['optionC'], row['optionD']], axis=1)
# # qna_df['subject_name'].unique()

'\nMedMCQA dataset\n'

In [11]:
import time
from tqdm import tqdm

correct_answers = 0
total_samples = len(qna_df)

chunk_size = 256
k = 2

results = []

for retriever_type in ["naiveq&a", "contriever", "specter", "medcpt", "bm25", "tfidf", "rrf2"]:
    print(f"\n{retriever_type.upper()} :")
    
    if retriever_type != "naiveq&a":
        retriever = MyRetriever(
            raw_data_path ="../data/raw/computer_science",
            retriever_type=retriever_type,
            chunk_size=chunk_size
        )

    start_time = time.time()
    for idx, (index, sample) in enumerate(tqdm(qna_df.iterrows(), total=len(qna_df))):
        question = sample["question"]
        options = sample["options"]
        gt_answer = sample["answer"].lower()

        if retriever_type == "naiveq&a":
            pred_answer = run_simple_prompt(model, question, options).lower()
        else:
            pred_answer = run_rag_chain(model, retriever, question, options, k).lower()

        if pred_answer not in ['a', 'b', 'c', 'd']:
            print("pred_answer:", pred_answer)

        correct_answers += int(pred_answer == gt_answer)
    end_time = time.time()

    acc = correct_answers / total_samples * 100
    processing_time = end_time - start_time
    
    results.append({
        "Retriever": retriever_type.upper(),
        "K / Chunk Size": f"{k} / {chunk_size}",
        "Accuracy (%)": f"{acc:.2f}",
        "Processing Time (s)": f"{processing_time:.2f}",
    })

    correct_answers = 0

results_df = pd.DataFrame(results)
print(f"\n🎯 Accuracy Result:")
print(results_df.to_markdown(index=False))



NAIVEQ&A :


100%|██████████| 1273/1273 [01:32<00:00, 13.78it/s]



CONTRIEVER :


100%|██████████| 1273/1273 [02:03<00:00, 10.33it/s]



SPECTER :


100%|██████████| 1273/1273 [02:07<00:00, 10.02it/s]



MEDCPT :


100%|██████████| 1273/1273 [02:00<00:00, 10.58it/s]



BM25 :


100%|██████████| 1273/1273 [02:21<00:00,  9.02it/s]



TFIDF :


100%|██████████| 1273/1273 [01:51<00:00, 11.45it/s]



RRF2 :


100%|██████████| 1273/1273 [02:33<00:00,  8.30it/s]


🎯 Accuracy Result:
| Retriever   | K / Chunk Size   |   Accuracy (%) |   Processing Time (s) |
|:------------|:-----------------|---------------:|----------------------:|
| NAIVEQ&A    | 2 / 256          |          27.49 |                 92.39 |
| CONTRIEVER  | 2 / 256          |          22.15 |                123.23 |
| SPECTER     | 2 / 256          |          23.8  |                127.03 |
| MEDCPT      | 2 / 256          |          22.07 |                120.32 |
| BM25        | 2 / 256          |          23.25 |                141.18 |
| TFIDF       | 2 / 256          |          22.78 |                111.17 |
| RRF2        | 2 / 256          |          21.76 |                153.42 |



