# Healthcare RAG System Lab
## Overview

In this lab, you'll take on the role of a junior data scientist at a healthcare technology company that specializes in creating educational resources for patients. Your team has been tasked with developing a system that can automatically generate informative responses to common patient questions about medical conditions, treatments, and wellness practices.

The challenge is to ensure these responses are both accurate and grounded in authoritative medical information. Your specific assignment is to implement a Retrieval-Augmented Generation (RAG) system that can:
1. Understand patient questions about various health topics
2. Retrieve relevant information from a trusted knowledge base
3. Generate helpful, accurate responses based on that information
4. Avoid "hallucinated" content that could potentially misinform patients

This lab follows the generative AI implementation process we've studied, with particular focus on:
- Data Strategy and Knowledge Foundation
- Model Selection and Generation Control
- Evaluation Framework Development

## Setup

First, let's import the necessary libraries:

In [13]:
import torch
import pandas as pd
import numpy as np
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


## Part 1: Knowledge Base Setup

Let's create a sample medical knowledge base with information about common health conditions, treatments, and wellness practices:

In [14]:
# Create a sample medical knowledge base
knowledge_base = pd.DataFrame({
    'content': [
        "Diabetes is a chronic condition that affects how your body turns food into energy. There are three main types: Type 1, Type 2, and gestational diabetes. Type 2 diabetes is the most common form, accounting for about 90-95% of diabetes cases.",
        "Type 1 diabetes is an autoimmune reaction that stops your body from making insulin. Symptoms include increased thirst, frequent urination, hunger, fatigue, and blurred vision. It's usually diagnosed in children, teens, and young adults.",
        "Type 2 diabetes occurs when your body becomes resistant to insulin or doesn't make enough insulin. Risk factors include being overweight, being 45 years or older, having a parent or sibling with type 2 diabetes, and being physically active less than 3 times a week.",
        "Managing diabetes involves monitoring blood sugar levels, taking medications as prescribed, eating a healthy diet, maintaining a healthy weight, and getting regular physical activity. It's important to work with healthcare providers to develop a management plan.",
        "Hypertension, or high blood pressure, is when the force of blood pushing against the walls of your arteries is consistently too high. It's often called the 'silent killer' because it typically has no symptoms but significantly increases the risk of heart disease and stroke.",
        "Blood pressure is measured using two numbers: systolic (top number) and diastolic (bottom number). Normal blood pressure is less than 120/80 mm Hg. Hypertension is diagnosed when readings are consistently 130/80 mm Hg or higher.",
        "Lifestyle changes to manage hypertension include reducing sodium in your diet, getting regular physical activity, maintaining a healthy weight, limiting alcohol, quitting smoking, and managing stress. Medications may also be prescribed if lifestyle changes aren't enough.",
        "Regular physical activity offers numerous health benefits, including weight management, reduced risk of heart disease, strengthened bones and muscles, improved mental health, and enhanced ability to perform daily activities. Adults should aim for at least 150 minutes of moderate-intensity activity per week.",
        "A balanced diet should include a variety of fruits, vegetables, whole grains, lean proteins, and healthy fats. It's recommended to limit intake of added sugars, sodium, saturated fats, and processed foods. Proper nutrition helps prevent chronic diseases and supports overall health.",
        "Vaccination is one of the most effective ways to prevent infectious diseases. Vaccines work by helping the body recognize and fight specific pathogens. Common adult vaccines include influenza (flu), Tdap (tetanus, diphtheria, pertussis), shingles, and pneumococcal vaccines."
    ],
    'metadata': [
        {'topic': 'diabetes', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type1', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type2', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'hypertension', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'diagnosis', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'wellness', 'subtopic': 'physical_activity', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'wellness', 'subtopic': 'nutrition', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'prevention', 'subtopic': 'vaccination', 'source': 'medical_guidelines', 'last_updated': '2023-08-05'}
    ]
})

print(f"Knowledge base loaded with {len(knowledge_base)} entries")
knowledge_base.head(2)

Knowledge base loaded with 10 entries


