# Cell 1: Setup & Imports


In [None]:
import os
import json
import re
import time
import pandas as pd
import numpy as np
import nltk
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

DATA_DIR = 'samples'

print("Downloading NLTK punkt tokenizer...")
nltk.download('punkt')
nltk.download('punkt_tab')

print("Setup Complete.")

Downloading NLTK punkt tokenizer...


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\muham\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\muham\AppData\Roaming\nltk_data...


Setup Complete.


[nltk_data]   Unzipping tokenizers\punkt_tab.zip.


In [None]:


def extract_text_from_json(data):
    """
    Heuristic to extract the main text content from the JSON data.
    Assumes the most relevant text is the longest string value.
    """
    longest_text = ""
    if isinstance(data, dict):
        for value in data.values():
            if isinstance(value, str):
                if len(value) > len(longest_text):
                    longest_text = value
            elif isinstance(value, list): # Handle lists of strings
                 list_text = "\n".join(filter(lambda x: isinstance(x, str), value))
                 if len(list_text) > len(longest_text):
                     longest_text = list_text
    elif isinstance(data, str): # If the JSON root is just a string
        longest_text = data
    elif isinstance(data, list): # If the JSON root is a list
        list_text = "\n".join(filter(lambda x: isinstance(x, str), data))
        if len(list_text) > len(longest_text):
            longest_text = list_text

    return longest_text.strip()


def load_data_from_structure(data_dir):
    """
    Loads JSON data recursively, extracting metadata (Disease Category, PDD)
    from the folder structure and text content using a heuristic.
    Expected structure: data_dir / Disease Category / PDD / note.json
    """
    all_records = []
    print(f"Starting data loading from: {data_dir}")
    if not os.path.isdir(data_dir):
        print(f"Error: Data directory '{data_dir}' not found.")
        return []

    for root, dirs, files in os.walk(data_dir):
        # Check if the current directory seems like a PDD directory (contains .json files)
        if any(f.endswith('.json') for f in files):
            # Try to extract Disease Category and PDD from the path
            try:
                path_parts = os.path.normpath(root).split(os.sep)
                # Expecting structure like [... , data_dir, disease_cat, pdd_cat]
                # Find the index of the base data_dir
                base_dir_index = -1
                norm_data_dir = os.path.normpath(data_dir)
                for i, part in enumerate(path_parts):
                    # Check if the path up to this part matches the base data directory
                    current_path_check = os.path.join(*path_parts[:i+1])
                    if os.path.samefile(current_path_check, norm_data_dir):
                         base_dir_index = i
                         break

                if base_dir_index != -1 and len(path_parts) > base_dir_index + 2:
                    disease_category = path_parts[base_dir_index + 1]
                    pdd_category = path_parts[base_dir_index + 2]
                else:
                    # Fallback if structure is unexpected
                    print(f"Warning: Could not reliably determine Disease/PDD from path: {root}. Using parent folders.")
                    pdd_category = path_parts[-1] if len(path_parts) > 0 else "Unknown PDD"
                    disease_category = path_parts[-2] if len(path_parts) > 1 else "Unknown Disease"

            except Exception as e:
                print(f"Error parsing path structure for {root}: {e}")
                disease_category = "Unknown Disease"
                pdd_category = "Unknown PDD"


            for filename in files:
                if filename.endswith('.json'):
                    file_path = os.path.join(root, filename)
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            data = json.load(f)

                        # Extract text content using the heuristic
                        text_content = extract_text_from_json(data)

                        if text_content: # Only add if text content is found
                            all_records.append({
                                'id': file_path, # Use file path as a unique ID
                                'disease_category': disease_category,
                                'pdd': pdd_category, # Keep 'pdd' key for consistency downstream
                                'text': text_content
                            })

                    except json.JSONDecodeError:
                        print(f"Warning: Could not decode JSON from {file_path}")
                    except Exception as e:
                        print(f"Warning: Error processing file {file_path}: {e}")

    print(f"Loaded {len(all_records)} records using path structure.")
    return all_records

raw_data = load_data_from_structure(DATA_DIR)

if raw_data:
    df_raw = pd.DataFrame(raw_data)
