In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="5"

import pandas as pd
import warnings
warnings.filterwarnings("ignore")
import faiss
import numpy as np
import torch

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict
from sentence_transformers import CrossEncoder

### Get Retriever with Vectorstore

In [2]:
class MyRetriever:
    def __init__(
            self,
            raw_data_path:str="",
            retriever_type:str="contriever",
            chunk_size:int=256,
            device:str="cuda"
        ):

        self.raw_data_path = raw_data_path
        self.retriever_type = retriever_type
        if retriever_type == 'domain_detector':
            self.chunk_size = None
        else:
            self.chunk_size = chunk_size
        self.device = device

        self.texts = self._load_dataset()

        self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
        
        # Dense - Contriever, MedCPT, Specter, RRF-2
        if self.retriever_type in ["contriever", "specter", "longformer", "medcpt", "rrf2"]:
            if self.retriever_type == 'contriever':
                model_name = "facebook/contriever"
            elif self.retriever_type == 'specter':
                model_name = "allenai/specter"
            elif self.retriever_type == 'longformer':
                model_name = "allenai/longformer-base-4096"
            elif self.retriever_type == "medcpt":
                model_name = "ncbi/MedCPT-Article-Encoder"
            else:
                model_name = "facebook/contriever"

            self.model = AutoModel.from_pretrained(model_name).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            text_embedding = self._compute_embeddings(self.texts)
            text_embedding = text_embedding.cpu()
            self.faiss_index = self._create_dense_retriever(text_embedding)
            # print("FAISS index created.")

        # Sparse - BM25, TF-IDF
        self.bm25_index = self._get_bm25_index(self.texts)

    def _compute_embeddings(self, texts):
        def mean_pooling(token_embeddings, mask):
            token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
            sentence_embeddings = (token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]).detach()
            return sentence_embeddings
        
        all_embeddings = []
        with torch.no_grad():
            for i in range(0, len(texts), 32):
                batch_texts = texts[i:i+32]
                inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=self.model.config.max_position_embeddings).to(self.device)
                outputs = self.model(**inputs)
                embeddings = mean_pooling(outputs[0], inputs['attention_mask']).cpu()
                embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
                all_embeddings.append(embeddings)
        return torch.cat(all_embeddings, dim=0)

    def _compute_query_embedding(self, query):
        return self._compute_embeddings([query]).numpy()
    
    def _create_dense_retriever(self, text_embedding):
        d = text_embedding.shape[1]
        faiss_index = faiss.IndexFlatIP(d) # [IndexFlatIP, IndexFlatL2]
        faiss_index.add(np.ascontiguousarray(text_embedding.numpy()))
        return faiss_index
    
    def _get_dense_results(self, query, k=5):
        query_embedding = self._compute_query_embedding(query)
        distances, indices = self.faiss_index.search(query_embedding, k)
        faiss_results = []
        for idx_list, dist_list in zip(indices, distances):
            for idx, dist in zip(idx_list, dist_list):
                faiss_results.append((self.texts[idx], dist))
        return faiss_results[:k]
    
    def _get_bm25_index(self, texts):
        tokenized_corpus = [text.lower().split() for text in texts]
        bm25_index = BM25Okapi(tokenized_corpus)
        return bm25_index
    
    def _get_bm25_results(self, query):
        query_tokens = query.lower().split()
        bm25_scores = self.bm25_index.get_scores(query_tokens)
        bm25_results = [(self.texts[i], bm25_scores[i]) for i in range(len(bm25_scores))]
        bm25_results = sorted(bm25_results, key=lambda x: x[1], reverse=True)
        return bm25_results
    
    def _load_dataset(self):
        texts = []
        for root, _, files in os.walk(self.raw_data_path):  # Recursively walk through directories
            for file_name in files:
                file_path = os.path.join(root, file_name)
                if file_name.endswith(".txt"):
                    with open(file_path, "r", encoding="utf-8") as f:
                        text = f.read().strip()
                        if self.chunk_size is not None:
                            texts.extend(self._split_text(text, self.chunk_size))
                        else:
                            texts.append(text)  # Add the entire text as a single element

                elif file_name.endswith(".csv"):
                    df = pd.read_csv(file_path)
                    if "summary" in df.columns:
                        texts.extend(df["summary"].dropna().tolist())

        if self.chunk_size is not None:
            concatenated_text = " ".join(texts)
            texts = self._split_text(concatenated_text, self.chunk_size)
        # print(f"Loaded {len(texts)} documents from directory: {self.raw_data_path}")
        return texts

    def _ensemble_scores(self, *results_with_weights):
        """
        Args: results_with_weights: Tuples of (results, weight) where results are lists of (text, score).
        Returns: List of (text, combined_score) tuples sorted by score in descending order.
        """
        combined_scores = defaultdict(float)

        for results, weight in results_with_weights:
            scores_dict = {text: score for text, score in results}
            values = np.array(list(scores_dict.values()))
            if len(values) > 0:
                min_val, max_val = values.min(), values.max()
                if max_val > min_val:
                    scores_dict = {k: (v - min_val) / (max_val - min_val) for k, v in scores_dict.items()}

            for text, score in scores_dict.items():
                combined_scores[text] += weight * score

        sorted_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
        return sorted_results
    
    def _split_text(self, texts, chunk_size, chunk_overlap=50):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )
        return text_splitter.split_text(texts)

    def search(self, query:str, k:int=5, rerank:bool=False):
        retrieved_docs = None
        if self.retriever_type in ["contriever", "longformer", "specter", "medcpt"]:
            retrieved_docs = self._get_dense_results(query, k=k)

        elif self.retriever_type == "bm25":
            retrieved_docs = self._get_bm25_results(query)[:k]
        
        elif self.retriever_type == "rrf2":
            faiss_results = self._get_dense_results(query, k=k)
            bm25_results = self._get_bm25_results(query)
            retrieved_docs = self._ensemble_scores((faiss_results, 0.5), (bm25_results, 0.5))[:k]
        else:
            raise ValueError("Invalid retriever type. Choose from ['contriever', 'specter', 'medcpt', 'bm25', 'rrf2']")
        if rerank:
            retrieved_docs = self.rerank(query, retrieved_docs, self.reranker, top_n=2) # top_n = final value of k
        return retrieved_docs
    
    def rerank(self, question, docs, reranker, top_n=2):
        pairs = [(question, doc.page_content if hasattr(doc, 'page_content') else doc[0]) for doc in docs]
        scores = reranker.predict(pairs)
        scored_docs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
        return [d[0] for d in scored_docs[:top_n]]