Unnamed: 0,content,metadata
0,Diabetes is a chronic condition that affects h...,"{'topic': 'diabetes', 'subtopic': 'overview', ..."
1,Type 1 diabetes is an autoimmune reaction that...,"{'topic': 'diabetes', 'subtopic': 'type1', 'so..."


### Task 1: Create Document Embeddings

Complete the function below to create embeddings for each document in the knowledge base. These embeddings will be used to find relevant documents based on patient queries.

In [15]:

def create_document_embeddings(documents):
    """
    Create embeddings for a list of documents.
    
    Args:
        documents: List of text documents to embed
        
    Returns:
        Numpy array of document embeddings
    """
    # Initialize a sentence transformer model
    embedding_model = SentenceTransformer('all-mpnet-base-v2')
    
    # Generate embeddings for all documents
    # convert_to_numpy=True returns a NumPy array; show_progress_bar gives feedback
    document_embeddings = embedding_model.encode(
        documents,
        convert_to_numpy=True,
        show_progress_bar=True
    )
    
    return document_embeddings

# Extract document content
documents = knowledge_base['content'].tolist()

# Create document embeddings
document_embeddings = create_document_embeddings(documents)

# Verify the shape of embeddings
if document_embeddings is not None:
    print(f"Generated embeddings with shape: {document_embeddings.shape}")
else:
    print("Embeddings not created yet.")


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generated embeddings with shape: (10, 768)


## Part 2: Implementing the Retrieval Component

Now, let's implement the function to retrieve relevant documents based on a patient query.

In [16]:
def retrieve_documents(query, embeddings, contents, metadata, top_k=3, threshold=0.3):
    """
    Retrieve the most relevant documents for a given query.
    
    Args:
        query: The patient's question
        embeddings: The precomputed document embeddings
        contents: The text content of the documents
        metadata: The metadata for each document
        top_k: Maximum number of documents to retrieve
        threshold: Minimum similarity score to include a document
        
    Returns:
        List of (content, metadata, similarity_score) tuples
    """
    # Get or initialize the embedding model (same as in create_document_embeddings)
    embedding_model = SentenceTransformer('all-mpnet-base-v2')
    
    # Embed the query
    query_embedding = embedding_model.encode(
        [query],            # single‐item list
        convert_to_numpy=True
    )
        
    # Calculate similarity scores between query and all documents
    # Hint: Use cosine_similarity
    similarities = cosine_similarity(query_embedding, embeddings)[0]
    
    # Filter by threshold and get top k results
    # Hint: Use list comprehension, sorting, and slicing
    valid_idxs = np.where(similarities >= threshold)[0]
    top_idxs = valid_idxs[np.argsort(similarities[valid_idxs])[::-1]][:top_k]

    # Return the top documents with their metadata and scores
    results = [
        (contents[i], metadata[i], float(similarities[i]))
        for i in top_idxs
    ]
    return results