else:
    df_raw = pd.DataFrame() 

print(f"Data loading finished. Number of records loaded: {len(df_raw)}")

if not df_raw.empty:
    print("\n--- Sample Loaded Record Structure ---")
    print(df_raw.head())
else:
    print("\nNo data loaded.")

Starting data loading from: samples
Loaded 511 records using path structure.
Data loading finished. Number of records loaded: 511

--- Sample Loaded Record Structure ---
                                                  id         disease_category  \
0  samples\Acute Coronary Syndrome\NSTEMI\1153590...  Acute Coronary Syndrome   
1  samples\Acute Coronary Syndrome\NSTEMI\1185908...  Acute Coronary Syndrome   
2  samples\Acute Coronary Syndrome\NSTEMI\1199071...  Acute Coronary Syndrome   
3  samples\Acute Coronary Syndrome\NSTEMI\1199283...  Acute Coronary Syndrome   
4  samples\Acute Coronary Syndrome\NSTEMI\1205401...  Acute Coronary Syndrome   

      pdd                                               text  
0  NSTEMI  F presents with history of HTN, hypothyroidism...  
3  NSTEMI  Female with PMH of rheumatoid arthritis on pre...  
4  NSTEMI  ADMISSION LABS\n___ 12:42PM BLOOD WBC-10.8 RBC...  


# Cell 3: Data Exploration


In [None]:
if not df_raw.empty:
    print("--- First 5 Records ---")
    print(df_raw.head())
    print("\n--- Sample Record Details ---")
    if len(df_raw) > 0:
        print(df_raw.iloc[0])
    print("\n--- Value Counts for PDD (Top 10) ---")
    print(df_raw['pdd'].value_counts().head(10))
    print(f"\n--- Basic Stats ---")
    print(f"Total records: {len(df_raw)}")
    print(f"Unique PDDs: {df_raw['pdd'].nunique()}")
else:
    print("No data loaded, skipping exploration.")

--- First 5 Records ---
                                                  id         disease_category  \
0  samples\Acute Coronary Syndrome\NSTEMI\1153590...  Acute Coronary Syndrome   
1  samples\Acute Coronary Syndrome\NSTEMI\1185908...  Acute Coronary Syndrome   
2  samples\Acute Coronary Syndrome\NSTEMI\1199071...  Acute Coronary Syndrome   
3  samples\Acute Coronary Syndrome\NSTEMI\1199283...  Acute Coronary Syndrome   
4  samples\Acute Coronary Syndrome\NSTEMI\1205401...  Acute Coronary Syndrome   

      pdd                                               text  
0  NSTEMI  F presents with history of HTN, hypothyroidism...  
3  NSTEMI  Female with PMH of rheumatoid arthritis on pre...  
4  NSTEMI  ADMISSION LABS\n___ 12:42PM BLOOD WBC-10.8 RBC...  

--- Sample Record Details ---
id                  samples\Acute Coronary Syndrome\NSTEMI\1153590...
disease_category                              Acute Coronary Syndrome
pdd                                                            NST

# Cell 4: Preprocessing Functions

In [None]:
def clean_text(text):
    if not isinstance(text, str):
        return ""
    text = text.lower()
    text = re.sub(r'\s+', ' ', text).strip() # Replace multiple whitespace with single space
    return text

def preprocess_documents(documents):
    """
    Applies cleaning and tokenization to the documents.
    Filters out documents with empty cleaned text.
    """
    preprocessed_docs = []
    for doc in documents:
        cleaned_text = clean_text(doc['text'])
        if cleaned_text: # Only keep documents with non-empty text after cleaning
            # Tokenize for BM25 - simple whitespace split is often sufficient,
            # but nltk.word_tokenize is more robust.
            tokens = nltk.word_tokenize(cleaned_text)
            preprocessed_docs.append({
                'id': doc['id'],
                'pdd': doc['pdd'],
                'original_text': doc['text'], # Keep original for context generation
                'cleaned_text': cleaned_text,
                'tokens': tokens
            })
        # else:
        #     print(f"Filtered out document {doc['id']} due to empty content after cleaning.")
    print(f"Preprocessing complete. Kept {len(preprocessed_docs)} documents.")
    return preprocessed_docs