### Model

In [3]:
from langchain_ollama import ChatOllama

slm_model = ChatOllama(model="llama3.2:1b", temperature=0)
llm_model = ChatOllama(model="gemma:7b", temperature=0) # $ ollama pull gemma:7b to run a new model

In [4]:
# Run the model with a prompt
def run_simple_prompt(model, question, options=None):
    template = None
    if options is None:
        template = '''You are an AI assistant for grad students.
        Answer the following question in concise sentences.
        Question: {question}
        '''
    else:
        template = '''You are an AI assistant for grad students.
        Answer the question and select the most appropriate answer from the given choices.
        You must return only the correct answer (a, b, c or d) and nothing else.
        
        Question: {question}
        Choices:
        {options}

        You must return a single character (a, b, c or d). Do not provide any explanation.
        '''
    prompt = ChatPromptTemplate.from_template(template)
    simple_chain = (prompt | model | StrOutputParser())
    answer_without_rag = simple_chain.invoke({"question": question, "options": options}, config={"timeout": 30}).strip()
    return answer_without_rag

question = "You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?"
# options = ["Buffer Analysis", "Network Analysis", "Raster Reclassification", "Thiessen Polygon Generation"]
options = None
slm_answer_without_rag = run_simple_prompt(slm_model, question, options)
llm_answer_without_rag = run_simple_prompt(llm_model, question, options)

print("Query (No RAG):", question)
print("sLM Answer (No RAG):", slm_answer_without_rag)
print("LLM Answer (No RAG):", llm_answer_without_rag)

Query (No RAG): You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?
sLM Answer (No RAG): The least directly relevant geoprocessing technique for initial site selection would be 3D visualization or spatial analysis using tools like ArcGIS Pro or QGIS.

These tools allow for the creation of detailed, interactive maps that can help visualize and analyze data related to population density, road networks, and healthcare facilities. However, they are not typically used as a primary step in the site selection process, which often involves more straightforward calculations and data analysis using techniques like linear programming or regression modeling.

Other relevant geoprocessing techniques for site selection might include:

* Spatial autocorrelation anal

