# Prerequisites
* PostgreSQL reachable at localhost:5432 with a database named patient_db and credentials matching the connection cell (user="kevin", password="password123"). Install the pgvector extension in this database.
* Python environment capable of installing packages listed below (the notebook relies on pip inside the runtime).
* Data files ./patients.csv and ./allergies.csv present relative to the notebook.
* Ollama running locally on the default port 11434 with the phi4-mini embedding model pulled and ready (ollama pull phi4-mini).

# 1. Install dependencies

In [1]:
!pip install psycopg2-binary pandas numpy matplotlib plotly faker requests umap-learn tabulate python-dotenv


Collecting psycopg2-binary
  Downloading psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (4.9 kB)
Collecting pandas
  Using cached pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (91 kB)
Collecting numpy
  Downloading numpy-2.3.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting plotly
  Downloading plotly-6.4.0-py3-none-any.whl.metadata (8.5 kB)
Collecting faker
  Downloading faker-37.12.0-py3-none-any.whl.metadata (15 kB)
Collecting umap-learn
  Downloading umap_learn-0.5.9.post2-py3-none-any.whl.metadata (25 kB)
Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Using cached contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5

# 2. Imports

In [2]:
from dotenv import load_dotenv
import os, requests
import psycopg2
import pandas as pd
import numpy as np
from io import StringIO
import plotly.express as px
load_dotenv()


# 3. Connect to PostgreSQL with pgvector

In [3]:
conn = psycopg2.connect(
    dbname="patient_db",
    user="kevin",
    password="password123",
    host="localhost",
    port=6432
)
cur = conn.cursor()
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
conn.commit()


# 4. Load CSVs

In [4]:
patients_df = pd.read_csv("./data/patients.csv")
allergies_df = pd.read_csv("./data/allergies.csv")

print("Patients:")
display(patients_df.head())

print("Allergies:")
display(allergies_df.head())


Patients:


Unnamed: 0,patient_id,vista_id,mrn,first_name,last_name,middle_name,gender,birthdate,age,race,...,genetic_markers,precision_markers,comorbidity_profile,care_plan_total,care_plan_completed,care_plan_overdue,care_plan_scheduled,deceased,death_date,death_primary_cause
0,a7896ff7-3cfb-4a4f-ae92-02f58d2d691e,1168769,MRN893749,Kari,Hutchinson,D,female,2018-11-01,7,White,...,[],[],[],0,0,0,0,False,,
1,c89b6067-c181-412f-be7f-5ee597961ad7,9206201,MRN955810,Joseph,Swanson,W,male,1935-11-22,90,White,...,[],[],[],3,0,3,0,True,2025-10-30,"Acute myocardial infarction, unspecified"
2,748503cd-2632-4e29-8c06-558ea9d64548,8857864,MRN854471,Mark,Henderson,B,other,1989-11-08,36,White,...,[],[],[],0,0,0,0,False,,
3,7701558f-0d71-492b-ae3c-e9e93aa79b86,6727969,MRN501764,Tina,Castillo,A,male,1965-11-14,60,White,...,[],[],"[{""primary"": ""Hypertension"", ""associated"": ""Di...",6,1,5,0,True,2021-10-31,"Intentional self-harm by unspecified means, in..."
4,3d0a4feb-f79e-4624-9c33-c2eb1c90489e,9256242,MRN248802,Jason,Smith,K,male,2001-11-05,24,Asian,...,[],[],[],0,0,0,0,False,,


Allergies:


