Step 1: Install Required Packages


In [7]:
#!pip install -q gradio sentence-transformers faiss-cpu transformers datasets nltk
#!apt-get install -y ffmpeg  # optional, for media support
#!pip install -q unrar
#!apt-get install -y unrar
#!unrar x samples.rar

Step 2: Import Libraries and Download NLTK Resources

In [48]:
import gradio as gr
import os
import json
import pandas as pd
import nltk
from nltk.corpus import stopwords
import string
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import pipeline

# Set NLTK data path
nltk.data.path.append("/usr/share/nltk_data")
nltk.download('punkt', download_dir="/usr/share/nltk_data")
nltk.download('stopwords', download_dir="/usr/share/nltk_data")
nltk.download('punkt_tab', download_dir="/usr/share/nltk_data") # Download punkt_tab

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

Step 3: Load Clinical Notes from Finished/ Folder

In [41]:
def load_notes_from_folder(root_folder="Finished"):
    note_records = []

    for disease in os.listdir(root_folder):
        disease_path = os.path.join(root_folder, disease)
        if not os.path.isdir(disease_path):
            continue
        for subfolder in os.listdir(disease_path):
            subfolder_path = os.path.join(disease_path, subfolder)
            if not os.path.isdir(subfolder_path):
                continue
            for filename in os.listdir(subfolder_path):
                if filename.endswith(".json"):
                    file_path = os.path.join(subfolder_path, filename)
                    try:
                        with open(file_path, "r") as f:
                            content = json.load(f)

                        # Extract diagnosis from the key (e.g., NSTEMI from "NSTEMI$Intermedia_5")
                        diagnosis = list(content.keys())[0].split("$")[0] if content else "Unknown"

                        # Extract clinical note text
                        input_text = " ".join([
                            content.get(f"input{i}", "").strip()
                            for i in range(1, 7)
                        ])

                        if input_text.strip():
                            note_records.append({
                                "text": input_text.strip(),
                                "label": diagnosis
                            })

                    except Exception as e:
                        print(f"Error reading {file_path}: {e}")

    df = pd.DataFrame(note_records)
    print(f"Loaded {len(df)} clinical notes.")
    return df

# Load dataset
print("Loading dataset...")
df = load_notes_from_folder("Finished")
if len(df) == 0:
    raise ValueError("No notes were loaded. Check folder structure and JSON parsing.")

Loading dataset...
Loaded 343 clinical notes.


Step 4: Preprocess Text and Build FAISS Index

In [49]:
# Preprocess text
def preprocess_text(text, stop_words):
    tokens = nltk.word_tokenize(text.lower())
    tokens = [t for t in tokens if t not in stop_words and t not in string.punctuation]
    return " ".join(tokens)

# Build FAISS index
def build_dense_index(df):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    stop_words = set(stopwords.words('english'))
    df['clean_text'] = df['text'].apply(lambda x: preprocess_text(x, stop_words))
    embeddings = model.encode(df['clean_text'].tolist(), convert_to_numpy=True)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index, df, model

# Build index
print("Building FAISS index...")
index, df, model = build_dense_index(df)

Building FAISS index...


Step 5: Load the Generator Model (Flan-T5)

In [50]:
# Load generator
def load_generator():
    generator = pipeline("text2text-generation", model="google/flan-t5-base", max_length=200)
    return generator

# Load model
print("Loading generator model...")
generator = load_generator()

Loading generator model...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Device set to use cuda:0


Step 6: Define Prompt Template and Evaluation Function

In [51]:
# Generate prompt with context and labels
def generate_prompt(query, retrieved_docs, labels):
    context = ""
    for doc, label in zip(retrieved_docs, labels):
        context += f"Note: {doc}\nDiagnosis: {label}\n\n"

    prompt = f"""
    You are a clinical assistant. Based on the following clinical notes and diagnoses:
    {context}

    Answer the following question:
    {query}

    Your response should be concise, accurate, and in plain English.
    """
    return prompt

# Evaluate answer against the label of the most relevant note
def evaluate_answer(answer, true_label):
    answer = answer.lower()
    true_label = true_label.lower()
    return int(true_label in answer or answer in true_label)

Step 7: Define the RAG Pipeline


In [52]:
# Define RAG function
def rag_pipeline(query, k=3):
    stop_words = set(stopwords.words('english'))
    clean_query = preprocess_text(query, stop_words)
    query_embedding = model.encode([clean_query], convert_to_numpy=True)
    distances, indices = index.search(query_embedding, k)

    retrieved_docs = df.iloc[indices[0]]['text'].tolist()
    retrieved_labels = df.iloc[indices[0]]['label'].tolist()
    most_relevant_label = retrieved_labels[0]

    prompt = generate_prompt(query, retrieved_docs, retrieved_labels)

    # Generate answer
    try:
        answer = generator(prompt, max_length=200, num_return_sequences=1)[0]['generated_text']
    except Exception as e:
        answer = f"Error in generation: {str(e)}"

    # Evaluate
    accuracy = evaluate_answer(answer, most_relevant_label)
    eval_text = "✅ Matched diagnosis label." if accuracy else "❌ Did not match diagnosis label."

    return answer, "\n\n".join(retrieved_docs), eval_text

Step 8: Build and Launch the Gradio Interface


In [55]:
# Gradio Interface
def respond(query):
    answer, context, evaluation = rag_pipeline(query)
    return answer, context, evaluation

with gr.Blocks() as demo:
    gr.Markdown("## 🏥 MediRAG - Clinical RAG Assistant using MIMIC-IV-Ext-DiReCT")
    gr.Markdown("Ask a clinical question, and the system will retrieve relevant clinical notes, generate an answer, and evaluate against diagnosis labels.")

    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="Enter your clinical question:", placeholder="e.g., What is the treatment for heart failure?")
            submit_btn = gr.Button("Get Answer")
        with gr.Column():
            answer_output = gr.Textbox(label="Generated Answer")
            context_output = gr.Textbox(label="Retrieved Clinical Notes", lines=10)
            eval_output = gr.Textbox(label="Evaluation Against Diagnosis Label")

    submit_btn.click(fn=respond, inputs=user_input, outputs=[answer_output, context_output, eval_output])

# Launch Gradio
demo.launch()

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://6260902c607e9ac946.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