if raw_data:
    preprocessed_data = preprocess_documents(raw_data)

    if preprocessed_data:
        print("\n--- Sample Preprocessed Record ---")
        print(f"ID: {preprocessed_data[0]['id']}")
        print(f"PDD: {preprocessed_data[0]['pdd']}")
        print(f"Cleaned Text: {preprocessed_data[0]['cleaned_text'][:200]}...")
        print(f"Tokens (first 20): {preprocessed_data[0]['tokens'][:20]}")
    else:
        print("No documents remained after preprocessing.")
else:
    print("Skipping preprocessing as no raw data was loaded.")
    preprocessed_data = []

Preprocessing complete. Kept 511 documents.

--- Sample Preprocessed Record ---
ID: samples\Acute Coronary Syndrome\NSTEMI\11535902-DS-14.json
PDD: NSTEMI
Cleaned Text: f presents with history of htn, hypothyroidism, no priorcardiac hx who presented to ed with chest pain. patient endorses right sided chest pain for the last 2 days which worsened today, at which point...
Tokens (first 20): ['f', 'presents', 'with', 'history', 'of', 'htn', ',', 'hypothyroidism', ',', 'no', 'priorcardiac', 'hx', 'who', 'presented', 'to', 'ed', 'with', 'chest', 'pain', '.']


# Cell 6: BM25 Indexing

In [None]:

if preprocessed_data:
    print("Starting BM25 Indexing...")
    tokenized_corpus = [doc['tokens'] for doc in preprocessed_data]

    start_time = time.time()
    bm25 = BM25Okapi(tokenized_corpus)
    end_time = time.time()
    print(f"BM25 Index created in {end_time - start_time:.2f} seconds.")
else:
    print("Skipping BM25 Indexing as there is no preprocessed data.")
    bm25 = None
    tokenized_corpus = []

Starting BM25 Indexing...
BM25 Index created in 0.02 seconds.


# Cell 7: Retrieval Function

In [34]:
# Cell 7: Retrieval Function

def retrieve_documents(query, bm25_index, preprocessed_docs, top_k=3):
    """
    Retrieves the top_k most relevant documents for a given query using BM25.
    """
    if not bm25_index or not preprocessed_docs:
        print("Error: BM25 index or preprocessed data not available.")
        return []

    # 1. Preprocess the query (same steps as documents)
    cleaned_query = clean_text(query)
    tokenized_query = nltk.word_tokenize(cleaned_query)

    # 2. Get BM25 scores for the query against all documents
    doc_scores = bm25_index.get_scores(tokenized_query)

    # 3. Get the indices of the top-k documents
    # Ensure we don't request more documents than available
    k = min(top_k, len(preprocessed_docs))
    top_n_indices = np.argsort(doc_scores)[::-1][:k] # Get indices sorted by score descending

    # 4. Retrieve the corresponding documents
    retrieved_docs = [preprocessed_docs[i] for i in top_n_indices if doc_scores[i] > 0] # Only return docs with score > 0

    print(f"Retrieved {len(retrieved_docs)} documents for query: '{query}'")
    # Optional: Print scores
    # for i in top_n_indices:
    #      print(f"  - Doc Index {i}, Score: {doc_scores[i]:.4f}, ID: {preprocessed_docs[i]['id']}")

    return retrieved_docs

# Cell 8: Generator Setup (LLM)

In [None]:
model_name = 'google/flan-t5-base'

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