Unnamed: 0,allergy_id,patient_id,substance,category,reaction,reaction_code,reaction_system,severity,severity_code,severity_system,rxnorm_code,unii_code,snomed_code,risk_level,registry_source,recorded_date,followup_summary
0,d8141746-f50b-4636-ba63-e7df88496831,a7896ff7-3cfb-4a4f-ae92-02f58d2d691e,Acremonium Strictum 50 Mg/Ml Injectable Solution,environment,Angioedema,41291007,http://snomed.info/sct,mild,255604002,http://snomed.info/sct,905073.0,,,standard,warehouse,2021-04-22,risk: standard | severity: mild
1,292baf74-fac6-4862-b3c8-720d1460ea3c,c89b6067-c181-412f-be7f-5ee597961ad7,Administration Of First Dose Of Vaccine Produc...,drug,Nausea,422587007,http://snomed.info/sct,mild,255604002,http://snomed.info/sct,416591003.0,,416591003.0,standard,warehouse,1956-12-14,risk: standard | severity: mild
2,5cf42c32-c2ed-451c-ad42-da9c600a001f,c89b6067-c181-412f-be7f-5ee597961ad7,Buckwheat 100 Mg/Ml Injectable Solution,food,Nausea,422587007,http://snomed.info/sct,mild,255604002,http://snomed.info/sct,904800.0,,,standard,warehouse,1969-09-19,risk: standard | severity: mild
3,52d3a0b5-5a55-4a54-8dd2-4589ca069618,748503cd-2632-4e29-8c06-558ea9d64548,Banana 100 Mg/Ml Injectable Solution,food,Urticaria,126485001,http://snomed.info/sct,mild,255604002,http://snomed.info/sct,891833.0,,,standard,warehouse,1995-11-08,risk: standard | severity: mild
4,3a3027b9-8477-447b-971e-6dbd3fbdac0c,748503cd-2632-4e29-8c06-558ea9d64548,Bee Venom,insect,Wheezing,56018004,http://snomed.info/sct,moderate,6736007,http://snomed.info/sct,,,,high,curated,1993-12-11,risk: high | severity: moderate


# 5. Create normalized tables

In [5]:
conn.rollback()

cur.execute("""
DROP TABLE IF EXISTS allergies CASCADE;
DROP TABLE IF EXISTS patients CASCADE;

CREATE TABLE patients (
    patient_id TEXT PRIMARY KEY,
    vista_id TEXT,
    mrn TEXT,
    first_name TEXT,
    last_name TEXT,
    middle_name TEXT,
    gender TEXT,
    birthdate TEXT,          -- changed from DATE to TEXT for flexibility
    age INT,
    race TEXT,
    ethnicity TEXT,
    address TEXT,
    city TEXT,
    state TEXT,
    zip TEXT,
    country TEXT,
    phone TEXT,
    email TEXT,
    marital_status TEXT,
    language TEXT,
    insurance TEXT,
    ssn TEXT,
    smoking_status TEXT,
    alcohol_use TEXT,
    education TEXT,
    employment_status TEXT,
    income TEXT,
    housing_status TEXT,
    sdoh_risk_score FLOAT,
    sdoh_risk_factors TEXT,
    community_deprivation_index FLOAT,
    access_to_care_score FLOAT,
    transportation_access TEXT,
    language_access_barrier TEXT,
    social_support_score FLOAT,
    sdoh_care_gaps TEXT,
    genetic_risk_score FLOAT,
    genetic_markers TEXT,
    precision_markers TEXT,
    comorbidity_profile TEXT,
    care_plan_total INT,
    care_plan_completed INT,
    care_plan_overdue INT,
    care_plan_scheduled INT,
    deceased BOOLEAN,
    death_date TEXT,         -- also TEXT now
    death_primary_cause TEXT
);

CREATE TABLE allergies (
    allergy_id TEXT PRIMARY KEY,
    patient_id TEXT,
    substance TEXT,
    category TEXT,
    reaction TEXT,
    reaction_code TEXT,
    reaction_system TEXT,
    severity TEXT,
    severity_code TEXT,
    severity_system TEXT,
    rxnorm_code TEXT,
    unii_code TEXT,
    snomed_code TEXT,
    risk_level TEXT,
    registry_source TEXT,
    recorded_date TEXT,
    followup_summary TEXT
);
""")
conn.commit()
print("‚úÖ Tables recreated successfully")


‚úÖ Tables recreated successfully


# 6. Bulk-insert dataframes into Postgres

In [6]:
from io import StringIO

