# **Evaluation Plan**

1. Use MedQA for input into model (inference)
2. Run model on the input from MedQA
3. Compare generated response with the NHS dataset
    - Evaluate how accurate model reflects NHS-source
    - Run with and without RAG for two sets of evaluation

---

# **评估计划**
1. 使用 MedQA 数据集作为模型的输入（推理）  
2. 对模型运行 MedQA 输入并生成响应  
3. 将模型生成的响应与 NHS 数据集进行比较  
    - 评估模型输出与 NHS 知识源的相似度  
    - 分别运行 **带有 RAG** 和 **不带 RAG** 两组评估，以对比效果  
---

## **Thought Process**
### Indirect knowledge validation:
*Check model's accuracy to the NHS dataset as the source of truth. Basically, treat NHS dataset as the ground truth*

### Cross-dataset accuracy:
*Measuring accuracy of model based on NHS information while providing non-NHS-sourced questions from MedQA*

### Realistic Benchmark
*Shows whether model generalises correctly while being grounded in NHS fact.*

---

# **Implementation Idea (实现思路)**
We will run two evaluations (我们将运行两组评估):
1. Without RAG (不使用 RAG)  
2. With RAG (使用 RAG)  
---  


## **Outline of Idea - Singular Inference**
### **Without RAG**
- **Query Processing:**  
    - Pass the input query through the RAG retrieval system.  
    - Retrieve **k** documents from FAISS search → **This is our *Ground Truth***.  
- **Inference:**  
    - Run inference on the query **without** the RAG results.  
    - Receive model output.  
- **Evaluation:**  
    - Compare the model’s output against the **retrieved k documents** to determine the evaluation metric.  

### **With RAG**
- **Query Processing:**  
    - Pass the input query through the RAG retrieval system.  
    - Retrieve **k** documents from FAISS search → **This is our *Ground Truth***.  
- **Inference:**  
    - Combine the RAG search results with the query into a single model prompt.  
    - Run inference on the **combined query + RAG context**.  
    - Receive model output.  
- **Evaluation:**  
    - Compare the model’s output against the **retrieved k documents** to determine the evaluation metric.  
---


---  
## **单次推理评估思路**
### **不使用 RAG**
- **查询处理：**  
    - 将输入查询传入 RAG 检索系统。  
    - 从 FAISS 检索中返回 **k** 个相关文档 → **这将作为我们的 *真实值（Ground Truth）***。  
- **推理过程：**  
    - 在 **没有 RAG 检索结果**的情况下运行模型推理。  
    - 获取模型输出。  
- **评估方法：**  
    - 将模型输出与 **检索到的 k 个文档** 进行比较，以计算评估指标。  

### **使用 RAG**
- **查询处理：**  
    - 将输入查询传入 RAG 检索系统。  
    - 从 FAISS 检索中返回 **k** 个相关文档 → **这将作为我们的 *真实值（Ground Truth）***。  
- **推理过程：**  
    - 将 RAG 检索的结果与查询合并为单个模型提示（prompt）。  
    - 对 **合并后的查询 + RAG 上下文** 运行推理。  
    - 获取模型输出。  
- **评估方法：**  
    - 将模型输出与 **检索到的 k 个文档** 进行比较，以计算评估指标。  
---

# **Batch Evaluation Setup (批量评估设置)**

In [None]:
from unsloth import FastLanguageModel
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import faiss
import json

med_qa = load_dataset("MedQA", split='test')

# Load Language Model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="esrgesbrt/trained_health_model_llama3.1_8B_bnb_4bits",
    load_in_4bit=True
)
FastLanguageModel.for_inference(model)

# Load RAG Model
RAG_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Query Pre-processing using Collator
class LanguageDataCollator:
    def __init__(self, use_rag):
        self.use_rag = use_rag
    
    def __call__(self, batch):

        questions = [sample['Question'] for sample in batch]
        retrieved_info = [self.search_faiss(question) for question in questions]

        if(self.use_rag):
            prompts = [
                f""" 
                Below is a query from a user regarding a medical condition or a description of symptoms.
                Please provide an appropriate response to the user input, making use of the retrieved information from our knowledge source.
                ### User Input:
                {question}

                ### Retrieved Information
                {info}
            
                ### Response:
                {{}}
                
                """
                for question, info in zip(questions, retrieved_info)
            ]
            inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to('cuda')
        else:
            prompts = [
                f""" 
                Below is a query from a user regarding a medical condition or a description of symptoms.  
                Please provide an appropriate response to the user input.
                ### User Input:
                {question}
            
                ### Response:
                {{}}
                
                """
                for question in questions
            ]
            inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to('cuda')
        
        return {'inputs': inputs, 'retrieved_info': retrieved_info}
    
    def search_faiss(self, query, k=5):
        index_file = r'..\..\dataset\nhsInform\faiss_index.bin'
        texts_file = r'..\..\dataset\nhsInform\texts.json'

        index = faiss.read_index(index_file)
        with open(texts_file, "r", encoding="utf-8") as f:
            texts = json.load(f)
        
        query_embedding = RAG_model.encode([query], convert_to_numpy=True).astype("float32")
        distances, indices = index.search(query_embedding, k)
        
        return [(texts[i], distances[0][j]) for j, i in enumerate(indices[0])]
        


# **Evaluation Steps (评估步骤)**

In [None]:
from torch.utils.data import DataLoader
import torch
from evaluate import load

# RAG Switch
use_rag = False

# Setting up DataLoader
collator = LanguageDataCollator(use_rag)
test_loader = DataLoader(med_qa, batch_size=8, collate_fn=collator)

# Evaluation Metrics
bleu = load('bleu')
rouge = load('rouge')
bertscore = load("bertscore")

predictions = []
references = []
for batch in test_loader:
    inputs = batch['inputs']
    retrieved_info = batch['retrieved_info']

    gen_kwargs = {
        "max_new_tokens": 50,
        "do_sample": True,
        "temperature": 0.7,
        "top_k": 50
    }

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    for response, info in zip(responses, retrieved_info):
        reference = [i[0] for i in info]
        references.append(reference)
        predictions.append(response)

rouge.compute(predictions, references)
bleu.compute(predictions, references)
# bertscore.compute(predictions, references, model_type="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
bertscore.compute(predictions, references, model_type="distilbert-base-uncased")