In [5]:
# Run the model with a prompt and RAG
def run_rag_chain(model, retriever, question, options=None, k=2, is_rerank=False):
    def format_docs(docs):
        if hasattr(docs[0], 'page_content'):
            return '\n\n'.join([d.page_content for d in docs])
        else:
            return '\n\n'.join([d[0] for d in docs])
        # return '\n\n'.join([ d.page_content for d in docs])
    
    if options is None:
        template = '''
        Here are the relevant documents: {context}
        
        You are an AI assistant for graduate students.
        Use the provided context to answer the question.
        Answer the following question in concise sentences.

        Question: {question}
        
        Output only the answer. Do not provide any explanation.
        '''
    else:
        template = '''
        Here are the relevant documents: {context}
        
        You are an AI assistant for graduate students.
        Use the provided context to select the most accurate answer.
        Return only one letter (a, b, c, or d) with no additional text or explanation.

        Question: {question}
        Potential choices: {options}

        Output only one character (a, b, c, or d) with no explanation.
        '''

    prompt = ChatPromptTemplate.from_template(template)
    retrieved_docs = retriever.search(query=question, k=k, rerank=is_rerank)
    formatted_context = format_docs(retrieved_docs)
    rag_chain = (prompt | model | StrOutputParser())
    answer = rag_chain.invoke({"context": formatted_context, "question": question, "options": options},config={"timeout": 30}).strip()
    return answer

question = "You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?"
# options = ["Buffer Analysis", "Network Analysis", "Raster Reclassification", "Thiessen Polygon Generation"]
options = None
retriever = MyRetriever(
    raw_data_path ="data/text/random",
    retriever_type='contriever',
    chunk_size=512
)
answer = run_rag_chain(slm_model, retriever, question, options)

print("Query (RAG):", question)
print("Answer (RAG):", answer)

config.json:   0%|          | 0.00/845 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

Query (RAG): You are tasked with determining the optimal location for a new hospital, considering factors like population density, proximity to major roads, and distance from existing healthcare facilities. Which geoprocessing technique is least directly relevant to this initial site selection process?
Answer (RAG): ArcGIS Network Analysis.


### Evaluation

In [6]:
# TEST_DATASET = "MedMCQA", "MedQuad", "MMLU"
TEST_DATASET = "MMLU"

In [7]:
import pandas as pd
from dataset import QuestionAnswering

qa = QuestionAnswering(TEST_DATASET)
qna_df = qa.get_question_answering_dataframe()
# qna_df['subject'].unique() # MedMCQA

Loading dataset...


In [8]:
import time
from tqdm import tqdm
from rouge import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from bert_score import score
from transformers import AutoTokenizer

bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
MAX_BERT_TOKENS = 512

def truncate_and_decode(text, tokenizer, max_len=512):
    enc = tokenizer(text, truncation=True, max_length=max_len, return_tensors="pt", add_special_tokens=False)
    return tokenizer.decode(enc["input_ids"][0], skip_special_tokens=True)

rouge = Rouge()
smooth_fn = SmoothingFunction().method1

correct_answers = 0
total_samples = len(qna_df)

chunk_size = 512
k = 10
is_rerank = True

results = []

