# Import Libraries

In [1]:
from unsloth import FastVisionModel, FastLanguageModel
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
from sentence_transformers import SentenceTransformer
import json
import faiss

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


ModuleNotFoundError: No module named 'faiss'

# User Inputs

In [None]:
user_text = None
user_image = None
user_audio = None

# Load Models

### Load Text Model

In [None]:
text_model_name = "esrgesbrt/trained_health_model_llama3.1_8B_bnb_4bits"
text_model, text_processor = FastLanguageModel.from_pretrained(
    text_model_name,
    load_in_4bit=True
)

FastLanguageModel.for_inference(text_model)

### Load Vision Model

In [None]:
vision_model_name = "hamzamooraj99/MedQA-Qwen-2B-LoRA16"
vision_model, vision_processor = FastVisionModel.from_pretrained(
    vision_model_name,
    load_in_4bit=True
)

vision_processor.image_processor.max_pixels = 512*512
vision_processor.image_processor.min_pixels = 224*224

FastVisionModel.for_inference(vision_model)

### Load Whisper Model

In [None]:
speech_model_name = "openai/whisper-base"
speech_model = WhisperForConditionalGeneration.from_pretrained(speech_model_name)
speech_processor = WhisperProcessor.from_pretrained(speech_model_name)

### Load RAAG Model

In [None]:
RAG_model_name = "sentence-transformers/all-MiniLM-L6-v2"
RAG_model = SentenceTransformer(RAG_model_name)

# Process Inputs

### Text

In [None]:
def preprocess_text(text, vision_response, retrieved_info):
    alpaca_prompt = """ 
        Below is a query from a user regarding a medical condition or a description of symptoms. 
        The user may also provide an image related to the query. 
        Please provide an appropriate response to the user input with reference to the image response (if provided), making use of the retrieved information from our knowledge source.
        ### User Input:
        {}

        ### Image Response:
        {}

        ### Retrieved Information
        {}
    
        ### Response:
        {}
    """

    prompt = alpaca_prompt.format(text, vision_response, retrieved_info, "")
    inputs = text_processor([prompt], return_tensors="pt").to('cuda')
    return inputs

### Vision

In [None]:
def preprocess_image(image, text):
    messages = [
        {'role': 'user',
         'content': [
             {'type': 'image'},
             {'type': 'text', 'content': f"Please describe what is shown in the image and answer the following query with reference to the image: '{text}'"}
         ]}
    ]

    input_text = vision_processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = vision_processor(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors="pt"
    ).to('cuda')

    return inputs

### Whisper

In [None]:
def transcribe_audio(audio_input):
    inputs = speech_processor(audio_input, return_tensors="pt")
    with torch.no_grad():
        speech_out = speech_model.generate(**inputs)
        transcription = speech_processor.decode(speech_out[0], skip_special_tokens=True)
    return transcription

# Transcribe Audio

In [None]:
if user_audio and not user_text:
    user_text = transcribe_audio(user_audio)

# Generate Response to Image

In [None]:
if user_image:
    vision_inputs = preprocess_image(user_image, user_text)
    with torch.no_grad():
        vision_outputs = vision_model.generate(**vision_inputs, max_new_tokens=128, use_cache=True)
        vision_response = vision_processor.decode(vision_outputs[0], skip_special_tokens=True)
else:
    vision_response = None

# RAG

### Concatenate User Query and Vision Response

In [None]:
def embed_query(text, vision_response):
    text = text.strip()
    periods = ['.', '?', '!']
    if(text[-1] not in periods):
        text = text + "."

    if(vision_response and vision_response.strip()):
        return(text + " " + vision_response.strip())
    
    return(text)

### Search FAISS Function

In [None]:
def search_faiss(query, index, texts, k=5):
    query_embedding = RAG_model.encode([query], convert_to_numpy=True).astype("float32")
    distances, indices = index.search(query_embedding, k)
    
    return [(texts[i], distances[0][j]) for j, i in enumerate(indices[0])]

### Path to RAG files

In [None]:
index_file = r'..\..\dataset\nhsInform\faiss_index.bin'
texts_file = r'..\..\dataset\nhsInform\texts.json'

index = faiss.read_index(index_file)
with open(texts_file, "r", encoding="utf-8") as f:
    texts = json.load(f)

query = embed_query(user_text, vision_response)
results = search_faiss(query, index, texts, k=3)

### Format RAG Context

In [2]:
def format_rag_context(results):
    context = "\n".join([f"Retrieved Info {i+1}: {res[0]}" for i, res in enumerate(results)])
    return context

# Response Generation

### Collect Input Strings

In [None]:
user_input = user_text
image_input = vision_response
retrieved_info = format_rag_context(results)

### Format Input for LLaMa 3.1

In [None]:
inputs = preprocess_text(user_input, vision_response, retrieved_info)

# TO-DO

### Final Workflow Summary
---
 1. User provides input (text, speech, or image). ✅
 2. Preprocessing (Whisper for speech, Qwen for images). ✅
 3. Retrieval Step (User query is embedded & FAISS retrieves relevant texts). ✅
 4. Augmentation Step (Relevant texts are appended to the user query).
 5. LLaMa 3.1 generates a response based on augmented input.
 6. TTS converts text to speech if needed.
 7. Response is delivered to the user (as text or speech).