# Import Libraries

In [None]:
import json
import torch
import numpy as np
from tqdm import tqdm
import faiss
from sentence_transformers import SentenceTransformer
import os

# Load Sentence Transformer

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_name)

# Load NHS Data

In [None]:
def load_text_data(json_file):
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    texts = []
    for item in data:
        disease = item.get("Disease", "")
        symptoms = " ".join(item.get("Symptoms", []))
        treatments = " ".join(item.get("Treatments", []))
        combined_text = f"Disease: {disease}\nSymptoms: {symptoms}\nTreatments: {treatments}"
        texts.append(combined_text)
    return texts

# Get Embeddings in Batches

In [None]:
def get_batch_embeddings(texts, batch_size=16):
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Computing embeddings"):
        batch = texts[i : i + batch_size]
        batch_embeddings = model.encode(batch, convert_to_numpy=True)
        embeddings.extend(batch_embeddings)
    return np.array(embeddings, dtype="float32")

# Build FAISS Index using IVFFlat

In [None]:
def build_faiss_index(texts):
    print("\nComputing text embeddings...")
    embeddings = get_batch_embeddings(texts)

    print("\nBuilding FAISS index...")
    dimension = embeddings.shape[1]
    quantizer = faiss.IndexFlatL2(dimension)  # Used for clustering
    index = faiss.IndexIVFFlat(quantizer, dimension, 100)  # 100 clusters for faster search

    index.train(embeddings)  # Train FAISS with embeddings
    index.add(embeddings)  # Add embeddings to index
    
    return index, texts

# Save FAISS index and text data

In [None]:
def save_retrieval_system(index, texts, index_file, texts_file):
    print("\nSaving FAISS index and text data...")
    faiss.write_index(index, index_file)

    with open(texts_file, "w", encoding="utf-8") as f:
        json.dump(texts, f, ensure_ascii=False, indent=4)

    print("✅ FAISS retrieval system built and saved successfully!")

# Search FAISS

In [None]:
def search_faiss(query, index, texts, k=5):
    query_embedding = model.encode([query], convert_to_numpy=True).astype("float32")
    distances, indices = index.search(query_embedding, k)
    
    results = [(texts[i], distances[0][j]) for j, i in enumerate(indices[0])]
    return results

# Build and save

In [None]:
def build_and_save_text_retrieval_system(json_file, index_file, texts_file):
    texts = load_text_data(json_file)
    index, texts = build_faiss_index(texts)
    save_retrieval_system(index, texts, index_file, texts_file)

# MAIN

In [None]:
text_data_file = r'..\..\dataset\nhsInform\NHS_Data.json'
index_file = r'..\..\dataset\nhsInform\faiss_index.bin'
texts_file = r'..\..\dataset\nhsInform\texts.json'

build_and_save_text_retrieval_system(text_data_file, index_file, texts_file)

# TEST

In [None]:
index = faiss.read_index(index_file)
with open(texts_file, "r", encoding="utf-8") as f:
    texts = json.load(f)

query = "What are the symptoms of pneumonia?"
results = search_faiss(query, index, texts, k=3)

print("\n🔍 Search Results:")
for i, (text, score) in enumerate(results):
    print(f"{i+1}. {text} (Score: {score:.4f})")

In [None]:
def format_rag_context(results):
    """Formats the RAG results into a readable context string."""
    context = "\n".join([f"Retrieved Info {i+1}: {res[0]}" for i, res in enumerate(results)])
    return context

In [None]:
print(format_rag_context(results))