def copy_dataframe(df, table_name):
    # Make sure nulls are proper SQL NULLs
    df = df.replace({np.nan: None})

    # Write DataFrame to CSV buffer with quoting handled by pandas
    buffer = StringIO()
    df.to_csv(
        buffer,
        index=False,
        header=False,
        sep=",",
        quoting=1,  # csv.QUOTE_ALL
        escapechar="\\"
    )
    buffer.seek(0)

    try:
        cur.copy_expert(
            sql=f"COPY {table_name} FROM STDIN WITH (FORMAT CSV, HEADER FALSE, DELIMITER ',', QUOTE '\"', ESCAPE '\\')",
            file=buffer
        )
        conn.commit()
        print(f"‚úÖ Loaded {len(df)} rows into {table_name}")
    except Exception as e:
        conn.rollback()
        print(f"‚ùå Error loading {table_name}: {e}")

# --- Clean and normalize the patients dataframe ---
patients_df = patients_df.replace({np.nan: None})  # Convert NaN to None

# Normalize date columns: replace blanks with None, cast to str for CSV writing
for col in ["birthdate", "death_date"]:
    patients_df[col] = patients_df[col].apply(lambda x: None if pd.isna(x) or x == "" else str(x))

# --- Filter allergies to valid patient_ids ---
allergies_df = allergies_df[allergies_df["patient_id"].isin(patients_df["patient_id"])].copy()

# Confirm how many rows remain
print(f"Patients: {len(patients_df)} | Allergies (filtered): {len(allergies_df)}")
# Run it
copy_dataframe(patients_df, "patients")
copy_dataframe(allergies_df, "allergies")


Patients: 1000 | Allergies (filtered): 1026
‚úÖ Loaded 1000 rows into patients
‚úÖ Loaded 1026 rows into allergies


# 7. Join data into patient ‚Äúcontext‚Äù paragraphs

In [7]:
# --- Step 7 (replace completely) ---
query = """
WITH agg AS (
  SELECT
    p.patient_id,
    p.first_name, p.last_name, p.gender, p.age, p.race, p.ethnicity,
    p.sdoh_risk_score,
    CASE
      WHEN p.sdoh_risk_score >= 0.70 THEN 'HIGH'
      WHEN p.sdoh_risk_score >= 0.40 THEN 'MEDIUM'
      ELSE 'LOW'
    END AS sdoh_bucket,
    COALESCE(p.insurance, 'Unknown') AS insurance,
    COALESCE(p.smoking_status, 'Unknown') AS smoking_status,
    COALESCE(p.deceased, FALSE) AS deceased,
    NULLIF(p.death_date, '') AS death_date_raw,
    COUNT(a.allergy_id) AS allergy_count,
    MAX(CASE WHEN LOWER(COALESCE(a.severity,'')) IN ('severe','high','life-threatening') THEN 1 ELSE 0 END) AS any_severe_allergy,
    STRING_AGG(a.substance || ' (' || COALESCE(a.severity,'unknown') || ')', '; ' ORDER BY a.substance) AS allergy_list
  FROM patients p
  LEFT JOIN allergies a ON p.patient_id = a.patient_id
  GROUP BY p.patient_id, p.first_name, p.last_name, p.gender, p.age, p.race, p.ethnicity,
           p.sdoh_risk_score, insurance, smoking_status, deceased, p.death_date
)
SELECT
  patient_id,
  CONCAT(
    'Patient: ', first_name, ' ', last_name, '. ',
    'Gender: ', COALESCE(gender,'unknown'), '; Age: ', COALESCE(age::text,'unknown'), '. ',
    'Race: ', COALESCE(race,'unspecified'), '; Ethnicity: ', COALESCE(ethnicity,'unspecified'), '. ',
    'SDOH: ', sdoh_bucket, ' (', COALESCE(sdoh_risk_score::text,'n/a'), '). ',
    'Insurance: ', insurance, '; Smoking: ', smoking_status, '. ',
    'Deceased: ', CASE WHEN deceased THEN 'yes' ELSE 'no' END,
    CASE WHEN deceased AND death_date_raw IS NOT NULL THEN CONCAT(' (death_date=', death_date_raw, ')') ELSE '' END, '. ',
    'Allergies: ', COALESCE(allergy_list,'none'), '. ',
    'Allergy_count: ', allergy_count::text, '; Any_severe_allergy: ', any_severe_allergy::text, '.'
  ) AS context_text
FROM agg;
"""

