## Retrieval-Augmented Generation (RAG) Proof of Concept on Healthcare Mortality Data

This notebook demonstrates a **proof of concept (POC)** for building a RAG pipeline. 

The goal: allow users to ask natural language questions (e.g., *"What was the heart disease mortality rate in Texas in 2019?"*) and get grounded answers **only from structured healthcare data**.

**Key steps:**
1. Convert structured rows into natural-language "facts"
2. Embed facts into vector space using `SentenceTransformers`
3. Store embeddings in a FAISS index for efficient retrieval
4. Given a query, retrieve the most relevant facts
5. Pass facts + query into an LLM (via Ollama) for grounded generation

This is a minimal POC to test feasibility before scaling up with bigger datasets.

### Setup

We load required libraries:

- sentence_transformers for embeddings
- faiss for vector similarity search
- ollama for running LLMs - must be installed in your system

In [1]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import subprocess
import json

  from .autonotebook import tqdm as notebook_tqdm


### Turn Data into Facts

We use a small dataset of healthcare mortality rates. Each row will be converted into a "fact" — a natural-language string describing the data point.

This makes the data easier for LLMs to consume.

In [2]:
df = pd.read_csv("../poc/data/Underlying Cause of Death, 1999-2020.xls", sep="\t")
df = df.rename(
    columns = {
        'State':'state',
        'Year':'year',
        'ICD-10 113 Cause List':'cause',
        'Deaths':'deaths',
        'Population':'population',
        'Crude Rate':'crude_rate'
    }
)
df = df[['state','year','cause','deaths','population','crude_rate']]
df['crude_rate'] = pd.to_numeric(df['crude_rate'], errors='coerce')
df['year'] = pd.to_numeric(df['year'], errors='coerce')
df = df.astype({
    "deaths" : "float64",
    "population": "float64"
    }    
)
df['cause'] = (
    df['cause']
    .str.replace(r"\(.*?\)","",regex=True)
    .str.replace(r"#","",regex=True)
    .str.strip()
) 
df['crude_rate'] = np.round(df['deaths'] / df['population'] * 1e5,1)
df.head()

Unnamed: 0,state,year,cause,deaths,population,crude_rate
0,Alabama,2015.0,Certain other intestinal infections,159.0,4858979.0,3.3
1,Alabama,2015.0,Tuberculosis,11.0,4858979.0,0.2
2,Alabama,2015.0,Septicemia,1046.0,4858979.0,21.5
3,Alabama,2015.0,Viral hepatitis,96.0,4858979.0,2.0
4,Alabama,2015.0,Human immunodeficiency virus disease,126.0,4858979.0,2.6


In [3]:
facts = []

for i, row in df.iterrows():
    fact = f"In {row['state']} in {row['year']}, the {row['cause']} mortality rate was {row['crude_rate']} per 100,000."
    facts.append(fact)

print(facts[:5])

['In Alabama in 2015.0, the Certain other intestinal infections mortality rate was 3.3 per 100,000.', 'In Alabama in 2015.0, the Tuberculosis mortality rate was 0.2 per 100,000.', 'In Alabama in 2015.0, the Septicemia mortality rate was 21.5 per 100,000.', 'In Alabama in 2015.0, the Viral hepatitis mortality rate was 2.0 per 100,000.', 'In Alabama in 2015.0, the Human immunodeficiency virus  disease mortality rate was 2.6 per 100,000.']


### Embeddings + FAISS Index

We use SentenceTransformers to embed each fact into a high-dimensional vector. These embeddings capture the semantic meaning of each fact. We then store them in a FAISS index, which allows fast similarity search.

In [4]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
fact_embeddings = embedder.encode(facts, convert_to_numpy=True)
print(fact_embeddings.shape)

dimension = fact_embeddings.shape[1]
# set up a search engine that compares vectors by distance
index = faiss.IndexFlatL2(dimension)
# load all the fact embeddings into the search engine so they can be queried later
index.add(fact_embeddings.astype("float32"))

(18805, 384)


### Why convert rows to "facts"?

An LLM is a language model. It doesn't understand structured databases directly (like rows in a SQL table). It understands text. So if you give it a CSV row, it has no natural way to reason about that. But if you turn it into text like: "In Texas in 2019, the heart disease mortality rate was 153 per 100,000.", the model can now read and reason in its native format: text. That's why we transform structured data into sentences or "facts". This makes it possible to do semantic search. If a user asks: "What was the heart disease mortality rate in Texas in 2019?", the embedding of that query will be very close to the embedding of the fact string.

### What is RAG?

RAG means Retrieval-Augmented Generation. Imagine you ask an LLM "What was the heart disease mortality rate in Texas in 2019?" The model doesn't actually have your dataset inside it. It will just "guess" based on patterns it learned during training. That guess might be wrong, outdated, or incomplete.

RAG means:

- Retrieve: Before answering, look up relevant information from an external knowledge base (like your facts in FAISS)
- Augment: Add that fact into the prompt given to the LLM. So the prompt will have both the CONTEXT + QUESTION
- Generate: The LLM now produces the answer using both its language skills and the retrieved context. So instead of guessing, it reads from your actual data.

Why it matters:

✅ Keeps answers grounded in your data

✅ Works with private datasets the LLM has never seen


✅ Prevents hallucinations (model making stuff up)

