## Step 1: Data Preparation

In [1]:
import pandas as pd
from datasets import load_dataset
import os

def preprocess_medqa(df):
    """ Prepares dataset for retrieval-based QA. """
    processed_data = []
    for _, row in df.iterrows():
        question = row["sent1"]
        # Four multiple-choice options
        options = [row[f"ending{i}"] for i in range(4)]
        # 'label' indicates which option is correct
        correct_answer = options[row["label"]]

        processed_data.append({
            "question": question,
            "options": options,
            "answer": correct_answer,
            # For real RAG, you might fetch relevant doc passages or knowledge base entries. 
            # Here we use the 4 options as "context" for demonstration.
            "context": " ".join(options)  
        })
    return pd.DataFrame(processed_data)

# Check if processed files already exist
if (os.path.exists("medical_train.csv") and 
    os.path.exists("medical_val.csv") and 
    os.path.exists("medical_test.csv")):
    
    # Load existing processed files
    df_train_proc = pd.read_csv("medical_train.csv")
    df_val_proc = pd.read_csv("medical_val.csv") 
    df_test_proc = pd.read_csv("medical_test.csv")
    
else:
    # Load and process MedQA-USMLE dataset
    dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf")

    df_train = pd.DataFrame(dataset["train"])
    df_val = pd.DataFrame(dataset["validation"])
    df_test = pd.DataFrame(dataset["test"])

    df_train_proc = preprocess_medqa(df_train)
    df_val_proc = preprocess_medqa(df_val)
    df_test_proc = preprocess_medqa(df_test)

    # Save processed datasets
    df_train_proc.to_csv("medical_train.csv", index=False)
    df_val_proc.to_csv("medical_val.csv", index=False)
    df_test_proc.to_csv("medical_test.csv", index=False)

print(df_train_proc.head())


  from .autonotebook import tqdm as notebook_tqdm


                                            question  \
0  A 23-year-old pregnant woman at 22 weeks gesta...   
1  A 3-month-old baby died suddenly at night whil...   
2  A mother brings her 3-week-old infant to the p...   
3  A pulmonary autopsy specimen from a 58-year-ol...   
4  A 20-year-old woman presents with menorrhagia ...   

                                             options  \