df_context = pd.read_sql(query, conn)
print(f"‚úÖ Created {len(df_context)} patient context records with labeled facts")
df_context.head(2)


‚úÖ Created 1000 patient context records with labeled facts


  df_context = pd.read_sql(query, conn)


Unnamed: 0,patient_id,context_text
0,004b6780-7827-4030-a9e7-b7c9f44b2fab,Patient: John Weaver. Gender: female; Age: 54....
1,00840f4a-78ed-4b56-99d4-b3b375ccc277,Patient: Andrew Lopez. Gender: other; Age: 115...


# 8a. Configure embedding backend (Ollama or LLM API)

Set `EMBEDDING_BACKEND` to `"ollama"` for the local Ollama workflow or `"llm_api"` to call an OpenAPI-compatible embedding service. When using the hosted path, export `OPENAI_API_KEY` (and optionally `OPENAI_BASE_URL`, `OPENAI_EMBED_MODEL`, `OPENAI_ORG`) before running the cell. Use `EMBEDDING_DIM` if the remote model emits vectors larger/smaller than the pgvector column, and set `SKIP_EMBEDDING_SMOKETEST=1` to bypass the quick connectivity check.


In [11]:
EMBEDDING_BACKEND = os.getenv('EMBEDDING_BACKEND', 'ollama').strip().lower()
if EMBEDDING_BACKEND not in {'ollama', 'llm_api'}:
    raise ValueError(f'Unsupported EMBEDDING_BACKEND: {EMBEDDING_BACKEND}')

OLLAMA_EMBED_MODEL = os.getenv('OLLAMA_EMBED_MODEL', 'phi4-mini')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1').rstrip('/')
OPENAI_EMBED_MODEL = os.getenv('OPENAI_EMBED_MODEL', 'text-embedding-3-small')
OPENAI_ORG = os.getenv('OPENAI_ORG')
EMBEDDING_TIMEOUT = int(os.getenv('EMBEDDING_TIMEOUT', '30'))
EMBEDDING_DIM_OVERRIDE = os.getenv('EMBEDDING_DIM')
EMBEDDING_DIM_OVERRIDE = int(EMBEDDING_DIM_OVERRIDE) if EMBEDDING_DIM_OVERRIDE else None
SKIP_EMBEDDING_SMOKETEST = os.getenv('SKIP_EMBEDDING_SMOKETEST', '0') == '1'

if EMBEDDING_BACKEND == 'llm_api' and not OPENAI_API_KEY:
    raise RuntimeError('Set OPENAI_API_KEY before selecting the llm_api backend.')

print(f'üì¶ Embedding backend: {EMBEDDING_BACKEND}')
if EMBEDDING_BACKEND == 'ollama':
    print(f' - Ollama model: {OLLAMA_EMBED_MODEL}')
else:
    print(f' - API base URL: {OPENAI_BASE_URL}/embeddings')
    print(f' - Remote model: {OPENAI_EMBED_MODEL}')
    if EMBEDDING_DIM_OVERRIDE:
        print(f' - Expected vector length override: {EMBEDDING_DIM_OVERRIDE}')
if SKIP_EMBEDDING_SMOKETEST:
    print('‚è≠Ô∏è Skipping embedding smoke test will defer dimension discovery to Step 9.')


üì¶ Embedding backend: ollama
 - Ollama model: phi4-mini


# 8b. Embedding helper functions

Dispatch requests to the selected backend, normalize outputs, and guard against vector length mismatches before storing embeddings in Postgres.


In [10]:
EXPECTED_EMBEDDING_DIM = EMBEDDING_DIM_OVERRIDE


def _embed_with_ollama(text, model):
    payload = {'model': model, 'prompt': text}
    response = requests.post('http://localhost:11434/api/embeddings', json=payload, timeout=EMBEDDING_TIMEOUT)
    response.raise_for_status()
    data = response.json()
    if 'embedding' not in data:
        raise ValueError(f'Unexpected Ollama response: {data}')
    return data['embedding']


