# Setup OS Env

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'

# Import Libraries

In [None]:
from unsloth import FastVisionModel, FastLanguageModel

from torch.utils.data import DataLoader
import torch
from evaluate import load
from concurrent.futures import ThreadPoolExecutor, as_completed
from sentence_transformers import SentenceTransformer
from transformers import Qwen2VLImageProcessor
from datasets import load_dataset
import faiss
import json
import os
from tqdm import tqdm
import pandas as pd

#### Check Device

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load Dataset

In [None]:
pmc = load_dataset("hamzamooraj99/PMC-VQA-1", split='test').shuffle(seed=42).select(range(0,2500))

## Load Models

### Load Text Model

In [None]:
text_model_name = "esrgesbrt/trained_health_model_llama3.1_8B_bnb_4bits"
text_model, text_processor = FastLanguageModel.from_pretrained(
    text_model_name,
    load_in_4bit=True
)

FastLanguageModel.for_inference(text_model)

### Load Vision Model

In [None]:
vision_model_name = "hamzamooraj99/MedQA-Qwen-2B-LoRA16"
vision_model, vision_processor = FastVisionModel.from_pretrained(
    vision_model_name,
    load_in_4bit=True
)

vision_processor.image_processor = Qwen2VLImageProcessor(
    do_resize=True,
    max_pixels=256*256,
    min_pixels=224*224
)

FastVisionModel.for_inference(vision_model)

### Load RAG Model

In [None]:
RAG_model_name = "sentence-transformers/all-MiniLM-L6-v2"
RAG_model = SentenceTransformer(RAG_model_name)

# RAG

In [None]:
class RAGPipeline:
    def __init__(self, user_text, vision_response, k=5):
        faiss_path = r'C:\Users\hamza\Documents\Heriot-Watt\Y4\F20CA\Medical-CA-w-RAG\dataset\nhsInform\faiss_index.bin'
        texts_json = r'C:\Users\hamza\Documents\Heriot-Watt\Y4\F20CA\Medical-CA-w-RAG\dataset\nhsInform\texts.json'
        self.index = faiss.read_index(faiss_path)
        with open(texts_json, "r", encoding="utf-8") as f:
            self.texts = json.load(f)
        self.k = k
        self.query = self.embed_query(user_text, vision_response)
        self.results = self.search_faiss()
        self.context = self.format_rag_context()

    def embed_query(self, text: str, vision_response: str) -> str:
        text = text.strip()
        periods = ['.', '?', '!']
        if(text[-1] not in periods):
            text = text + '.'
        if(vision_response):
            return(text + " " + vision_response.strip())
    
        return(text)
    
    def search_faiss(self):
        query_embedding = RAG_model.encode([self.query], convert_to_numpy=True).astype("float32")
        distances, indices = self.index.search(query_embedding, self.k)
        
        return [(self.texts[i], distances[0][j]) for j, i in enumerate(indices[0])]
    
    def format_rag_context(self):
        context = "\n".join([f"Retrieved Info {i+1}: {res[0]}" for i, res in enumerate(self.results)])
        return context

# Data Collator