# Test the retrieval function with a sample query
if document_embeddings is not None:
    sample_query = "What are the symptoms of Type 1 diabetes?"
    retrieved_docs = retrieve_documents(
        query=sample_query,
        embeddings=document_embeddings,
        contents=documents,
        metadata=knowledge_base['metadata'].tolist(),
        top_k=2
    )
    
    print(f"Query: {sample_query}")
    print("\nRetrieved Documents:")
    for i, (content, meta, score) in enumerate(retrieved_docs):
        print(f"{i+1}. [{score:.4f}] {content[:100]}...")
        print(f"   Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
else:
    print("Cannot test retrieval without document embeddings.")

Query: What are the symptoms of Type 1 diabetes?

Retrieved Documents:
1. [0.7585] Type 1 diabetes is an autoimmune reaction that stops your body from making insulin. Symptoms include...
   Topic: diabetes, Subtopic: type1
2. [0.4625] Diabetes is a chronic condition that affects how your body turns food into energy. There are three m...
   Topic: diabetes, Subtopic: overview


## Part 3: Building the Generation Component

Now, let's implement the generation component that will use the retrieved documents to create informative responses.

In [17]:
# Initialize the generative model
def initialize_generator(model_name="gpt2"):
    """
    Initialize the generative model and tokenizer.
    
    Args:
        model_name: Name of the pretrained model to use
        
    Returns:
        Tokenizer and model objects
    """
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model     = AutoModelForCausalLM.from_pretrained(model_name)
    
    
    # Set padding token if needed
    # Check if pad_token exists, if not set it to eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
    
    return tokenizer, model

# Initialize the generator
tokenizer, model = initialize_generator()
if tokenizer and model:
    print(f"Initialized {model.config._name_or_path} with {model.num_parameters()} parameters")

Initialized gpt2 with 124439808 parameters


In [18]:
import torch

def generate_rag_response(query, contents, metadata, document_embeddings, tokenizer, model, max_length=100):
    """
    Generate a response using Retrieval-Augmented Generation.
    
    Args:
        query: The patient's question
        contents: List of document contents
        metadata: List of document metadata
        document_embeddings: Precomputed embeddings for the documents
        tokenizer: The tokenizer for the language model
        model: The language model for generation
        max_length: Maximum response length
        
    Returns:
        Dictionary with the generated response and the retrieved documents
    """
    # Retrieve relevant documents for the query
    retrieved_docs = retrieve_documents(
        query=query,
        embeddings=document_embeddings,
        contents=contents,
        metadata=metadata,
        top_k=3,
        threshold=0.3
    )
    
    # Format prompt with retrieved context
    if retrieved_docs:
        context = "\n\n".join([doc for doc, _, _ in retrieved_docs])
        prompt = f"Context:\n{context}\n\nQ: {query}\nA:"
    else:
        prompt = f"Q: {query}\nA:"
    
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    
    # Generate the response
    output_sequences = model.generate(
        **inputs,
        max_new_tokens=max_length,
        temperature=0.7,
        top_k=50,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode the response and extract the generated text
    full_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
    # Remove the prompt from the generated text, keeping only the answer
    response = full_text.split("A:")[-1].strip()
    
    # Return the results
    return {
        "query": query,
        "response": response,
        "retrieved_documents": retrieved_docs
    }

# Test the RAG system with several queries
if document_embeddings is not None and tokenizer and model:
    test_queries = [
        "What are the different types of diabetes?",
        "How can I manage my high blood pressure through lifestyle changes?",
        "Why is regular physical activity important for health?",
        "What vaccines should adults consider getting?"
    ]
    
    for query in test_queries:
        print(f"\nQuery: {query}")
        result = generate_rag_response(
            query=query,
            contents=documents,
            metadata=knowledge_base['metadata'].tolist(),
            document_embeddings=document_embeddings,
            tokenizer=tokenizer,
            model=model
        )
        
        print("\nRetrieved Documents:")
        for i, (doc, meta, score) in enumerate(result["retrieved_documents"], 1):
            print(f"{i}. [{score:.4f}] Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
        
        print(f"\nGenerated Response:\n{result['response']}")
        print("-" * 80)
else:
    print("Cannot test generation without embeddings or model.")



Query: What are the different types of diabetes?

Retrieved Documents:
1. [0.7130] Topic: diabetes, Subtopic: overview
2. [0.6430] Topic: diabetes, Subtopic: type1
3. [0.6042] Topic: diabetes, Subtopic: type2

Generated Response:
Diabetes can affect your ability to get insulin. If you have diabetes, your doctor can't help you get insulin. If you have a medical condition that can cause diabetes
--------------------------------------------------------------------------------

Query: How can I manage my high blood pressure through lifestyle changes?

Retrieved Documents:
1. [0.7775] Topic: hypertension, Subtopic: management
2. [0.4690] Topic: hypertension, Subtopic: overview
3. [0.4077] Topic: hypertension, Subtopic: diagnosis

Generated Response:
There are three ways to manage hypertension:

You can take steps to reduce blood pressure through diet or exercise.

You can take steps to reduce blood pressure through lifestyle changes.

You can take steps to manage blood pressure through phy

## Part 4: Evaluation and Analysis

Let's implement a basic evaluation function to assess the quality of our generated responses.

In [19]:
def evaluate_response(response_data):
    """
    Evaluate the quality of a generated response based on various criteria.
    
    Args:
        response_data: Dictionary containing the query, response, and retrieved docs
        
    Returns:
        Evaluation metrics
    """
    query    = response_data['query']
    response = response_data['response']
    docs     = response_data['retrieved_documents']
    
    # 1) Content relevance: fraction of unique words from retrieved docs that appear in the response
    doc_text   = " ".join([doc for doc, _, _ in docs])
    doc_words  = set(re.findall(r'\w+', doc_text.lower()))
    resp_words = set(re.findall(r'\w+', response.lower()))
    common     = doc_words & resp_words
    relevance_score = len(common) / max(len(doc_words), 1)
    
    # 2) Response length appropriateness: ideal between 1× and 3× query length
    resp_len  = len(response.split())
    query_len = len(query.split())
    if query_len == 0:
        length_score = 0.0
    elif query_len <= resp_len <= 3 * query_len:
        length_score = 1.0
    else:
        length_score = min(resp_len / (3 * query_len), resp_len / query_len)
    
    # 3) Medical terminology usage: fraction of key terms included
    medical_terms = [
        "diabetes", "insulin", "glucose", "hypertension", "blood pressure",
        "systolic", "diastolic", "cardiovascular", "cholesterol", "nutrition",
        "obesity", "physical activity", "vaccination", "immune", "prevention"
    ]
    used_terms = [t for t in medical_terms if t in response.lower()]
    usage_score = len(used_terms) / len(medical_terms)
    
    metrics = {
        "relevance_score":     round(relevance_score,  3),
        "length_score":        round(length_score,     3),
        "medical_usage_score": round(usage_score,      3),
        "response_length":     resp_len,
        "used_terms":          used_terms
    }
    
    return metrics


result = generate_rag_response(
    query=sample_query,
    contents=documents,
    metadata=knowledge_base['metadata'].tolist(),
    document_embeddings=document_embeddings,
    tokenizer=tokenizer,
    model=model
)
metrics = evaluate_response(result)
print("Evaluation for query:", result['query'])
for name, score in metrics.items():
    print(f"  {name}: {score}")


Evaluation for query: What are the symptoms of Type 1 diabetes?
  relevance_score: 0.191
  length_score: 3.417
  medical_usage_score: 0.2
  response_length: 82
  used_terms: ['diabetes', 'insulin', 'glucose']


## Reflection Questions

Answer the following questions about your RAG implementation and its potential applications in healthcare:

### How does the RAG approach improve factual accuracy compared to regular generation?

RAG forces the model to base its answer on real, pre-indexed medical texts instead of pulling facts out of thin air. By retrieving the top-k relevant passages and using those as context, you dramatically cut down on hallucinations. It’s not just “what the LLM thinks is plausible,” it’s “what’s literally in our trusted knowledge base.”

### What are potential challenges or limitations of your current implementation?

Coverage gaps: If the KB misses a topic, the model can’t retrieve the right info and will default back to guesswork.

Retrieval errors: Poor embeddings or an off-threshold can pull in irrelevant docs, which drags the response off course.

Hallucination fallback: Even with RAG, the generator can still manufacture details when context isn’t clear.



### How might you enhance this system for a production healthcare environment?

Curated, versioned KB: Hook into audited medical guidelines (e.g. FDA labels, peer-reviewed summaries), keep it up to date.

Monitoring & feedback: Log user interactions, flag low-confidence outputs for human review, continuously retrain.

### What ethical considerations are particularly important for healthcare content generation?

So many: 

Transparency & disclaimers: Clearly label “AI-generated” and include “This is educational only—consult your provider.”

Privacy & HIPAA: Don’t leak or store sensitive patient info in logs, and ensure data handling meets compliance standards.

Bias & fairness: Watch out for under-represented conditions or demographics.

Accountability: Maintain audit trails of which documents were retrieved, so you can trace back any mistakes.