0  ['Ampicillin', 'Ceftriaxone', 'Doxycycline', '...   
1  ['Placing the infant in a supine position on a...   
2  ['Abnormal migration of ventral pancreatic bud...   
3  ['Thromboembolism', 'Pulmonary ischemia', 'Pul...   
4  ['Hemophilia A', 'Lupus anticoagulant', 'Prote...   

                                              answer  \
0                                     Nitrofurantoin   
1  Placing the infant in a supine position on a f...   
2       Abnormal migration of ventral pancreatic bud   
3                                    Thromboembolism   
4                             Von Willebrand d

## Step 2: Create medical knowledge base

In [3]:
import os
import torch
import numpy as np
import faiss
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from pprint import pprint

# Set environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Define available models with their characteristics
AVAILABLE_MODELS = {
    "pritamdeka/S-PubMedBert-MS-MARCO": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "Specialized for medical/biomedical text, fine-tuned on PubMed"
    },
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "Trained on PubMed abstracts and full-text articles"
    },
    "gsarti/biobert-nli": {
        "parameters": "110M",
        "embedding_dim": 768,
        "description": "BioBERT fine-tuned on NLI tasks, good for medical similarity"
    },
    "all-MiniLM-L6-v2": {
        "parameters": "22M",
        "embedding_dim": 384,
        "description": "Fast and lightweight model, good balance of speed and performance"
    },
    "all-mpnet-base-v2": {
        "parameters": "110M", 
        "embedding_dim": 768,
        "description": "One of the best performing general models"
    }
}

# Define available FAISS index types
FAISS_INDEXES = {
    "IndexFlatL2": {
        "description": "Exact L2 distance search. Most accurate but slower for large datasets.",
        "use_case": "Small to medium datasets where accuracy is critical",
        "recommended_size": "< 1M vectors"
    },
    "IndexIVFFlat": {
        "description": "Inverted file with exact post-verification. Good balance of speed and accuracy.",
        "use_case": "Medium to large datasets, allows approximate search",
        "recommended_size": "1M - 10M vectors"
    },
    "IndexHNSWFlat": {
        "description": "Hierarchical Navigable Small World graph. Very fast search with good accuracy.",
        "use_case": "Large datasets where search speed is critical",
        "recommended_size": "10M - 100M vectors"
    },
    "IndexLSH": {
        "description": "Locality-Sensitive Hashing. Fast but less accurate.",
        "use_case": "Very large datasets where approximate results are acceptable",
        "recommended_size": "> 100M vectors"
    }
}

def create_medical_knowledge_base():
    """Create a comprehensive medical knowledge base from multiple sources with quality filtering"""
    knowledge_base = []
    
    print("Loading datasets...")
    pubmedqa = load_dataset("pubmed_qa", "pqa_labeled")
    medmcqa = load_dataset("medmcqa")
    
    # Show some examples
    print("\nExample from PubMedQA:")
    example_pubmed = pubmedqa['train'][0]
    # Convert context to string before slicing
    context_preview = str(example_pubmed['context'])[:200] + "..."
    pprint({
        'question': example_pubmed['question'],
        'context_preview': context_preview,
        'long_answer': example_pubmed['long_answer']
    })
    
    print("\nExample from MedMCQA:")
    example_medmcqa = medmcqa['train'][0]
    # Convert explanation to string before slicing
    exp_preview = str(example_medmcqa['exp'])[:200] + "..." if example_medmcqa['exp'] else "No explanation"
    pprint({
        'question': example_medmcqa['question'],
        'explanation_preview': exp_preview,
        'correct_option': example_medmcqa['cop']
    })
    
    # Add PubMedQA abstracts
    for item in pubmedqa['train']:
        # Join all context pieces into a single string
        context_text = " ".join(item['context']['contexts'])
        if len(context_text.split()) >= 20:  # Length check
            knowledge_base.append({
                'text': context_text,
                'source': 'PubMedQA',
                'type': 'research_abstract',
                'metadata': {
                    'question': item['question'],
                    'long_answer': item['long_answer'],
                    'pubid': item['pubid']
                }
            })
    
    # Add MedMCQA explanations
    for item in medmcqa['train']:
        if item['exp'] and len(item['exp'].split()) >= 20:  # Check if explanation exists and length
            knowledge_base.append({
                'text': item['exp'],
                'source': 'MedMCQA', 
                'type': 'expert_explanation',
                'metadata': {
                    'question': item['question'],
                    'correct_answer': item['cop']
                }
            })
    
    print(f"\nCreated knowledge base with {len(knowledge_base)} entries")
    return knowledge_base

class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]

def build_retrieval_system(knowledge_base):
    """Build dense retrieval system with FAISS index"""
     # Print available models and their info
    print("\nAvailable SentenceTransformer Models:")
    for model_name, info in AVAILABLE_MODELS.items():
        print(f"\n{model_name}:")
        print(f"Parameters: {info['parameters']}")
        print(f"Embedding Dimension: {info['embedding_dim']}")
        print(f"Description: {info['description']}")

    # Model selection
    selected_model = input("\nEnter the name of the model you want to use (default: pritamdeka/S-PubMedBert-MS-MARCO): ").strip()
    if not selected_model or selected_model not in AVAILABLE_MODELS:
        print(f"Using default model: pritamdeka/S-PubMedBert-MS-MARCO")
        selected_model = "pritamdeka/S-PubMedBert-MS-MARCO"

    # Initialize encoder and configure GPU usage
    print(f"\nLoading {selected_model}...")
    encoder = SentenceTransformer(selected_model)

    # Enhanced GPU detection and configuration
    if torch.cuda.is_available():
        n_gpus = torch.cuda.device_count()
        print(f"Found {n_gpus} CUDA GPUs")
        
        if n_gpus > 1:
            print(f"Using {n_gpus} GPUs in parallel")
            # Use DataParallel with all available GPUs
            encoder = torch.nn.DataParallel(encoder)
            # Scale batch size with number of GPUs, but cap it for stability
            batch_size = min(32 * n_gpus, 256)  # Cap at 256 to prevent OOM
        else:
            print("Using single GPU")
            device = torch.device("cuda:0")
            encoder.to(device)
            batch_size = 64
    else:
        print("CUDA is not available. Using CPU.")
        device = torch.device("cpu")
        batch_size = 32

    # Optimize data loading for multi-GPU
    texts = [entry['text'] for entry in knowledge_base]
    dataset = TextDataset(texts)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=min(4, os.cpu_count() or 1),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if torch.cuda.is_available() else False
    )

    # Optimized embedding generation
    print("Generating embeddings...")
    embeddings_list = []
    encoder.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            try:
                if isinstance(encoder, torch.nn.DataParallel):
                    # Automatic batch splitting across GPUs
                    batch_embeddings = encoder.module.encode(
                        batch,
                        convert_to_numpy=True,
                        device=None  # Let DataParallel handle device placement
                    )
                else:
                    # Single GPU or CPU processing
                    batch_embeddings = encoder.encode(
                        batch,
                        convert_to_numpy=True,
                        device=device if 'device' in locals() else None
                    )
                
                embeddings_list.append(batch_embeddings)

            except RuntimeError as e:
                print(f"Error processing batch: {e}")
                # Graceful fallback with reduced batch size
                if len(batch) > 1:
                    print("Reducing batch size and retrying...")
                    # Process in smaller chunks
                    chunk_size = len(batch) // 4
                    for i in range(0, len(batch), chunk_size):
                        sub_batch = batch[i:i + chunk_size]
                        if isinstance(encoder, torch.nn.DataParallel):
                            sub_embeddings = encoder.module.encode(
                                sub_batch,
                                convert_to_numpy=True,
                                device=None
                            )
                        else:
                            sub_embeddings = encoder.encode(
                                sub_batch,
                                convert_to_numpy=True,
                                device=device if 'device' in locals() else None
                            )
                        embeddings_list.append(sub_embeddings)
                else:
                    raise e

            # Explicit GPU memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Concatenate embeddings
    print("Concatenating embeddings...")
    embeddings = np.concatenate(embeddings_list, axis=0)
    
    # FAISS index selection and creation
    print("\nAvailable FAISS Index Types:")
    for index_name, info in FAISS_INDEXES.items():
        print(f"\n{index_name}:")
        print(f"Description: {info['description']}")
        print(f"Use Case: {info['use_case']}")
        print(f"Recommended Dataset Size: {info['recommended_size']}")

    data_size = len(embeddings)
    recommended_index = "IndexFlatL2"
    if data_size > 100_000_000:
        recommended_index = "IndexLSH"
    elif data_size > 10_000_000:
        recommended_index = "IndexHNSWFlat"
    elif data_size > 1_000_000:
        recommended_index = "IndexIVFFlat"

    print(f"\nBased on your dataset size ({data_size:,} vectors), we recommend using: {recommended_index}")
    
    selected_index = input("\nEnter the name of the index type you want to use (default: recommended): ").strip()
    if not selected_index or selected_index not in FAISS_INDEXES:
        print(f"Using recommended index: {recommended_index}")
        selected_index = recommended_index

    # Build FAISS index
    dimension = embeddings.shape[1]
    
    try:
        if selected_index == "IndexFlatL2":
            index = faiss.IndexFlatL2(dimension)
        elif selected_index == "IndexIVFFlat":
            nlist = min(4096, max(data_size // 30, 100))
            quantizer = faiss.IndexFlatL2(dimension)
            index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
            print("Training IVF index...")
            index.train(embeddings)
        elif selected_index == "IndexHNSWFlat":
            M = 32
            index = faiss.IndexHNSWFlat(dimension, M)
        elif selected_index == "IndexLSH":
            nbits = min(64, dimension)
            index = faiss.IndexLSH(dimension, nbits)

        print("Adding vectors to index...")
        index.add(embeddings)
        
    except Exception as e:
        print(f"Error creating FAISS index: {e}")
        print("Falling back to simple IndexFlatL2...")
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
    
    return index, encoder, selected_model

def retrieve_contexts(query, index, encoder, knowledge_base, k=3):
    """Retrieve relevant contexts for a query"""
    try:
        # Handle both DataParallel and regular encoder cases
        if isinstance(encoder, torch.nn.DataParallel):
            actual_encoder = encoder.module
        else:
            actual_encoder = encoder
            
        # Move query to same device as encoder
        device = next(actual_encoder.parameters()).device
        
        # Encode query
        with torch.no_grad():
            query_vector = actual_encoder.encode([query], convert_to_numpy=True, device=device)
        
        # Search index
        distances, indices = index.search(query_vector, k)
        
        # Return relevant contexts with metadata and distances
        retrieved = []
        for idx, distance in zip(indices[0], distances[0]):
            if 0 <= idx < len(knowledge_base):  # Validate index
                context = knowledge_base[idx].copy()
                context['distance'] = float(distance)
                retrieved.append(context)
        
        return retrieved
        
    except Exception as e:
        print(f"Error retrieving contexts: {e}")
        return []

if __name__ == "__main__":
    kb_file = "medical_knowledge_base.pkl"
    index_file = "faiss_index.bin"
    model_name_file = "model_name.txt"
    
    try:
        if os.path.exists(kb_file) and os.path.exists(index_file) and os.path.exists(model_name_file):
            print("Loading existing knowledge base and index...")
            with open(kb_file, 'rb') as f:
                kb = pickle.load(f)
            index = faiss.read_index(index_file)
            with open(model_name_file, 'r') as f:
                model_name = f.read().strip()
            encoder = SentenceTransformer(model_name)
        else:
            print("Creating new knowledge base and index...")
            kb = create_medical_knowledge_base()
            index, encoder, model_name = build_retrieval_system(kb)
            
            print("Saving knowledge base, index and model name...")
            with open(kb_file, 'wb') as f:
                pickle.dump(kb, f)
            faiss.write_index(index, index_file)
            with open(model_name_file, 'w') as f:
                f.write(model_name)
        
        # Example queries
        example_queries = [
            "What are the symptoms of diabetes?",
            "How is breast cancer diagnosed?",
            "What are the side effects of chemotherapy?"
        ]
        
        print("\nTesting retrieval system with example queries:")
        for query in example_queries:
            print(f"\nQuery: {query}")
            relevant_contexts = retrieve_contexts(query, index, encoder, kb)
            for i, context in enumerate(relevant_contexts, 1):
                print(f"\nRelevant Context {i}:")
                print(f"Source: {context['source']}")
                print(f"Type: {context['type']}")
                print(f"Text preview: {context['text'][:200]}...")
                
    except Exception as e:
        print(f"An error occurred: {e}")

Loading existing knowledge base and index...

Testing retrieval system with example queries:

Query: What are the symptoms of diabetes?

Relevant Context 1:
Source: MedMCQA
Type: expert_explanation
Text preview: DM is a syndrome consisting of hyperglycemia, large vessel disease, micro vascular disease, and neuropathy. The classic presenting symptoms are increased thirst, polyuria, polyphagia, and weight loss....

Relevant Context 2:
Source: MedMCQA
Type: expert_explanation
Text preview: Manifestations of DKA are : * Symptoms - Nausea/vomiting, thirst/polyuria, abdominal pain, shoness of breath. * Physical Findings - Tachycardia, dehydration/hypotension, tachypnea/Kussmaul respiration...

Relevant Context 3:
Source: MedMCQA
Type: expert_explanation
Text preview: Symptoms of hyperglycemia include polyuria, polydipsia, weight loss, fatigue, weakness, blurry vision, frequent superficial infections (vaginitis, fungal skin infections), and slow healing of skin les...

Query: How is breast ca

## Step 3: Use an LLM with RAG

In [None]:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 
from rouge_score import rouge_scorer
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
import nltk
from bert_score import score
import spacy

# Download required NLTK data
nltk.download('punkt')

# Define available medical QA models
MEDICAL_QA_MODELS = {
    "google/flan-t5-large": {
        "parameters": "780M",
        "description": "Strong general-purpose model, good at following instructions",
        "strengths": "Versatile, good at structured responses"
    },
    "GanjinZero/biomedical-flan-t5-large": {
        "parameters": "780M",
        "description": "FLAN-T5 fine-tuned on medical datasets",
        "strengths": "Specialized for medical domain"
    },
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract": {
        "parameters": "110M",
        "description": "Trained on PubMed abstracts",
        "strengths": "Strong medical domain knowledge"
    },
    "epfl-llm/medical-llama-7b": {
        "parameters": "7B",
        "description": "LLaMA fine-tuned on medical data",
        "strengths": "Comprehensive medical knowledge"
    }
}

class MultiGPUConfig:
    def __init__(self):
        self.n_gpus = torch.cuda.device_count()
        self.using_multi_gpu = self.n_gpus > 1
        
    def setup_distributed(self, rank):
        """Initialize distributed training"""
        if self.using_multi_gpu:
            dist.init_process_group(
                backend='nccl',
                init_method='tcp://localhost:12355',
                world_size=self.n_gpus,
                rank=rank
            )
            torch.cuda.set_device(rank)

def initialize_qa_model(gpu_config, rank=0):
    """Initialize QA model with multi-GPU support"""
    print("\nAvailable Medical QA Models:")
    for model_name, info in MEDICAL_QA_MODELS.items():
        print(f"\n{model_name}:")
        print(f"Parameters: {info['parameters']}")
        print(f"Description: {info['description']}")
        print(f"Strengths: {info['strengths']}")

    selected_model = input("\nEnter the name of the model you want to use (default: GanjinZero/biomedical-flan-t5-large): ").strip()
    if not selected_model or selected_model not in MEDICAL_QA_MODELS:
        print("Using default model: GanjinZero/biomedical-flan-t5-large")
        selected_model = "GanjinZero/biomedical-flan-t5-large"

    tokenizer = AutoTokenizer.from_pretrained(selected_model)
    model = AutoModelForSeq2SeqLM.from_pretrained(selected_model)

    if gpu_config.using_multi_gpu:
        # Multi-GPU setup
        device = torch.device(f'cuda:{rank}')
        model = model.to(device)
        model = DDP(model, device_ids=[rank])
        print(f"Using GPU {rank} in distributed mode")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        model = model.to(device)
        print("Using single GPU")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        model = model.to(device)
        print("Using MPS backend for Apple Silicon")
    else:
        device = torch.device("cpu")
        model = model.to(device)
        print("Using CPU")
    
    return model, tokenizer, device

class ParallelQualityMetrics:
    def __init__(self, gpu_config, rank=0):
        self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.nlp = spacy.load('en_core_web_sm')
        self.gpu_config = gpu_config
        self.rank = rank
        
        # Move BERT Score to appropriate device
        if gpu_config.using_multi_gpu:
            self.device = f'cuda:{rank}'
        elif torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'

    def calculate_metrics_batch(self, batch_texts, batch_contexts=None):
        """Process multiple texts in parallel"""
        if self.gpu_config.using_multi_gpu:
            # Split batch across GPUs
            local_batch_size = len(batch_texts) // self.gpu_config.n_gpus
            start_idx = local_batch_size * self.rank
            end_idx = start_idx + local_batch_size
            
            local_texts = batch_texts[start_idx:end_idx]
            local_contexts = batch_contexts[start_idx:end_idx] if batch_contexts else None
            
            # Process local batch
            local_results = [self.calculate_metrics(text, context=ctx) 
                           for text, ctx in zip(local_texts, local_contexts or [None] * len(local_texts))]
            
            # Gather results from all GPUs
            all_results = [None] * len(batch_texts)
            dist.all_gather_object(all_results, local_results)
            
            return all_results
        else:
            return [self.calculate_metrics(text, context=ctx) 
                   for text, ctx in zip(batch_texts, batch_contexts or [None] * len(batch_texts))]
    
    def calculate_metrics(self, generated_text, reference_text=None, context=None):
        metrics = {}
        
        # Content relevance (if context is provided)
        if context:
            metrics['context_relevance'] = self._calculate_context_relevance(generated_text, context)
        
        # Medical entity coverage
        metrics['medical_entities'] = self._count_medical_entities(generated_text)
        
        # Response length and complexity
        metrics['response_length'] = len(generated_text.split())
        metrics['avg_word_length'] = np.mean([len(word) for word in generated_text.split()])
        
        # Reference-based metrics (if reference is provided)
        if reference_text:
            rouge_scores = self.rouge_scorer.score(generated_text, reference_text)
            metrics['rouge1'] = rouge_scores['rouge1'].fmeasure
            metrics['rouge2'] = rouge_scores['rouge2'].fmeasure
            metrics['rougeL'] = rouge_scores['rougeL'].fmeasure
            
            # BLEU score
            reference_tokens = [reference_text.split()]
            candidate_tokens = generated_text.split()
            metrics['bleu'] = sentence_bleu(reference_tokens, candidate_tokens)
            
            # BERTScore
            P, R, F1 = score([generated_text], [reference_text], lang='en', verbose=False)
            metrics['bert_score'] = F1.mean().item()
        
        return metrics
    
    def _calculate_context_relevance(self, text, context):
        """Calculate semantic similarity between generated text and context"""
        doc1 = self.nlp(text)
        doc2 = self.nlp(context)
        return doc1.similarity(doc2)
    
    def _count_medical_entities(self, text):
        """Count medical entities in text using spaCy"""
        doc = self.nlp(text)
        medical_ents = [ent for ent in doc.ents if ent.label_ in ['DISEASE', 'CHEMICAL', 'PROCEDURE']]
        return len(medical_ents)

def parallel_rag_infer(questions, kb, index, encoder, qa_model, tokenizer, device, metrics, 
                      gpu_config, rank=0, use_context=True, top_k=3, batch_size=8):
    """Parallel RAG-enhanced medical QA inference"""
    
    results = []
    
    # Create batches
    for i in range(0, len(questions), batch_size):
        batch_questions = questions[i:i + batch_size]
        
        if use_context:
            # Parallel context retrieval
            batch_contexts = []
            for question in batch_questions:
                contexts = retrieve_contexts(question, index, encoder, kb, k=top_k)
                combined_context = "\n".join([ctx['text'] for ctx in contexts])
                batch_contexts.append(combined_context)
            
            # Prepare prompts with context
            prompts = [
                f"Answer the medical question based on the following context:\nContext: {ctx}\nQuestion: {q}\nAnswer:"
                for q, ctx in zip(batch_questions, batch_contexts)
            ]
        else:
            batch_contexts = None
            prompts = [
                f"Answer the medical question:\nQuestion: {q}\nAnswer:"
                for q in batch_questions
            ]
        
        # Tokenize batch
        inputs = tokenizer(
            prompts, 
            return_tensors="pt", 
            max_length=1024, 
            truncation=True, 
            padding=True
        ).to(device)
        
        # Generate answers in parallel
        with torch.no_grad():
            outputs = qa_model.module.generate(
                inputs.input_ids,
                max_length=256,
                num_beams=4,
                temperature=0.7,
                top_p=0.9,
                early_stopping=True
            ) if isinstance(qa_model, DDP) else qa_model.generate(
                inputs.input_ids,
                max_length=256,
                num_beams=4,
                temperature=0.7,
                top_p=0.9,
                early_stopping=True
            )
            
            batch_answers = [
                tokenizer.decode(output, skip_special_tokens=True)
                for output in outputs
            ]
        
        # Calculate metrics in parallel
        batch_metrics = metrics.calculate_metrics_batch(
            batch_answers,
            batch_contexts if use_context else None
        )
        
        # Combine results
        for q, a, m, c in zip(batch_questions, batch_answers, batch_metrics, 
                             batch_contexts if use_context else [None] * len(batch_questions)):
            results.append({
                'question': q,
                'context': c,
                'answer': a,
                'metrics': m
            })
    
    return results

def compare_rag_performance_parallel():
    """Compare and evaluate RAG vs. no-RAG performance using multiple GPUs"""
    
    # Initialize multi-GPU setup
    gpu_config = MultiGPUConfig()
    
    if gpu_config.using_multi_gpu:
        print(f"Using {gpu_config.n_gpus} GPUs")
        
        # Launch processes for each GPU
        torch.multiprocessing.spawn(
            run_distributed_comparison,
            args=(gpu_config,),
            nprocs=gpu_config.n_gpus
        )
    else:
        # Single GPU or CPU mode
        run_distributed_comparison(0, gpu_config)

def run_distributed_comparison(rank, gpu_config):
    """Run comparison on a single GPU in distributed setting"""
    
    if gpu_config.using_multi_gpu:
        gpu_config.setup_distributed(rank)
    
    # Initialize components
    qa_model, tokenizer, device = initialize_qa_model(gpu_config, rank)
    metrics = ParallelQualityMetrics(gpu_config, rank)
    
    # Load knowledge base and retrieval components
    try:
        with open("medical_knowledge_base.pkl", 'rb') as f:
            kb = pickle.load(f)
        index = faiss.read_index("faiss_index.bin")
        with open("model_name.txt", 'r') as f:
            model_name = f.read().strip()
        encoder = SentenceTransformer(model_name)
        
        if gpu_config.using_multi_gpu:
            encoder = encoder.to(f'cuda:{rank}')
    except Exception as e:
        print(f"Error loading knowledge base: {e}")
        return
    
    # Test questions
    test_questions = [
        "What are the early symptoms of diabetes?",
        "How is rheumatoid arthritis diagnosed?",
        "What are the common side effects of chemotherapy?",
        "How is high blood pressure treated?",
        "What causes migraine headaches?"
    ]
    
    if rank == 0:  # Only print on main process
        print("\nComparing RAG vs. No-RAG Performance:")
        print("=" * 80)
    
    # Generate answers with and without RAG in parallel
    rag_results = parallel_rag_infer(
        test_questions, kb, index, encoder, qa_model, tokenizer, 
        device, metrics, gpu_config, rank, use_context=True
    )
    
    no_rag_results = parallel_rag_infer(
        test_questions, kb, index, encoder, qa_model, tokenizer, 
        device, metrics, gpu_config, rank, use_context=False
    )
    
   if rank == 0:  # Only print results on main process
    print("\nResults with RAG:")
    for result in rag_results:
        print(f"\nQ: {result['question']}")
        print(f"A: {result['answer']}")
        print(f"Metrics: {result['metrics']}")
        
    print("\nResults without RAG:")
    for result in no_rag_results:
        print(f"\nQ: {result['question']}")
        print(f"A: {result['answer']}")
        print(f"Metrics: {result['metrics']}")
    
    if gpu_config.using_multi_gpu:
        dist.destroy_process_group()

if __name__ == "__main__":
    compare_rag_performance_parallel()

ImportError: cannot import name 'AutoModelForSeq2SeqGeneration' from 'transformers' (/home/azureuser/mambaforge/envs/llm/lib/python3.10/site-packages/transformers/__init__.py)

## Fine-Tuning Flan-T5 with RAG

In [25]:
from datasets import Dataset
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

def prepare_finetune_data(df):
    data_list = []
    for _, row in df.iterrows():
        # input: RAG style prompt
        input_text = (
            f"answer the medical question based on context:\n"
            f"Context: {row['context']}\n"
            f"Question: {row['question']}\n"
            f"Answer:"
        )
        # target: correct answer
        target_text = row["answer"]
        data_list.append({"input_text": input_text, "target_text": target_text})
    return pd.DataFrame(data_list)

ft_train = prepare_finetune_data(df_train_proc)
ft_val = prepare_finetune_data(df_val_proc)

train_dataset = Dataset.from_pandas(ft_train)
val_dataset = Dataset.from_pandas(ft_val)

def tokenize_function(example):
    model_inputs = tokenizer(example["input_text"], max_length=512, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example["target_text"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    max_length=512,
    label_pad_token_id=-100,
)

train_dataset = train_dataset.map(tokenize_function, batched=True)
valid_dataset = val_dataset.map(tokenize_function, batched=True)

train_dataset = train_dataset.remove_columns(["input_text", "target_text"])
val_dataset = val_dataset.remove_columns(["input_text", "target_text"])

training_args = TrainingArguments(
    output_dir="./flan_t5_medical",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs = 1,
    weight_decay=0.01,
    save_total_limit=2,
    logging_steps=100,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer, 
    data_collator=data_collator
)

trainer.train()




Map: 100%|██████████| 10178/10178 [00:03<00:00, 2731.16 examples/s]
Map: 100%|██████████| 1272/1272 [00:00<00:00, 2980.02 examples/s]
  trainer = Trainer(


RuntimeError: MPS backend out of memory (MPS allocated: 8.42 GB, other allocations: 11.98 GB, max allowed: 20.40 GB). Tried to allocate 741.00 KB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).