In [None]:
class DataCollator:
    def __init__(self, use_rag, vision_processor, vision_model):
        self.use_rag = use_rag
        self.vision_processor = vision_processor
        self.vision_model = vision_model

    def run_vision_inference(self, images, questions):
        messages = [[
            {'role': 'system',
            'content': [
                {'type': 'text', 'text': "You are a medical imaging analyst. Your job is to firstly provide a description of the image and then answer the question provided by the user with reference to the image"}
                ]
            },
            {'role': 'user',
            'content': [
                {'type': 'image'},
                {'type': 'text', 'text': f"Please describe what is shown in the image and answer the following query with reference to the image: '{question}'"}
                ]
            }
        ] for question in questions]

        input_text = self.vision_processor.apply_chat_template(messages, add_generation_prompt=True)

        # Preprocessing
        inputs = self.vision_processor(
            images,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
            truncation=True,
            padding=True
        ).to('cuda')

        # Inference
        with torch.no_grad():
            gen_ids = self.vision_model.generate(**inputs, max_new_tokens=128, use_cache=True)
            gen_ids = gen_ids[:, inputs.input_ids.shape[1]:]
            vision_responses = vision_processor.batch_decode(gen_ids, skip_special_tokens=True)

        return vision_responses
    
    def __call__(self, batch):

        questions = [sample['Question'] for sample in batch]
        images = [sample['image'] for sample in batch]
        vision_responses = self.run_vision_inference(images, questions)
        retrieved_info = [RAGPipeline(question, vision_response).context for question, vision_response in zip(questions, vision_responses)]
        if(self.use_rag):
            prompts = [f""" 
                You are a medical assistant providing health information.  
                - Use the retrieved information to **enhance the accuracy** of your response.  
                - Do **not generate external links** unless explicitly stated by the user.  
                - Respond clearly and concisely. 
                ### User Input:
                {question}  

                ### Image Response:
                {vr}

                ### Retrieved Information
                {info}
            
                ### Response:
                {{}}
            """
                for question, vr, info in zip(questions, vision_responses, retrieved_info)
            ]
        else:
            prompts = [
                f""" 
                Below is a query from a user regarding a medical condition or a description of symptoms. The user may also provide an image related to the query. Please provide an appropriate response to the user input with reference to the image response (if provided).
                ### User Input:
                {question}

                ### Image Response:
                {vr}
            
                ### Response:
                {{}}
                
                """
                for question, vr in zip(questions, vision_responses)
            ]
        
        inputs = text_processor(prompts, return_tensors="pt", padding=True, truncation=True).to('cuda')
        return {'inputs': inputs, 'retrieved_info': retrieved_info, 'vision_responses': vision_responses}

# Evaluation

## Load Metrics

In [None]:
bleu = load('bleu')
rouge = load('rouge')
bertscore = load('bertscore')

## Batch Inference

In [None]:
def batch_inference(dataloader):
    predictions = []
    references = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Running Batch Inference"):
            inputs = batch['inputs']
            retrieved_info = batch['retrieved_info']

            gen_kwargs = {
                "max_new_tokens": 256,
                "do_sample": True,
                "temperature": 0.7,
                "top_k": 50
            }

            
            outputs = text_model.generate(**inputs, **gen_kwargs)
            responses = text_processor.batch_decode(outputs, skip_special_tokens=True)
            
            
            for response, info in zip(responses, retrieved_info):
                response_start = response.find("### Response:")
                if response_start != -1:
                    final_response = response[response_start + len("### Response:"):].strip()
                else:
                    final_response = response.strip()
                
                reference = [i[0] for i in info]
                references.append(reference)
                predictions.append(final_response)
        
    return predictions, references

## Batch Eval

In [None]:
def eval_batch(predictions, references):
    bleu_scores = bleu.compute(predictions=predictions, references=references)
    rouge_scores = rouge.compute(predictions=predictions, references=references)
    bert_scores = bertscore.compute(predictions=predictions, references=references, lang='en')

    results = {
        "BLEU": bleu_scores["bleu"],
        "ROUGE-1": rouge_scores["rouge1"],
        "ROUGE-2": rouge_scores["rouge2"],
        "ROUGE-L": rouge_scores["rougeL"],
        "BERTScore": sum(bert_scores["f1"]) / len(bert_scores["f1"])
    }

    return results

## Main

In [None]:
if __name__ == "__main__":

    collator = DataCollator(False, vision_processor, vision_model)
    test_loader = DataLoader(pmc, batch_size=8, collate_fn=collator, num_workers=16, persistent_workers=True)

    predictions, references = batch_inference(test_loader)
    results = eval_batch(predictions, references)

    print("\nEvaluation Results:")
    print(pd.DataFrame([results]))