#  ["llm_wo_knowledge", "slm_wo_knowledge", "slm_random", "slm_contriever", "slm_specter", "slm_longformer", "slm_medcpt", "slm_bm25", "slm_rrf2"]
for retriever_type in ["slm_longformer", "slm_medcpt", "slm_bm25", "slm_rrf2"]:
    print(f"\n{retriever_type.upper()} :")
    
    if retriever_type not in ["llm_wo_knowledge", "slm_wo_knowledge"]:
        retriever_name = retriever_type.split("_")[-1]
        retriever_path = "random" if retriever_name == "random" else TEST_DATASET.lower()
        retriever = MyRetriever(
            raw_data_path=f"data/text/{retriever_path}",
            retriever_type="contriever" if retriever_name == "random" else retriever_name,
            chunk_size=chunk_size
        )

    if "options" in qna_df.columns and qna_df["options"].apply(lambda x: isinstance(x, list) and len(x) > 0).all():
        evaluation_mode = "multiple_choice"
        print("Running evaluation for multiple choice mode...")
    else:
        evaluation_mode = "sentence"
        print("Running evaluation for sentence mode...")

    scores_list = []
    start_time = time.time()

    for idx, (index, sample) in enumerate(tqdm(qna_df.iterrows(), total=len(qna_df))):
        question = sample["question"].lower()
        options = list(map(str.lower, sample["options"])) if evaluation_mode == "multiple_choice" else None
        gt_answer = sample["answer"].lower()

        if retriever_type == "llm_wo_knowledge":
            pred_answer = run_simple_prompt(llm_model, question, options).lower()
        elif retriever_type == "slm_wo_knowledge":
            pred_answer = run_simple_prompt(slm_model, question, options).lower()
        else:
            pred_answer = run_rag_chain(slm_model, retriever, question, options, k, is_rerank).lower()
        
        if evaluation_mode == "multiple_choice":
            if pred_answer not in ['a', 'b', 'c', 'd']:
                print("Invalid prediction:", pred_answer)
            correct_answers += int(pred_answer == gt_answer)

        else:
            bleu = sentence_bleu([gt_answer.split()], pred_answer.split(), smoothing_function=smooth_fn)
            rouge_l = rouge.get_scores(pred_answer, gt_answer)[0]['rouge-l']['f']
            
            # Truncate for BERTScore
            truncated_pred = truncate_and_decode(pred_answer, bert_tokenizer)
            truncated_gt = truncate_and_decode(gt_answer, bert_tokenizer)
            try:
                _, _, bert_f1 = score([truncated_pred], [truncated_gt], lang='en', model_type='allenai/scibert_scivocab_uncased')
                bert_score_value = bert_f1.item()
            except RuntimeError:
                bert_score_value = 0.0

            scores_list.append({
                "BLEU": bleu,
                "ROUGE-L": rouge_l,
                "BERTScore": bert_f1.item()
            })

    end_time = time.time()
    processing_time = end_time - start_time

    if evaluation_mode == "multiple_choice":
        acc = correct_answers / total_samples * 100
        results.append({
            "Retriever": retriever_type.upper(),
            "K / Chunk Size": f"{k} / {chunk_size}",
            "Accuracy (%)": f"{acc:.2f}",
            "Processing Time (s)": f"{processing_time:.2f}",
        })
        print(f"{retriever_type}, {k} / {chunk_size}, {acc:.2f}, {processing_time:.2f}")
        correct_answers = 0

    else:
        mean_bleu = sum(s["BLEU"] for s in scores_list) / total_samples
        mean_rouge = sum(s["ROUGE-L"] for s in scores_list) / total_samples
        mean_bert = sum(s["BERTScore"] for s in scores_list) / total_samples
        results.append({
            "Retriever": retriever_type.upper(),
            "Rerank": is_rerank,
            "K / Chunk Size": f"{k} / {chunk_size}",
            "BLEU": f"{mean_bleu:.4f}",
            "ROUGE-L": f"{mean_rouge:.4f}",
            "BERTScore": f"{mean_bert:.4f}",
            "Processing Time (s)": f"{processing_time:.2f}",
        })
        print(f"{retriever_type}, BLEU: {mean_bleu:.4f}, ROUGE-L: {mean_rouge:.4f}, BERTScore: {mean_bert:.4f}, Time: {processing_time:.2f}")

results_df = pd.DataFrame(results)
print(f"\nEvaluation Result:")
print(results_df.to_markdown(index=False))



SLM_LONGFORMER :


Input ids are automatically padded to be a multiple of `config.attention_window`: 512


Running evaluation for sentence mode...


100%|██████████| 1000/1000 [09:57<00:00,  1.67it/s]


slm_longformer, BLEU: 0.0003, ROUGE-L: 0.0048, BERTScore: 0.4930, Time: 597.62

SLM_MEDCPT :
Running evaluation for sentence mode...


100%|██████████| 1000/1000 [09:55<00:00,  1.68it/s]


slm_medcpt, BLEU: 0.0004, ROUGE-L: 0.0053, BERTScore: 0.4933, Time: 595.24

SLM_BM25 :
Running evaluation for sentence mode...


100%|██████████| 1000/1000 [08:59<00:00,  1.85it/s]


slm_bm25, BLEU: 0.0007, ROUGE-L: 0.0070, BERTScore: 0.4935, Time: 539.30

SLM_RRF2 :
Running evaluation for sentence mode...


100%|██████████| 1000/1000 [09:13<00:00,  1.81it/s]

slm_rrf2, BLEU: 0.0006, ROUGE-L: 0.0058, BERTScore: 0.4934, Time: 553.39

Evaluation Result:
| Retriever      | Rerank   | K / Chunk Size   |   BLEU |   ROUGE-L |   BERTScore |   Processing Time (s) |
|:---------------|:---------|:-----------------|-------:|----------:|------------:|----------------------:|
| SLM_LONGFORMER | True     | 10 / 512         | 0.0003 |    0.0048 |      0.493  |                597.62 |
| SLM_MEDCPT     | True     | 10 / 512         | 0.0004 |    0.0053 |      0.4933 |                595.24 |
| SLM_BM25       | True     | 10 / 512         | 0.0007 |    0.007  |      0.4935 |                539.3  |
| SLM_RRF2       | True     | 10 / 512         | 0.0006 |    0.0058 |      0.4934 |                553.39 |