def _embed_with_llm_api(text, model):
    headers = {'Authorization': f'Bearer {OPENAI_API_KEY}'}
    if OPENAI_ORG:
        headers['OpenAI-Organization'] = OPENAI_ORG
    payload = {'model': model, 'input': text}
    response = requests.post(f'{OPENAI_BASE_URL}/embeddings', headers=headers, json=payload, timeout=EMBEDDING_TIMEOUT)
    response.raise_for_status()
    body = response.json()
    try:
        return body['data'][0]['embedding']
    except (KeyError, IndexError) as exc:
        raise ValueError(f'Unexpected LLM API response: {body}') from exc


def get_embedding(text, *, model=None):
    global EXPECTED_EMBEDDING_DIM
    if not text:
        raise ValueError('Text to embed must be non-empty.')
    if EMBEDDING_BACKEND == 'ollama':
        vector = _embed_with_ollama(text, model or OLLAMA_EMBED_MODEL)
    else:
        vector = _embed_with_llm_api(text, model or OPENAI_EMBED_MODEL)
    length = len(vector)
    if EXPECTED_EMBEDDING_DIM is None:
        EXPECTED_EMBEDDING_DIM = length
    elif length != EXPECTED_EMBEDDING_DIM:
        raise ValueError(f'Embedding length mismatch: expected {EXPECTED_EMBEDDING_DIM}, got {length}')
    return vector


if not SKIP_EMBEDDING_SMOKETEST:
    smoke_vec = get_embedding('test patient embedding')
    print(f'‚úÖ {EMBEDDING_BACKEND} embedding length: {len(smoke_vec)}')
else:
    print('‚ÑπÔ∏è Smoke test deferred; the first embedding call will determine vector length.')