try:
    print(f"Loading Tokenizer: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print(f"Loading Model: {model_name} (this may take a while)...")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    print("Model and Tokenizer loaded successfully.")
    llm_ready = True
except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    print("LLM setup failed. Generation will not be possible.")
    tokenizer = None
    model = None
    llm_ready = False

Using GPU: NVIDIA GeForce RTX 3080 Ti Laptop GPU
Loading Tokenizer: google/flan-t5-base


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Loading Model: google/flan-t5-base (this may take a while)...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Model and Tokenizer loaded successfully.


# Cell 9: RAG Prompt Template and Generation Function

In [35]:
# --- Prompt Template ---
PROMPT_TEMPLATE = """
Based *only* on the following context regarding diagnostic procedures for specific diseases, please answer the question. Do not use any prior knowledge. If the context does not contain the answer, state that the information is not available in the provided documents.

Context:
---
{context_str}
---

Question: {query}

Answer:
"""

def generate_answer(query, retrieved_docs, model, tokenizer, device, max_length=10000):
    """
    Generates an answer using the LLM based on the query and retrieved context.
    """
    if not model or not tokenizer:
        return "Error: LLM model or tokenizer not available."
    if not retrieved_docs:
        return "No relevant documents were found to answer the question."

    # 1. Format the context string
    # Using original_text might provide richer context for the LLM
    context_str = "\n\n---\n\n".join([doc['original_text'] for doc in retrieved_docs])


    # 2. Create the full prompt
    prompt = PROMPT_TEMPLATE.format(context_str=context_str, query=query)

    # 3. Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt", max_length=10000, truncation=True).to(device) 

    # 4. Generate the answer
    try:
        print("Generating answer...")
        with torch.no_grad(): 
             outputs = model.generate(
                 **inputs,
                 max_length=max_length, 
                 min_length=10,       
                 num_beams=20,          
                 early_stopping=True   
             )

        # 5. Decode the output
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return answer

    except Exception as e:
        print(f"Error during generation: {e}")
        return "Error occurred while generating the answer."


# Cell 10: Example Usage / User Simulation

In [None]:
def ask_question(query, bm25_index, preprocessed_docs, model, tokenizer, device, top_k=3):
    """
    Orchestrates the full RAG pipeline: retrieve -> generate.
    """
    print(f"\nProcessing Query: '{query}'")

    # 1. Retrieve relevant documents
    retrieved_docs = retrieve_documents(query, bm25_index, preprocessed_docs, top_k=top_k)

    # Handle case where no documents are retrieved
    if not retrieved_docs:
        print("No relevant documents found by retriever.")
        pass 

    answer = generate_answer(query, retrieved_docs, model, tokenizer, device)

    return answer


# User Simulation
if bm25 and preprocessed_data and llm_ready:
    print("\n" + "="*30 + " RAG System Demo " + "="*30)

    # --- Define Sample Queries ---
    queries = [
        "What are the typical diagnostic procedures for Congestive Heart Failure?",
        "Describe the process for diagnosing Pulmonary Embolism.",
        "How is Sepsis diagnosed according to these records?",
        "Outline the diagnostic steps for Acute Myocardial Infarction.",
        "What is the cause of short of breath?"
    ]

    # --- Run Queries Through the RAG Pipeline ---
    for q in queries:
        final_answer = ask_question(
            query=q,
            bm25_index=bm25,
            preprocessed_docs=preprocessed_data,
            model=model,
            tokenizer=tokenizer,
            device=device,
            top_k=3 # Number of documents to retrieve
        )
        print(f"\nQuery: {q}")
        print(f"Generated Answer:\n{final_answer}")
        print("-" * 70)

    print("="*30 + " Demo Finished " + "="*30 + "\n")



Processing Query: 'What are the typical diagnostic procedures for Congestive Heart Failure?'
Retrieved 3 documents for query: 'What are the typical diagnostic procedures for Congestive Heart Failure?'
Generating answer...

Query: What are the typical diagnostic procedures for Congestive Heart Failure?
Generated Answer:
pulmonary edema and small bilateral effusions
----------------------------------------------------------------------

Processing Query: 'Describe the process for diagnosing Pulmonary Embolism.'
Retrieved 3 documents for query: 'Describe the process for diagnosing Pulmonary Embolism.'
Generating answer...

Query: Describe the process for diagnosing Pulmonary Embolism.
Generated Answer:
CTA Chest: IMPRESSION: 1. No evidence of pulmonary embolism or acute aortic abnormality. 2. Mucous plugging in bilateral airways, most substantial in the left upper lobe. 3. Left lower lobe pulmonary nodule with somewhat spiculated margins, suspicious for malignancy. 4. 8 x 10 mm left low