## Step 1: Data Preparation

In [1]:
import pandas as pd
from datasets import load_dataset
import os

def preprocess_medqa(df):
    """ Prepares dataset for retrieval-based QA. """
    processed_data = []
    for _, row in df.iterrows():
        question = row["sent1"]
        # Four multiple-choice options
        options = [row[f"ending{i}"] for i in range(4)]
        # 'label' indicates which option is correct
        correct_answer = options[row["label"]]

        processed_data.append({
            "question": question,
            "options": options,
            "answer": correct_answer,
            # For real RAG, you might fetch relevant doc passages or knowledge base entries. 
            # Here we use the 4 options as "context" for demonstration.
            "context": " ".join(options)  
        })
    return pd.DataFrame(processed_data)

# Check if processed files already exist
if (os.path.exists("medical_train.csv") and 
    os.path.exists("medical_val.csv") and 
    os.path.exists("medical_test.csv")):
    
    # Load existing processed files
    df_train_proc = pd.read_csv("medical_train.csv")
    df_val_proc = pd.read_csv("medical_val.csv") 
    df_test_proc = pd.read_csv("medical_test.csv")
    
else:
    # Load and process MedQA-USMLE dataset
    dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf")

    df_train = pd.DataFrame(dataset["train"])
    df_val = pd.DataFrame(dataset["validation"])
    df_test = pd.DataFrame(dataset["test"])

    df_train_proc = preprocess_medqa(df_train)
    df_val_proc = preprocess_medqa(df_val)
    df_test_proc = preprocess_medqa(df_test)

    # Save processed datasets
    df_train_proc.to_csv("medical_train.csv", index=False)
    df_val_proc.to_csv("medical_val.csv", index=False)
    df_test_proc.to_csv("medical_test.csv", index=False)

print(df_train_proc.head())


  from .autonotebook import tqdm as notebook_tqdm


                                            question  \
0  A 23-year-old pregnant woman at 22 weeks gesta...   
1  A 3-month-old baby died suddenly at night whil...   
2  A mother brings her 3-week-old infant to the p...   
3  A pulmonary autopsy specimen from a 58-year-ol...   
4  A 20-year-old woman presents with menorrhagia ...   

                                             options  \