ConnectionError: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/embeddings (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x78240012dd90>: Failed to establish a new connection: [Errno 111] Connection refused'))

# 9. Create and fill patient_embeddings

The table uses the discovered embedding dimension (from the smoke test or the first context embedding) so it works with either backend.


In [None]:
if df_context.empty:
    raise ValueError('df_context is empty; run Step 7 before generating embeddings.')

cached_first_embedding = None
if EXPECTED_EMBEDDING_DIM is None:
    cached_first_embedding = get_embedding(df_context.iloc[0].context_text)
    vector_length = len(cached_first_embedding)
else:
    vector_length = EXPECTED_EMBEDDING_DIM

cur.execute('DROP TABLE IF EXISTS patient_embeddings;')
cur.execute(f"""
CREATE TABLE patient_embeddings (
  patient_id TEXT PRIMARY KEY,
  context_text TEXT,
  embedding VECTOR({vector_length})
);
""")
conn.commit()

for idx, row in enumerate(df_context.itertuples(index=False)):
    if idx == 0 and cached_first_embedding is not None:
        emb = cached_first_embedding
    else:
        emb = get_embedding(row.context_text)
    cur.execute(
        'INSERT INTO patient_embeddings (patient_id, context_text, embedding) VALUES (%s, %s, %s)',
        (row.patient_id, row.context_text, emb)
    )
conn.commit()
print(f'‚úÖ Re-embedded {len(df_context)} patients with {vector_length}-d vectors via {EMBEDDING_BACKEND}')


# 10. Train UMAP reducer

In [None]:

import umap
import numpy as np
import ast


cur.execute("SELECT patient_id, embedding FROM patient_embeddings;")
rows = cur.fetchall()

embeddings = np.vstack([
    np.array(ast.literal_eval(r[1]), dtype=np.float32) if isinstance(r[1], str) else np.array(r[1], dtype=np.float32)
    for r in rows
])
patient_ids = [r[0] for r in rows]

reducer = umap.UMAP(n_neighbors=5, min_dist=0.3, metric="cosine", random_state=42)
embedding_2d = reducer.fit_transform(embeddings)

print(f"‚úÖ UMAP trained on {len(patient_ids)} patient embeddings.")


# 11. Semantic search query

In [None]:
query_text = 'patients allergic to penicillin with high sdoh risk'
query_emb = get_embedding(query_text)

sql = """
SELECT p.patient_id,
       p.context_text,
       1 - (p.embedding <=> %s::vector) AS similarity
FROM patient_embeddings p
ORDER BY similarity DESC
LIMIT 5;
"""

cur.execute(sql, (query_emb,))
results = cur.fetchall()

pd.DataFrame(results, columns=['patient_id', 'context_text', 'similarity'])


# 12. Visualization (Matplotlib + Plotly)

In [None]:
# --- Step 12: Visualize patient embeddings with optional query overlay ---
import matplotlib.pyplot as plt
import plotly.express as px
import numpy as np
import pandas as pd

plt.figure(figsize=(8,6))
plt.scatter(embedding_2d[:,0], embedding_2d[:,1], s=40, alpha=0.8)
plt.title("Patient Embedding Clusters (œÜ4-mini)")
plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.show()

# Optional: Interactive hover view
df_plot = pd.DataFrame({
    "x": embedding_2d[:,0],
    "y": embedding_2d[:,1],
    "patient_id": patient_ids,
    "context_text": [r[1] for r in rows]
})

fig = px.scatter(
    df_plot,
    x="x",
    y="y",
    hover_data={"patient_id": True, "context_text": True},
    title="Interactive Semantic Map of Patients",
    width=800,
    height=600
)
fig.update_traces(marker=dict(size=10, opacity=0.8))
fig.show()


# 12. Add NLP Semantic Search for Patients

In [None]:
import json
import numpy as np
import pandas as pd
from tabulate import tabulate
from sklearn.metrics.pairwise import cosine_similarity


def embed_query_vector(text):
    return np.array(get_embedding(text), dtype=np.float32)


def semantic_search_fused(query_text, top_k=5):
    # Parse rule-based filters
    filters = parse_filters(query_text)
    candidate_ids = candidate_ids_from_filters(filters)

    if not candidate_ids:
        cur.execute('SELECT patient_id, embedding FROM patient_embeddings;')
    else:
        cur.execute(
            """
            SELECT patient_id, embedding
            FROM patient_embeddings
            WHERE patient_id = ANY(%s);
        """,
            (candidate_ids,)
        )
    rows = cur.fetchall()

    if not rows:
        print('No candidates match the prefilters.')
        return None, pd.DataFrame()

    q_emb = embed_query_vector(query_text).reshape(1, -1)

    ids, mats = [], []
    for pid, emb in rows:
        if isinstance(emb, str):
            emb = json.loads(emb)
        mats.append(np.array(emb, dtype=np.float32))
        ids.append(pid)
    E = np.vstack(mats)

    sims = cosine_similarity(q_emb, E)[0]
    order = np.argsort(-sims)[:top_k]
    top = [(ids[i], float(sims[i])) for i in order]

    results = []

    for pid, score in top:
        cur.execute(
            """
            SELECT first_name, last_name, gender, age, race, ethnicity,
                   sdoh_risk_score, insurance, smoking_status, deceased, death_date
            FROM patients WHERE patient_id = %s;
        """,
            (pid,)
        )
        p = cur.fetchone()

        cur.execute(
            """
            SELECT substance, severity, reaction
            FROM allergies WHERE patient_id = %s
            ORDER BY severity DESC NULLS LAST, substance
            LIMIT 3;
        """,
            (pid,)
        )
        alls = cur.fetchall()
        allergy_summary = '; '.join([f"{a[0]} ({a[1] or 'unknown'})" for a in alls]) if alls else 'None'

        sdoh_bucket = 'HIGH' if (p[6] and p[6] >= 0.70) else ('MEDIUM' if (p[6] and p[6] >= 0.40) else 'LOW')

        # --- EXPLANATION SECTION ---
        explanations = []
        if filters['gender']:
            explanations.append(f"gender={p[2]} {'‚úÖ' if p[2] and p[2].lower()==filters['gender'] else '‚ùå'}")
        if filters['sdoh_bucket']:
            explanations.append(f"SDOH={sdoh_bucket} {'‚úÖ' if sdoh_bucket==filters['sdoh_bucket'] else '‚ùå'}")
        if filters['deceased'] is not None:
            explanations.append(f"deceased={p[9]} {'‚úÖ' if bool(p[9])==filters['deceased'] else '‚ùå'}")
        if filters['recent_days']:
            explanations.append(f"recent_check={filters['recent_days']}d window")
        explanation = '; '.join(explanations) if explanations else 'No explicit filter matches'

        results.append({
            'Patient ID': pid,
            'Name': f"{p[0]} {p[1]}",
            'Gender': p[2],
            'Age': p[3],
            'SDOH': f"{sdoh_bucket} ({p[6]})",
            'Insurance': p[7],
            'Smoking': p[8],
            'Deceased': p[9],
            'Death Date': p[10],
            'Allergies (sample)': allergy_summary,
            'Similarity': round(score, 3),
            'Filter Match Summary': explanation
        })

    df = pd.DataFrame(results)

    print(f"\nüîé Query: {query_text}")
    print(tabulate(df, headers='keys', tablefmt='fancy_grid', showindex=False))

    return q_emb, df


# 13. Filtering + fused semantic search

In [None]:
import re
from datetime import datetime, timedelta

def parse_filters(query_text):
    q = query_text.lower()

    filters = {
        "gender": None,
        "sdoh_bucket": None,
        "deceased": None,
        "recent_days": None
    }

    if "female" in q: filters["gender"] = "female"
    if "male" in q and "female" not in q: filters["gender"] = "male"

    if "high sdoh" in q or "high social risk" in q:   filters["sdoh_bucket"] = "HIGH"
    elif "low sdoh" in q or "low social risk" in q:   filters["sdoh_bucket"] = "LOW"
    elif "medium sdoh" in q or "medium social risk" in q: filters["sdoh_bucket"] = "MEDIUM"

    if "deceased" in q or "death" in q:
        filters["deceased"] = True
        if "recent" in q or "recently" in q:
            filters["recent_days"] = 180  # tweak as needed

    return filters
def candidate_ids_from_filters(filters):
    clauses, params = [], []

    if filters["gender"]:
        clauses.append("LOWER(gender) = %s")
        params.append(filters["gender"])

    if filters["sdoh_bucket"]:
        clauses.append("""
        CASE
          WHEN sdoh_risk_score >= 0.70 THEN 'HIGH'
          WHEN sdoh_risk_score >= 0.40 THEN 'MEDIUM'
          ELSE 'LOW'
        END = %s
        """)
        params.append(filters["sdoh_bucket"])

    if filters["deceased"] is True and filters["recent_days"]:
        days = int(filters["recent_days"])
        clauses.append(f"""
          COALESCE(deceased, FALSE) = TRUE
          AND NULLIF(death_date,'') IS NOT NULL
          AND death_date ~ '^[0-9]{{4}}-[0-9]{{2}}-[0-9]{{2}}$'
          AND TO_DATE(death_date, 'YYYY-MM-DD') >= (CURRENT_DATE - INTERVAL '{days} days')
        """)
    
    elif filters["deceased"] is True:
        clauses.append("COALESCE(deceased, FALSE) = TRUE")

    where_sql = ("WHERE " + " AND ".join(clauses)) if clauses else ""
    sql = f"SELECT patient_id FROM patients {where_sql};"
    cur.execute(sql, params)
    return [r[0] for r in cur.fetchall()]


# Run queries

In [None]:
q_emb, df_results = semantic_search_fused("patients experiencing headaches", top_k=5)

semantic_search_fused("recently deceased patient with respiratory reaction and low income", top_k=5)

semantic_search_fused("female patient deceased recently with severe drug allergy", top_k=5)


In [None]:
# --- Step 13: Project query embedding into same UMAP ---
q_emb = np.array(q_emb, dtype=np.float32).reshape(1, -1)
q_emb_2d = reducer.transform(q_emb)

plt.figure(figsize=(8,6))
plt.scatter(embedding_2d[:,0], embedding_2d[:,1], alpha=0.3, label='Patients')
plt.scatter(q_emb_2d[:,0], q_emb_2d[:,1], color='red', s=120, label='Query')
plt.legend()
plt.title("Query Position in Patient Semantic Space")
plt.show()