### Query + Retrieval

Given a user query, we:

- Embed the query in the same vector space
- Search FAISS for the closest facts (this one is above)
- Return the top K most relevant facts

In [13]:
query = "What are the top 3 states for diabetes mortality in 2017?"

query_embeddings = embedder.encode([query], convert_to_numpy=True)
D, I = index.search(query_embeddings.astype("float32"), k=10)
retrieved_facts = [facts[i] for i in I[0]]

print(retrieved_facts)

# Augment + Generate

prompt = f"""
    You are a helpful assistant. Use only the facts provided to answer the question. 
    If the answer is not in the facts, say "I don't have data on that."

    Facts:
    {retrieved_facts}

    Question:

    {query}

    Answer:
"""

process = subprocess.run(
    ["ollama", "run", "mistral", prompt],
    capture_output=True,
    text=True
)

print(process.stdout)

['In Pennsylvania in 2017.0, the Diabetes mellitus mortality rate was 28.9 per 100,000.', 'In Massachusetts in 2017.0, the Diabetes mellitus mortality rate was 19.3 per 100,000.', 'In Pennsylvania in 2016.0, the Diabetes mellitus mortality rate was 27.8 per 100,000.', 'In Montana in 2017.0, the Diabetes mellitus mortality rate was 27.8 per 100,000.', 'In Pennsylvania in 2015.0, the Diabetes mellitus mortality rate was 29.5 per 100,000.', 'In Pennsylvania in 2020.0, the Diabetes mellitus mortality rate was 33.2 per 100,000.', 'In Pennsylvania in 2018.0, the Diabetes mellitus mortality rate was 28.1 per 100,000.', 'In Washington in 2017.0, the Diabetes mellitus mortality rate was 24.5 per 100,000.', 'In Arkansas in 2017.0, the Diabetes mellitus mortality rate was 39.3 per 100,000.', 'In Massachusetts in 2015.0, the Diabetes mellitus mortality rate was 20.6 per 100,000.']


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


 In the provided data for the year 2017, the three states with the highest Diabetes mellitus mortality rates are Arkansas (39.3 per 100,000), Pennsylvania (28.9 per 100,000), and Montana (27.8 per 100,000). These states are listed in order of decreasing mortality rate.




In [11]:
df[(df['year']==2017)&(df['cause'].str.contains("diabetes", case=False, na=False))]

Unnamed: 0,state,year,cause,deaths,population,crude_rate
159,Alabama,2017.0,Diabetes mellitus,1173.0,4874747.0,24.1
499,Alaska,2017.0,Diabetes mellitus,130.0,739795.0,17.6
836,Arizona,2017.0,Diabetes mellitus,2054.0,7016270.0,29.3
1217,Arkansas,2017.0,Diabetes mellitus,1180.0,3004279.0,39.3
1604,California,2017.0,Diabetes mellitus,9595.0,39536653.0,24.3
1998,Colorado,2017.0,Diabetes mellitus,1017.0,5607154.0,18.1
2371,Connecticut,2017.0,Diabetes mellitus,694.0,3588184.0,19.3
2726,Delaware,2017.0,Diabetes mellitus,244.0,961939.0,25.4
3050,District of Columbia,2017.0,Diabetes mellitus,138.0,693972.0,19.9
3394,Florida,2017.0,Diabetes mellitus,6172.0,20984400.0,29.4


## Observation: Effect of `k` on Query Accuracy

During testing, an interesting pattern emerged:

- For **single-entity lookup questions** (e.g., *"What was the heart disease mortality rate in Texas in 2019?"*), the model produced accurate answers when `k` (the number of retrieved facts) was set to a small value (e.g., `k=1–3`).  
  A small `k` helps because only one fact is relevant, and extra rows introduce noise.

- For **aggregation or comparison questions** (e.g., *"What are the top 3 states in 2019 by mortality rate?"*), a small `k` failed.  
  The model only retrieved a few facts, so it lacked the full set of states needed to compute "top 3."  
  In these cases, a larger `k` (e.g., `k=50`) is necessary.

### Why this happens
The choice of `k` depends on the **type of query**:
- **Lookup queries** → need just 1 fact (small `k`).
- **Aggregation/comparison queries** → need many facts to cover the data slice (large `k`).

### Possible solutions
1. **Query classification**  
   Detect whether the query is a lookup or aggregation.  
   - Lookup mode → use small `k`  
   - Aggregation mode → use larger `k`

2. **Filtering before retrieval**  
   Restrict by query terms (e.g., year=2019) before embedding search.  
   This avoids pulling irrelevant rows while still returning enough facts.

3. **Hybrid approach**  
   For precise lookups, use embeddings + facts.  
   For aggregations, consider passing the query through a SQL-like execution step instead of relying only on the LLM.

---


In [None]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
fact_embeddings = embedder.encode(facts, convert_to_numpy=True)
print(fact_embeddings.shape)

dimension = fact_embeddings.shape[1]
# set up a search engine that compares vectors by distance
index = faiss.IndexFlatL2(dimension)
# load all the fact embeddings into the search engine so they can be queried later
index.add(fact_embeddings.astype("float32"))

In [5]:
def rag_pipeline(query):
    

In [6]:
classify_query("What are the top 3 states for diabetes mortality in 2017?")

what are the top 3 states for diabetes mortality in 2017?