0  ['Ampicillin', 'Ceftriaxone', 'Doxycycline', '...   
1  ['Placing the infant in a supine position on a...   
2  ['Abnormal migration of ventral pancreatic bud...   
3  ['Thromboembolism', 'Pulmonary ischemia', 'Pul...   
4  ['Hemophilia A', 'Lupus anticoagulant', 'Prote...   

                                              answer  \
0                                     Nitrofurantoin   
1  Placing the infant in a supine position on a f...   
2       Abnormal migration of ventral pancreatic bud   
3                                    Thromboembolism   
4                             Von Willebrand d

## Step 2: Create medical knowledge base

In [2]:
import os
import torch
import numpy as np
import faiss
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from pprint import pprint

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Define available models with their characteristics
AVAILABLE_MODELS = {
    "pritamdeka/S-PubMedBert-MS-MARCO": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "Specialized for medical/biomedical text, fine-tuned on PubMed"
    },
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "Trained on PubMed abstracts and full-text articles"
    },
    "gsarti/biobert-nli": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "BioBERT fine-tuned on NLI tasks, good for medical similarity"
    },
    "all-MiniLM-L6-v2": {
        "parameters": "22M",
        "embedding_dim": 384,
        "description": "Fast and lightweight model, good balance of speed and performance"
    },
    "all-mpnet-base-v2": {
        "parameters": "110M", 
        "embedding_dim": 768,
        "description": "One of the best performing general models"
    }
}

# Define available FAISS index types
FAISS_INDEXES = {
    "IndexFlatL2": {
        "description": "Exact L2 distance search. Most accurate but slower for large datasets.",
        "use_case": "Small to medium datasets where accuracy is critical",
        "recommended_size": "< 1M vectors"
    },
    "IndexIVFFlat": {
        "description": "Inverted file with exact post-verification. Good balance of speed and accuracy.",
        "use_case": "Medium to large datasets, allows approximate search",
        "recommended_size": "1M - 10M vectors"
    },
    "IndexHNSWFlat": {
        "description": "Hierarchical Navigable Small World graph. Very fast search with good accuracy.",
        "use_case": "Large datasets where search speed is critical",
        "recommended_size": "10M - 100M vectors"
    },
    "IndexLSH": {
        "description": "Locality-Sensitive Hashing. Fast but less accurate.",
        "use_case": "Very large datasets where approximate results are acceptable",
        "recommended_size": "> 100M vectors"
    }
}

def create_medical_knowledge_base():
    """Create a comprehensive medical knowledge base from multiple sources with quality filtering"""
    knowledge_base = []
    
    print("Loading datasets...")
    pubmedqa = load_dataset("pubmed_qa", "pqa_labeled")
    medmcqa = load_dataset("medmcqa")
    
    # Show some examples
    print("\nExample from PubMedQA:")
    example_pubmed = pubmedqa['train'][0]
    # Convert context to string before slicing
    context_preview = str(example_pubmed['context'])[:200] + "..."
    pprint({
        'question': example_pubmed['question'],
        'context_preview': context_preview,
        'long_answer': example_pubmed['long_answer']
    })
    
    print("\nExample from MedMCQA:")
    example_medmcqa = medmcqa['train'][0]
    # Convert explanation to string before slicing
    exp_preview = str(example_medmcqa['exp'])[:200] + "..." if example_medmcqa['exp'] else "No explanation"
    pprint({
        'question': example_medmcqa['question'],
        'explanation_preview': exp_preview,
        'correct_option': example_medmcqa['cop']
    })
    
    # Add PubMedQA abstracts
    for item in pubmedqa['train']:
        # Join all context pieces into a single string
        context_text = " ".join(item['context']['contexts'])
        if len(context_text.split()) >= 20:  # Length check
            knowledge_base.append({
                'text': context_text,
                'source': 'PubMedQA',
                'type': 'research_abstract',
                'metadata': {
                    'question': item['question'],
                    'long_answer': item['long_answer'],
                    'pubid': item['pubid']
                }
            })
    
    # Add MedMCQA explanations
    for item in medmcqa['train']:
        if item['exp'] and len(item['exp'].split()) >= 20:  # Check if explanation exists and length
            knowledge_base.append({
                'text': item['exp'],
                'source': 'MedMCQA', 
                'type': 'expert_explanation',
                'metadata': {
                    'question': item['question'],
                    'correct_answer': item['cop']
                }
            })
    
    print(f"\nCreated knowledge base with {len(knowledge_base)} entries")
    return knowledge_base

class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]

def build_retrieval_system(knowledge_base):
    """Build dense retrieval system with FAISS index"""
     # Print available models and their info
    print("\nAvailable SentenceTransformer Models:")
    for model_name, info in AVAILABLE_MODELS.items():
        print(f"\n{model_name}:")
        print(f"Parameters: {info['parameters']}")
        print(f"Embedding Dimension: {info['embedding_dim']}")
        print(f"Description: {info['description']}")

    # Model selection
    selected_model = input("\nEnter the name of the model you want to use (default: pritamdeka/S-PubMedBert-MS-MARCO): ").strip()
    if not selected_model or selected_model not in AVAILABLE_MODELS:
        print(f"Using default model: pritamdeka/S-PubMedBert-MS-MARCO")
        selected_model = "pritamdeka/S-PubMedBert-MS-MARCO"

    # Initialize encoder and configure GPU usage
    print(f"\nLoading {selected_model}...")
    encoder = SentenceTransformer(selected_model)

    # Enhanced GPU detection and configuration
    if torch.cuda.is_available():
        n_gpus = torch.cuda.device_count()
        print(f"Found {n_gpus} CUDA GPUs")
        
        if n_gpus > 1:
            print(f"Using {n_gpus} GPUs in parallel")
            # Use DataParallel with all available GPUs
            encoder = torch.nn.DataParallel(encoder)
            # Scale batch size with number of GPUs, but cap it for stability
            batch_size = min(32 * n_gpus, 256)  # Cap at 256 to prevent OOM
        else:
            print("Using single GPU")
            device = torch.device("cuda:0")
            encoder.to(device)
            batch_size = 64
    else:
        print("CUDA is not available. Using CPU.")
        device = torch.device("cpu")
        batch_size = 32

    # Optimize data loading for multi-GPU
    texts = [entry['text'] for entry in knowledge_base]
    dataset = TextDataset(texts)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=min(4, os.cpu_count() or 1),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # Optimized embedding generation
    print("Generating embeddings...")
    embeddings_list = []
    encoder.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            try:
                if isinstance(encoder, torch.nn.DataParallel):
                    # Automatic batch splitting across GPUs
                    batch_embeddings = encoder.module.encode(
                        batch,
                        convert_to_numpy=True,
                        device=None  # Let DataParallel handle device placement
                    )
                else:
                    # Single GPU or CPU processing
                    batch_embeddings = encoder.encode(
                        batch,
                        convert_to_numpy=True,
                        device=device if 'device' in locals() else None
                    )
                
                embeddings_list.append(batch_embeddings)

            except RuntimeError as e:
                print(f"Error processing batch: {e}")
                # Graceful fallback with reduced batch size
                if len(batch) > 1:
                    print("Reducing batch size and retrying...")
                    # Process in smaller chunks
                    chunk_size = len(batch) // 4
                    for i in range(0, len(batch), chunk_size):
                        sub_batch = batch[i:i + chunk_size]
                        if isinstance(encoder, torch.nn.DataParallel):
                            sub_embeddings = encoder.module.encode(
                                sub_batch,
                                convert_to_numpy=True,
                                device=None
                            )
                        else:
                            sub_embeddings = encoder.encode(
                                sub_batch,
                                convert_to_numpy=True,
                                device=device if 'device' in locals() else None
                            )
                        embeddings_list.append(sub_embeddings)
                else:
                    raise e

            # Explicit GPU memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Concatenate embeddings
    print("Concatenating embeddings...")
    embeddings = np.concatenate(embeddings_list, axis=0)
    
    # FAISS index selection and creation
    print("\nAvailable FAISS Index Types:")
    for index_name, info in FAISS_INDEXES.items():
        print(f"\n{index_name}:")
        print(f"Description: {info['description']}")
        print(f"Use Case: {info['use_case']}")
        print(f"Recommended Dataset Size: {info['recommended_size']}")

    data_size = len(embeddings)
    recommended_index = "IndexFlatL2"
    if data_size > 100_000_000:
        recommended_index = "IndexLSH"
    elif data_size > 10_000_000:
        recommended_index = "IndexHNSWFlat"
    elif data_size > 1_000_000:
        recommended_index = "IndexIVFFlat"

    print(f"\nBased on your dataset size ({data_size:,} vectors), we recommend using: {recommended_index}")
    
    selected_index = input("\nEnter the name of the index type you want to use (default: recommended): ").strip()
    if not selected_index or selected_index not in FAISS_INDEXES:
        print(f"Using recommended index: {recommended_index}")
        selected_index = recommended_index

    # Build FAISS index
    dimension = embeddings.shape[1]
    
    try:
        if selected_index == "IndexFlatL2":
            index = faiss.IndexFlatL2(dimension)
        elif selected_index == "IndexIVFFlat":
            nlist = min(4096, max(data_size // 30, 100))
            quantizer = faiss.IndexFlatL2(dimension)
            index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
            print("Training IVF index...")
            index.train(embeddings)
        elif selected_index == "IndexHNSWFlat":
            M = 32
            index = faiss.IndexHNSWFlat(dimension, M)
        elif selected_index == "IndexLSH":
            nbits = min(64, dimension)
            index = faiss.IndexLSH(dimension, nbits)

        print("Adding vectors to index...")
        index.add(embeddings)
        
    except Exception as e:
        print(f"Error creating FAISS index: {e}")
        print("Falling back to simple IndexFlatL2...")
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
    
    return index, encoder, selected_model

def retrieve_contexts(query, index, encoder, knowledge_base, k=3):
    """Retrieve relevant contexts for a query"""
    try:
        # Handle both DataParallel and regular encoder cases
        if isinstance(encoder, torch.nn.DataParallel):
            actual_encoder = encoder.module
        else:
            actual_encoder = encoder
            
        # Move query to same device as encoder
        device = next(actual_encoder.parameters()).device
        
        # Encode query
        with torch.no_grad():
            query_vector = actual_encoder.encode([query], convert_to_numpy=True, device=device)
        
        # Search index
        distances, indices = index.search(query_vector, k)
        
        # Return relevant contexts with metadata and distances
        retrieved = []
        for idx, distance in zip(indices[0], distances[0]):
            if 0 <= idx < len(knowledge_base):  # Validate index
                context = knowledge_base[idx].copy()
                context['distance'] = float(distance)
                retrieved.append(context)
        
        return retrieved
        
    except Exception as e:
        print(f"Error retrieving contexts: {e}")
        return []

if __name__ == "__main__":
    kb_file = "medical_knowledge_base.pkl"
    index_file = "faiss_index.bin"
    model_name_file = "model_name.txt"
    
    try:
        if os.path.exists(kb_file) and os.path.exists(index_file) and os.path.exists(model_name_file):
            print("Loading existing knowledge base and index...")
            with open(kb_file, 'rb') as f:
                kb = pickle.load(f)
            index = faiss.read_index(index_file)
            with open(model_name_file, 'r') as f:
                model_name = f.read().strip()
            encoder = SentenceTransformer(model_name)
        else:
            print("Creating new knowledge base and index...")
            kb = create_medical_knowledge_base()
            index, encoder, model_name = build_retrieval_system(kb)
            
            print("Saving knowledge base, index and model name...")
            with open(kb_file, 'wb') as f:
                pickle.dump(kb, f)
            faiss.write_index(index, index_file)
            with open(model_name_file, 'w') as f:
                f.write(model_name)
        
        # Example queries
        example_queries = [
            "What are the symptoms of diabetes?",
            "How is breast cancer diagnosed?",
            "What are the side effects of chemotherapy?"
        ]
        
        print("\nTesting retrieval system with example queries:")
        for query in example_queries:
            print(f"\nQuery: {query}")
            relevant_contexts = retrieve_contexts(query, index, encoder, kb)
            for i, context in enumerate(relevant_contexts, 1):
                print(f"\nRelevant Context {i}:")
                print(f"Source: {context['source']}")
                print(f"Type: {context['type']}")
                print(f"Text preview: {context['text'][:200]}...")
                
    except Exception as e:
        print(f"An error occurred: {e}")

  from .autonotebook import tqdm as notebook_tqdm


Creating new knowledge base and index...
Loading datasets...

Example from PubMedQA:
{'context_preview': "{'contexts': ['Programmed cell death (PCD) is the "
                    'regulated death of cells within an organism. The lace '
                    'plant (Aponogeton madagascariensis) produces perforations '
                    'in its leaves through PCD. The leaves ...',
 'long_answer': 'Results depicted mitochondrial dynamics in vivo as PCD '
                'progresses within the lace plant, and highlight the '
                'correlation of this organelle with other organelles during '
                'developmental PCD. To the best of our knowledge, this is the '
                'first report of mitochondria and chloroplasts moving on '
                'transvacuolar strands to form a ring structure surrounding '
                'the nucleus during developmental PCD. Also, for the first '
                'time, we have shown the feasibility for the use of CsA in a '
       

100%|██████████| 1016/1016 [07:44<00:00,  2.19it/s]


Concatenating embeddings...

Available FAISS Index Types:

IndexFlatL2:
Description: Exact L2 distance search. Most accurate but slower for large datasets.
Use Case: Small to medium datasets where accuracy is critical
Recommended Dataset Size: < 1M vectors

IndexIVFFlat:
Description: Inverted file with exact post-verification. Good balance of speed and accuracy.
Use Case: Medium to large datasets, allows approximate search
Recommended Dataset Size: 1M - 10M vectors

IndexHNSWFlat:
Description: Hierarchical Navigable Small World graph. Very fast search with good accuracy.
Use Case: Large datasets where search speed is critical
Recommended Dataset Size: 10M - 100M vectors

IndexLSH:
Description: Locality-Sensitive Hashing. Fast but less accurate.
Use Case: Very large datasets where approximate results are acceptable
Recommended Dataset Size: > 100M vectors

Based on your dataset size (129,952 vectors), we recommend using: IndexFlatL2
Adding vectors to index...
Saving knowledge base, inde

## Step 3: Small LLM for RAG: Flan-T5

In [13]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

model_name = "google/flan-t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS backend for Apple Silicon.")
else:
    device = torch.device("cpu")
    print("Using CPU.")
model.to(device)

# RAG inference pipeline
def rag_infer(question, top_k=1, max_length=128):
    # Retrieve context
    retrieved_list = retrieve_context(question, k=top_k)
    combined_context = "\n".join(retrieved_list)
    
    # Prepare prompt for T5
    prompt = f"answer the medical question based on context:\nContext: {combined_context}\nQuestion: {question}\nAnswer:"
    # prompt = f"answer the medical question based on context:\nQuestion: {question}\nAnswer:"
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_beams=4,
            early_stopping=True
        )
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return combined_context, answer
    
# Quick test
context, sample_answer = rag_infer("What is the cause of low blood pressure?")
print(f"Provided context:\n {context}")
print(f"Generated Asnwer: {sample_answer}")


Using MPS backend for Apple Silicon.
Provided context:
 Decreased vascular resistance Increased cardiac output Diastolic murmur Low blood pressure
Generated Asnwer: Decreased vascular resistance


## Fine-Tuning Flan-T5 with RAG

In [25]:
from datasets import Dataset
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

def prepare_finetune_data(df):
    data_list = []
    for _, row in df.iterrows():
        # input: RAG style prompt
        input_text = (
            f"answer the medical question based on context:\n"
            f"Context: {row['context']}\n"
            f"Question: {row['question']}\n"
            f"Answer:"
        )
        # target: correct answer
        target_text = row["answer"]
        data_list.append({"input_text": input_text, "target_text": target_text})
    return pd.DataFrame(data_list)

ft_train = prepare_finetune_data(df_train_proc)
ft_val = prepare_finetune_data(df_val_proc)

train_dataset = Dataset.from_pandas(ft_train)
val_dataset = Dataset.from_pandas(ft_val)

def tokenize_function(example):
    model_inputs = tokenizer(example["input_text"], max_length=512, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example["target_text"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    max_length=512,
    label_pad_token_id=-100,
)

train_dataset = train_dataset.map(tokenize_function, batched=True)
valid_dataset = val_dataset.map(tokenize_function, batched=True)

train_dataset = train_dataset.remove_columns(["input_text", "target_text"])
val_dataset = val_dataset.remove_columns(["input_text", "target_text"])

training_args = TrainingArguments(
    output_dir="./flan_t5_medical",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs = 1,
    weight_decay=0.01,
    save_total_limit=2,
    logging_steps=100,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer, 
    data_collator=data_collator
)

trainer.train()




Map: 100%|██████████| 10178/10178 [00:03<00:00, 2731.16 examples/s]
Map: 100%|██████████| 1272/1272 [00:00<00:00, 2980.02 examples/s]
  trainer = Trainer(


RuntimeError: MPS backend out of memory (MPS allocated: 8.42 GB, other allocations: 11.98 GB, max allowed: 20.40 GB). Tried to allocate 741.00 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).