# 01 - Organizando a estrutura de diretórios no Google Colab

In [3]:
# 1) Definir o diretório raiz do projeto

from google.colab import drive
from pathlib import Path
import os, json

PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
PROJECT_ROOT.mkdir(parents=True, exist_ok=True)

# 2) Criar um arquivo de configuração mínimo (ajuda a manter consistência entre todos os notebooks do projeto)
config = {
    "project_root": str(PROJECT_ROOT),
    "data_dir": str(PROJECT_ROOT / "data"),
    "raw_dir": str(PROJECT_ROOT / "data" / "raw"),
    "processed_dir": str(PROJECT_ROOT / "data" / "processed"),
    "synthetic_dir": str(PROJECT_ROOT / "data" / "synthetic"),
    "models_dir": str(PROJECT_ROOT / "models"),
    "logs_dir": str(PROJECT_ROOT / "logs"),
}

#Criar apenas os diretórios de nível superior "data" e "logs" por enquanto (estrutura inicial leve)
Path(config["data_dir"]).mkdir(parents=True, exist_ok=True)
Path(config["logs_dir"]).mkdir(parents=True, exist_ok=True)

config_path = PROJECT_ROOT / "config.json"
with open(config_path, "w") as f:
    json.dump(config, f, indent=2)

print("✅ Drive mounted.")
print("✅ PROJECT_ROOT:", PROJECT_ROOT)
print("✅ Wrote config:", config_path)
print("📁 Existing dirs:", [p.name for p in PROJECT_ROOT.iterdir() if p.is_dir()])


✅ Drive mounted.
✅ PROJECT_ROOT: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3
✅ Wrote config: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/config.json
📁 Existing dirs: ['data', 'logs']


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


<p style="white-space: nowrap;"><strong>Comentário (Bloco 1)</strong> – Inicialização do ambiente do projeto no Google Colab, estabelecendo uma raiz única no Google Drive para persistência dos artefatos gerados. Essa abordagem assegura organização, reprodutibilidade e governança dos dados, modelos e logs ao longo de todas as etapas da construção do assistente médico.

# 02 - Baixando assuntos selecionados do MedQuad

In [4]:
# Baixar o dataset MedQuAD e manter apenas os tópicos selecionados no Google Drive

import shutil
from pathlib import Path

# Caminhos do projeto (Drive)
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
MEDQUAD_RAW_DIR = PROJECT_ROOT / "data" / "raw" / "medquad"
MEDQUAD_RAW_DIR.mkdir(parents=True, exist_ok=True)

# Local temporário no Colab para clonar o repositório
TMP_DIR = Path("/content/medquad_tmp")

# Tópicos de interesse a serem mantidos
selected_topics = [
    "4_MPlus_Health_Topics_QA",
    "7_SeniorHealth_QA"
]

# 0) Garantir que o diretório temporário esteja limpo antes de iniciar
if TMP_DIR.exists():
    shutil.rmtree(TMP_DIR)

# 1) Clone repositório
!git clone https://github.com/abachaa/MedQuAD.git /content/medquad_tmp

# 2) Copiar apenas os tópicos selecionados (os tópicos ficam na raiz do repositório)
for topic in selected_topics:
    src = TMP_DIR / topic
    dst = MEDQUAD_RAW_DIR / topic

    if not src.exists():
        # Debug útil: listar o que existe na raiz do repositório
        existing = sorted([p.name for p in TMP_DIR.iterdir() if p.is_dir()])
        raise FileNotFoundError(
            f"Topic not found: {src}\n"
            f"Folders found at repo root:\n{existing[:50]}"
        )

    if dst.exists():
        shutil.rmtree(dst)

    shutil.copytree(src, dst)
    print(f"✅ Copied: {topic} -> {dst}")

# 3) Remove o repositório clonado temporariamente no ambiente do Colab
shutil.rmtree(TMP_DIR)

# 4) Verificação rápida: contar quantos arquivos XML foram copiados
print("\n📁 Final MedQuAD raw structure:")
for topic in selected_topics:
    topic_dir = MEDQUAD_RAW_DIR / topic
    xml_count = len(list(topic_dir.rglob("*.xml")))
    print(f" - {topic}: {xml_count} XML files")


Cloning into '/content/medquad_tmp'...
remote: Enumerating objects: 11310, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 11310 (delta 7), reused 4 (delta 4), pack-reused 11300 (from 1)[K
Receiving objects: 100% (11310/11310), 11.01 MiB | 6.28 MiB/s, done.
Resolving deltas: 100% (6807/6807), done.
✅ Copied: 4_MPlus_Health_Topics_QA -> /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/raw/medquad/4_MPlus_Health_Topics_QA
✅ Copied: 7_SeniorHealth_QA -> /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/raw/medquad/7_SeniorHealth_QA

📁 Final MedQuAD raw structure:
 - 4_MPlus_Health_Topics_QA: 981 XML files
 - 7_SeniorHealth_QA: 48 XML files


<p style="white-space: nowrap;"><strong>Comentário (Bloco 2)</strong> – Este bloco realiza o download do dataset MedQuAD a partir de sua fonte oficial e copia para o Google Drive apenas os tópicos clínicos relevantes ao projeto (tópico 4_MPlus_Health_Topics_QA, associado à saúde geral e 7_SeniorHealth_QA, associado à saúde de idosos). Essa curadoria reduz ruído, otimiza o uso de recursos computacionais e prepara a base de conhecimento médico para uso posterior em uma arquitetura de RAG, conforme os objetivos estabelecidos para o Tech Challenge 3.

# 03 - Pré-processamento dos arquivos (Parse MedQuAD XML → DataFrame)

In [5]:
# Ler os XML do MedQuAD e transformar em uma tabela limpa de Perguntas & Respostas (DataFrame + CSV)

from pathlib import Path
import xml.etree.ElementTree as ET
import pandas as pd
import re

PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
MEDQUAD_RAW_DIR = PROJECT_ROOT / "data" / "raw" / "medquad"
PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

topics = ["4_MPlus_Health_Topics_QA", "7_SeniorHealth_QA"]

def normalize_text(s: str) -> str:
    s = s or ""
    s = re.sub(r"\s+", " ", s).strip()
    return s

rows = []
xml_files = []
for t in topics:
    xml_files.extend(sorted((MEDQUAD_RAW_DIR / t).rglob("*.xml")))

if not xml_files:
    raise FileNotFoundError(f"No XML files found under: {MEDQUAD_RAW_DIR}")

for fp in xml_files:
    topic = fp.parts[fp.parts.index("medquad") + 1] if "medquad" in fp.parts else fp.parent.name
    try:
        tree = ET.parse(fp)
        root = tree.getroot()

        # O MedQuAD geralmente usa a tag <QAPair> contendo as tags <Question> e <Answer> dentro dela
             for qa in root.findall(".//QAPair"):
            q_el = qa.find(".//Question")
            a_el = qa.find(".//Answer")

            q = normalize_text(q_el.text if q_el is not None else "")
            a = normalize_text(a_el.text if a_el is not None else "")

            if q and a:
                rows.append({
                    "question": q,
                    "answer": a,
                    "topic": topic,
                    "source_file": str(fp),
                })

        # Fallback: if no QAPair found, try generic tags
        if not any(r["source_file"] == str(fp) for r in rows):
            q_el = root.find(".//Question")
            a_el = root.find(".//Answer")
            q = normalize_text(q_el.text if q_el is not None else "")
            a = normalize_text(a_el.text if a_el is not None else "")
            if q and a:
                rows.append({
                    "question": q,
                    "answer": a,
                    "topic": topic,
                    "source_file": str(fp),
                })

    except ET.ParseError as e:
        print(f"⚠️ Skipping malformed XML: {fp.name} ({e})")

df = pd.DataFrame(rows)

# Limpeza básica
df["question"] = df["question"].astype(str).map(normalize_text)
df["answer"] = df["answer"].astype(str).map(normalize_text)

df = df.dropna(subset=["question", "answer"])
df = df[(df["question"] != "") & (df["answer"] != "")]
df = df.drop_duplicates(subset=["question", "answer"]).reset_index(drop=True)

out_csv = PROCESSED_DIR / "medquad_qa.csv"
df.to_csv(out_csv, index=False)

print("✅ Parsed QA pairs:", len(df))
print("✅ Saved:", out_csv)
print("\nSample rows:")
display(df.sample(min(5, len(df))))


✅ Parsed QA pairs: 1750
✅ Saved: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/processed/medquad_qa.csv

Sample rows:


Unnamed: 0,question,answer,topic,source_file
1580,What is (are) Problems with Smell ?,You can help your doctor make a diagnosis by w...,7_SeniorHealth_QA,/content/drive/MyDrive/FIAP_PosTech/Tech_Chall...
1439,What is (are) Low Vision ?,Many agencies and organizations in the communi...,7_SeniorHealth_QA,/content/drive/MyDrive/FIAP_PosTech/Tech_Chall...
1104,What is (are) Cataract ?,A Clouding of the Lens in the Eye A cataract i...,7_SeniorHealth_QA,/content/drive/MyDrive/FIAP_PosTech/Tech_Chall...
913,What is (are) Trigeminal Neuralgia ?,Trigeminal neuralgia (TN) is a type of chronic...,4_MPlus_Health_Topics_QA,/content/drive/MyDrive/FIAP_PosTech/Tech_Chall...
1371,What is (are) High Blood Pressure ?,Normal blood pressure for adults is defined as...,7_SeniorHealth_QA,/content/drive/MyDrive/FIAP_PosTech/Tech_Chall...


<p style="white-space: nowrap;"><strong>Comentário (Bloco 3)</strong> – Este bloco realiza o parsing dos arquivos XML do MedQuAD e transforma o conteúdo em um dataset tabular limpo (pergunta–resposta), adicionando metadados de rastreabilidade (tema e arquivo de origem). O resultado é salvo em medquad_qa.csv, que será utilizado tanto para indexação em RAG quanto para o fine-tuning do modelo no Tech Challenge 3.

# 04 - Gerar dataset sintético de 250 pacientes contendo infos do prontuário

In [6]:
# Gerar prontuário sintético EHR (Electronic Health Record) com 250 pacientes

from pathlib import Path
import pandas as pd
import numpy as np
import random
import json
from datetime import datetime, timedelta

PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
SYN_DIR = PROJECT_ROOT / "data" / "synthetic" / "ehr"
SYN_DIR.mkdir(parents=True, exist_ok=True)

random.seed(42)
np.random.seed(42)

def rand_date(start_days_ago=365*3, end_days_ago=0):
    now = datetime.utcnow()
    start = now - timedelta(days=start_days_ago)
    end = now - timedelta(days=end_days_ago)
    delta = end - start
    return start + timedelta(seconds=random.randint(0, int(delta.total_seconds())))

def fmt_date(dt):
    return dt.strftime("%Y-%m-%d")

def make_id(prefix, n):
    return [f"{prefix}{i:05d}" for i in range(1, n+1)]

conditions = [
    ("Hypertension", "I10"),
    ("Type 2 diabetes mellitus", "E11.9"),
    ("Hyperlipidemia", "E78.5"),
    ("Osteoarthritis", "M19.90"),
    ("Chronic kidney disease (stage 3)", "N18.30"),
    ("Asthma", "J45.909"),
    ("Coronary artery disease", "I25.10"),
    ("Hypothyroidism", "E03.9"),
    ("Gastroesophageal reflux disease", "K21.9"),
    ("Depression", "F32.9"),
]

medications = [
    ("Lisinopril", "10 mg", "oral", "once daily"),
    ("Metformin", "500 mg", "oral", "twice daily"),
    ("Atorvastatin", "20 mg", "oral", "once daily"),
    ("Levothyroxine", "50 mcg", "oral", "once daily"),
    ("Omeprazole", "20 mg", "oral", "once daily"),
    ("Albuterol inhaler", "90 mcg", "inhalation", "as needed"),
    ("Amlodipine", "5 mg", "oral", "once daily"),
    ("Sertraline", "50 mg", "oral", "once daily"),
    ("Acetaminophen", "500 mg", "oral", "as needed"),
]

lab_tests = [
    ("Hemoglobin A1c", "%", (5.0, 11.5)),
    ("LDL cholesterol", "mg/dL", (50, 220)),
    ("Creatinine", "mg/dL", (0.6, 2.6)),
    ("TSH", "mIU/L", (0.1, 8.0)),
    ("Systolic blood pressure", "mmHg", (95, 190)),
]

visit_reasons = [
    "Routine follow-up",
    "Medication review",
    "Annual wellness visit",
    "Acute cough",
    "Joint pain evaluation",
    "Blood pressure check",
    "Diabetes management",
    "Lab results discussion",
]

providers = [
    ("Dr. Morgan", "Internal Medicine"),
    ("Dr. Patel", "Family Medicine"),
    ("Dr. Chen", "Geriatrics"),
    ("Dr. Rivera", "Cardiology"),
]

# 🔴 Quantidade de pacientes sintéticos gerados
N_PATIENTS = 250

patient_ids = make_id("P", N_PATIENTS)
sexes = ["Female", "Male"]
first_names = ["Alex", "Jordan", "Taylor", "Casey", "Riley", "Morgan", "Avery", "Cameron", "Quinn", "Parker"]
last_names = ["Smith", "Johnson", "Brown", "Garcia", "Miller", "Davis", "Rodriguez", "Martinez", "Lee", "Wilson"]

patients = []
for pid in patient_ids:
    dob = rand_date(start_days_ago=365*90, end_days_ago=365*18)
    age = int((datetime.utcnow() - dob).days / 365.25)
    patients.append({
        "patient_id": pid,
        "full_name": f"{random.choice(first_names)} {random.choice(last_names)}",
        "sex": random.choice(sexes),
        "date_of_birth": fmt_date(dob),
        "age": age,
        "phone": f"+1-555-{random.randint(100,999)}-{random.randint(1000,9999)}",
        "email": f"{pid.lower()}@example.com",
    })

patients_df = pd.DataFrame(patients)

visits, diagnoses_rows, prescriptions_rows, labs_rows = [], [], [], []
visit_id_counter = rx_id_counter = dx_id_counter = lab_id_counter = 1

for pid in patient_ids:
    for vdt in sorted([rand_date(start_days_ago=365*2) for _ in range(random.randint(2, 8))]):
        vid = f"V{visit_id_counter:06d}"
        visit_id_counter += 1

        provider_name, specialty = random.choice(providers)
        reason = random.choice(visit_reasons)

        visits.append({
            "visit_id": vid,
            "patient_id": pid,
            "visit_date": fmt_date(vdt),
            "provider_name": provider_name,
            "specialty": specialty,
            "reason_for_visit": reason,
            "notes": f"Patient seen for {reason.lower()}."
        })

        for _ in range(random.randint(1, 3)):
            cond, icd10 = random.choice(conditions)
            diagnoses_rows.append({
                "diagnosis_id": f"DX{dx_id_counter:06d}",
                "visit_id": vid,
                "patient_id": pid,
                "diagnosis_date": fmt_date(vdt),
                "diagnosis_name": cond,
                "icd10_code": icd10,
                "status": random.choice(["active", "resolved", "chronic"])
            })
            dx_id_counter += 1

        for _ in range(random.randint(0, 2)):
            drug, dose, route, freq = random.choice(medications)
            prescriptions_rows.append({
                "prescription_id": f"RX{rx_id_counter:06d}",
                "visit_id": vid,
                "patient_id": pid,
                "prescribed_date": fmt_date(vdt),
                "medication_name": drug,
                "dose": dose,
                "route": route,
                "frequency": freq,
                "duration_days": random.choice([7, 14, 30, 90]),
                "instructions": f"Take {drug} {dose} via {route}, {freq}."
            })
            rx_id_counter += 1

        if random.random() < 0.7:
            for _ in range(random.randint(1, 3)):
                test, unit, (lo, hi) = random.choice(lab_tests)
                labs_rows.append({
                    "lab_id": f"LAB{lab_id_counter:06d}",
                    "visit_id": vid,
                    "patient_id": pid,
                    "lab_date": fmt_date(vdt + timedelta(days=random.randint(0, 3))),
                    "test_name": test,
                    "value": round(random.uniform(lo, hi), 2),
                    "unit": unit,
                    "reference_range": f"{lo}-{hi} {unit}"
                })
                lab_id_counter += 1

patients_df.to_csv(SYN_DIR / "patients.csv", index=False)
pd.DataFrame(visits).to_csv(SYN_DIR / "visits.csv", index=False)
pd.DataFrame(diagnoses_rows).to_csv(SYN_DIR / "diagnoses.csv", index=False)
pd.DataFrame(prescriptions_rows).to_csv(SYN_DIR / "prescriptions.csv", index=False)
pd.DataFrame(labs_rows).to_csv(SYN_DIR / "labs.csv", index=False)

print("✅ Synthetic EHR regenerated")
print("patients:", len(patients_df))
print("visits:", len(visits))
print("diagnoses:", len(diagnoses_rows))
print("prescriptions:", len(prescriptions_rows))
print("labs:", len(labs_rows))


✅ Synthetic EHR regenerated
patients: 250
visits: 1252
diagnoses: 2480
prescriptions: 1267
labs: 1780


  now = datetime.utcnow()
  age = int((datetime.utcnow() - dob).days / 365.25)


<p style="white-space: nowrap;"><strong>Comentário (Bloco 4)</strong> – Este bloco gera um prontuário eletrônico sintético (EHR) para 250 pacientes, criando dados estruturados e realistas (pacientes, visitas, diagnósticos com CID-10, prescrições e exames laboratoriais). Para garantir reprodutibilidade, foram fixadas sementes (random.seed e np.random.seed). As datas de eventos clínicos são distribuídas ao longo do tempo para simular histórico de atendimento. Ao final, os dados são salvos em CSV no Google Drive (data/synthetic/ehr/), formando a base “patient-specific” do assistente. Este EHR permite testar consultas clínicas (ex.: diagnósticos ativos, exames faltantes) e validar o roteamento do assistente entre EHR (contexto do paciente) e MedQuAD (conhecimento geral).

# 05 - Visualizar as info disponiveis para um paciente selecionado

In [7]:
#Inspeção visual de prontuários eletrônicos de saúde (EHR) sintéticos para pacientes selecionados

from pathlib import Path
import pandas as pd

PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
EHR_DIR = PROJECT_ROOT / "data" / "synthetic" / "ehr"

# Carregar tabelas
patients = pd.read_csv(EHR_DIR / "patients.csv")
visits = pd.read_csv(EHR_DIR / "visits.csv")
diagnoses = pd.read_csv(EHR_DIR / "diagnoses.csv")
prescriptions = pd.read_csv(EHR_DIR / "prescriptions.csv")
labs = pd.read_csv(EHR_DIR / "labs.csv")

# ---------------------------
# Selecionar pacientes para inspecionar
# ---------------------------
# Opção A: selecionar pacientes de forma aleatória
selected_patients = patients.sample(2).patient_id.tolist()

# Opção B: Selecionar pacientes de interesse
# selected_patients = ["P00001", "P00042"]

print("🔎 Selected patients:", selected_patients)

# ---------------------------
# Mostrar dados por paciente
# ---------------------------
for pid in selected_patients:
    print("\n" + "="*80)
    print(f"PATIENT {pid}")
    print("="*80)

    print("\n🧍 Patient demographics")
    display(patients[patients.patient_id == pid])

    print("\n📅 Visits")
    display(
        visits[visits.patient_id == pid]
        .sort_values("visit_date", ascending=False)
    )

    print("\n🩺 Diagnoses")
    display(
        diagnoses[diagnoses.patient_id == pid]
        .sort_values("diagnosis_date", ascending=False)
    )

    print("\n💊 Prescriptions")
    display(
        prescriptions[prescriptions.patient_id == pid]
        .sort_values("prescribed_date", ascending=False)
    )

    print("\n🧪 Labs")
    display(
        labs[labs.patient_id == pid]
        .sort_values("lab_date", ascending=False)
    )


🔎 Selected patients: ['P00143', 'P00007']

PATIENT P00143

🧍 Patient demographics


Unnamed: 0,patient_id,full_name,sex,date_of_birth,age,phone,email
142,P00143,Alex Johnson,Female,1970-02-12,55,+1-555-790-1339,p00143@example.com



📅 Visits


Unnamed: 0,visit_id,patient_id,visit_date,provider_name,specialty,reason_for_visit,notes
724,V000725,P00143,2024-07-01,Dr. Morgan,Internal Medicine,Lab results discussion,Patient seen for lab results discussion.
723,V000724,P00143,2024-02-21,Dr. Patel,Family Medicine,Blood pressure check,Patient seen for blood pressure check.



🩺 Diagnoses


Unnamed: 0,diagnosis_id,visit_id,patient_id,diagnosis_date,diagnosis_name,icd10_code,status
1460,DX001461,V000725,P00143,2024-07-01,Type 2 diabetes mellitus,E11.9,resolved
1457,DX001458,V000724,P00143,2024-02-21,Hypothyroidism,E03.9,chronic
1458,DX001459,V000724,P00143,2024-02-21,Depression,F32.9,resolved
1459,DX001460,V000724,P00143,2024-02-21,Hypertension,I10,chronic



💊 Prescriptions


Unnamed: 0,prescription_id,visit_id,patient_id,prescribed_date,medication_name,dose,route,frequency,duration_days,instructions
714,RX000715,V000725,P00143,2024-07-01,Amlodipine,5 mg,oral,once daily,14,"Take Amlodipine 5 mg via oral, once daily."
715,RX000716,V000725,P00143,2024-07-01,Atorvastatin,20 mg,oral,once daily,30,"Take Atorvastatin 20 mg via oral, once daily."
713,RX000714,V000724,P00143,2024-02-21,Sertraline,50 mg,oral,once daily,14,"Take Sertraline 50 mg via oral, once daily."



🧪 Labs


Unnamed: 0,lab_id,visit_id,patient_id,lab_date,test_name,value,unit,reference_range
1025,LAB001026,V000725,P00143,2024-07-04,Hemoglobin A1c,6.59,%,5.0-11.5 %
1026,LAB001027,V000725,P00143,2024-07-03,Systolic blood pressure,147.88,mmHg,95-190 mmHg
1027,LAB001028,V000725,P00143,2024-07-01,Creatinine,0.88,mg/dL,0.6-2.6 mg/dL
1022,LAB001023,V000724,P00143,2024-02-24,Systolic blood pressure,100.64,mmHg,95-190 mmHg
1024,LAB001025,V000724,P00143,2024-02-24,Creatinine,2.45,mg/dL,0.6-2.6 mg/dL
1023,LAB001024,V000724,P00143,2024-02-23,Hemoglobin A1c,6.79,%,5.0-11.5 %



PATIENT P00007

🧍 Patient demographics


Unnamed: 0,patient_id,full_name,sex,date_of_birth,age,phone,email
6,P00007,Parker Miller,Female,1982-11-08,43,+1-555-847-8527,p00007@example.com



📅 Visits


Unnamed: 0,visit_id,patient_id,visit_date,provider_name,specialty,reason_for_visit,notes
33,V000034,P00007,2025-12-23,Dr. Rivera,Cardiology,Annual wellness visit,Patient seen for annual wellness visit.
32,V000033,P00007,2025-08-02,Dr. Chen,Geriatrics,Lab results discussion,Patient seen for lab results discussion.
31,V000032,P00007,2025-03-02,Dr. Rivera,Cardiology,Lab results discussion,Patient seen for lab results discussion.
30,V000031,P00007,2024-10-04,Dr. Patel,Family Medicine,Acute cough,Patient seen for acute cough.



🩺 Diagnoses


Unnamed: 0,diagnosis_id,visit_id,patient_id,diagnosis_date,diagnosis_name,icd10_code,status
64,DX000065,V000034,P00007,2025-12-23,Osteoarthritis,M19.90,active
65,DX000066,V000034,P00007,2025-12-23,Depression,F32.9,chronic
66,DX000067,V000034,P00007,2025-12-23,Type 2 diabetes mellitus,E11.9,active
63,DX000064,V000033,P00007,2025-08-02,Asthma,J45.909,resolved
62,DX000063,V000032,P00007,2025-03-02,Depression,F32.9,chronic
59,DX000060,V000031,P00007,2024-10-04,Gastroesophageal reflux disease,K21.9,resolved
60,DX000061,V000031,P00007,2024-10-04,Gastroesophageal reflux disease,K21.9,active
61,DX000062,V000031,P00007,2024-10-04,Hypertension,I10,active



💊 Prescriptions


Unnamed: 0,prescription_id,visit_id,patient_id,prescribed_date,medication_name,dose,route,frequency,duration_days,instructions
30,RX000031,V000033,P00007,2025-08-02,Metformin,500 mg,oral,twice daily,90,"Take Metformin 500 mg via oral, twice daily."
31,RX000032,V000033,P00007,2025-08-02,Omeprazole,20 mg,oral,once daily,7,"Take Omeprazole 20 mg via oral, once daily."



🧪 Labs


Unnamed: 0,lab_id,visit_id,patient_id,lab_date,test_name,value,unit,reference_range
47,LAB000048,V000034,P00007,2025-12-26,LDL cholesterol,167.39,mg/dL,50-220 mg/dL
48,LAB000049,V000034,P00007,2025-12-26,TSH,7.0,mIU/L,0.1-8.0 mIU/L
46,LAB000047,V000033,P00007,2025-08-05,Systolic blood pressure,111.2,mmHg,95-190 mmHg
45,LAB000046,V000033,P00007,2025-08-03,Systolic blood pressure,163.53,mmHg,95-190 mmHg
44,LAB000045,V000032,P00007,2025-03-05,Systolic blood pressure,162.33,mmHg,95-190 mmHg
43,LAB000044,V000032,P00007,2025-03-04,Systolic blood pressure,151.4,mmHg,95-190 mmHg
42,LAB000043,V000032,P00007,2025-03-03,Creatinine,1.39,mg/dL,0.6-2.6 mg/dL


<p style="white-space: nowrap;"><strong>Comentário (Bloco 5)</strong> – Este bloco realiza uma inspeção visual exploratória do EHR sintético (prontuário paciente), carregando os arquivos CSV gerados no Bloco 4 e exibindo, para 1 ou 2 pacientes selecionados, seus dados demográficos, histórico de visitas, diagnósticos, prescrições e exames laboratoriais. O objetivo é validar qualitativamente a consistência e o realismo dos dados gerados antes de utilizá-los no assistente médico. Esse passo funciona como uma etapa de sanity check, garantindo que o EHR suporta consultas clínicas reais (ex.: diagnósticos ativos, exames faltantes) e o correto roteamento entre EHR e MedQuAD.

# 06 - Camada de Consulta Estruturada do Prontuário (Acesso Seguro aos Dados Clínicos)

In [8]:
# Camada estruturada de consulta de prontuários (seguro, deterministico, sem LLM) ---

from pathlib import Path
import pandas as pd

# --------------------------------------------------
# Caminhos
# --------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
EHR_DIR = PROJECT_ROOT / "data" / "synthetic" / "ehr"

# --------------------------------------------------
# Carregar tabelas de prontuários
# --------------------------------------------------
patients = pd.read_csv(EHR_DIR / "patients.csv")
visits = pd.read_csv(EHR_DIR / "visits.csv")
diagnoses = pd.read_csv(EHR_DIR / "diagnoses.csv")
prescriptions = pd.read_csv(EHR_DIR / "prescriptions.csv")
labs = pd.read_csv(EHR_DIR / "labs.csv")

# --------------------------------------------------
# Funcões de consulta
# --------------------------------------------------

def get_patient_summary(patient_id: str) -> dict:
    """Return basic demographic information for a patient."""
    p = patients[patients.patient_id == patient_id]
    if p.empty:
        return {"error": "Patient not found"}

    p = p.iloc[0].to_dict()

    return {
        "patient_id": p["patient_id"],
        "full_name": p["full_name"],
        "age": int(p["age"]),
        "sex": p["sex"],
        "contact": {
            "phone": p["phone"],
            "email": p["email"]
        }
    }


def get_active_diagnoses(patient_id: str):
    """Return active or chronic diagnoses for a patient."""
    df = diagnoses[
        (diagnoses.patient_id == patient_id) &
        (diagnoses.status.isin(["active", "chronic"]))
    ]

    return (
        df[["diagnosis_name", "icd10_code", "status", "diagnosis_date"]]
        .sort_values("diagnosis_date", ascending=False)
        .to_dict(orient="records")
    )


def get_latest_labs(patient_id: str, n: int = 5):
    """Return the most recent lab results for a patient."""
    df = labs[labs.patient_id == patient_id]
    if df.empty:
        return []

    df = df.sort_values("lab_date", ascending=False).head(n)

    return (
        df[["test_name", "value", "unit", "reference_range", "lab_date"]]
        .to_dict(orient="records")
    )


def check_missing_labs(patient_id: str, required_tests: list):
    """Check which required lab tests are missing for a patient."""
    df = labs[labs.patient_id == patient_id]
    existing_tests = set(df.test_name.unique())
    return [t for t in required_tests if t not in existing_tests]


def list_recent_visits(patient_id: str, n: int = 3):
    """Return recent visits for a patient."""
    df = visits[visits.patient_id == patient_id]
    if df.empty:
        return []

    df = df.sort_values("visit_date", ascending=False).head(n)

    return (
        df[["visit_date", "provider_name", "specialty", "reason_for_visit"]]
        .to_dict(orient="records")
    )

# --------------------------------------------------
# Teste de validação
# --------------------------------------------------
sample_patient = patients.sample(1).iloc[0].patient_id

print("🧪 Sample patient:", sample_patient)

print("\n📌 Patient summary")
print(get_patient_summary(sample_patient))

print("\n📌 Active diagnoses")
print(get_active_diagnoses(sample_patient))

print("\n📌 Latest labs")
print(get_latest_labs(sample_patient))

print("\n📌 Recent visits")
print(list_recent_visits(sample_patient))

print("\n📌 Missing labs (HbA1c, LDL, Creatinine)")
print(
    check_missing_labs(
        sample_patient,
        ["Hemoglobin A1c", "LDL cholesterol", "Creatinine"]
    )
)


🧪 Sample patient: P00206

📌 Patient summary
{'patient_id': 'P00206', 'full_name': 'Parker Johnson', 'age': 20, 'sex': 'Male', 'contact': {'phone': '+1-555-615-4629', 'email': 'p00206@example.com'}}

📌 Active diagnoses
[{'diagnosis_name': 'Chronic kidney disease (stage 3)', 'icd10_code': 'N18.30', 'status': 'chronic', 'diagnosis_date': '2025-10-16'}, {'diagnosis_name': 'Type 2 diabetes mellitus', 'icd10_code': 'E11.9', 'status': 'chronic', 'diagnosis_date': '2025-09-22'}, {'diagnosis_name': 'Coronary artery disease', 'icd10_code': 'I25.10', 'status': 'active', 'diagnosis_date': '2025-09-22'}, {'diagnosis_name': 'Depression', 'icd10_code': 'F32.9', 'status': 'chronic', 'diagnosis_date': '2025-09-18'}, {'diagnosis_name': 'Depression', 'icd10_code': 'F32.9', 'status': 'active', 'diagnosis_date': '2025-02-22'}, {'diagnosis_name': 'Chronic kidney disease (stage 3)', 'icd10_code': 'N18.30', 'status': 'chronic', 'diagnosis_date': '2025-02-22'}, {'diagnosis_name': 'Hypothyroidism', 'icd10_code'

<p style="white-space: nowrap;"><strong>Comentário (Bloco 6)</strong> – Este bloco implementa uma camada de consulta estruturada do EHR (prontuário eletrônico sintético) usando apenas operações determinísticas por meio da biblioteca Pandas, sem uso de LLM, garantindo respostas estáveis e seguras para perguntas sobre pacientes. São carregadas as tabelas do EHR (patients, visits, diagnoses, prescriptions, labs) e definidas funções para obter: resumo demográfico, diagnósticos ativos/crônicos, exames mais recentes, visitas recentes e identificação de exames faltantes (ex.: HbA1c, LDL, Creatinina). Ao final, um teste seleciona automaticamente um paciente e executa todas as funções para validar que a base suporta consultas clínicas. Esse bloco fornece a fonte confiável de “verdade” do paciente, reduzindo alucinações e permitindo que o assistente combine dados do prontuário (EHR) com conhecimento geral (MedQuAD) de forma rastreável.

 # 07 - Orquestração do Assistente: Roteamento entre Prontuário Clínico e Conhecimento Médico

In [9]:
#Ferramentas do LangChain + Roteador + Orquestração com LangGraph (EHR vs MedQuAD)

# 0) Instalar dependências (reiniciar runtime se o Colab pedir)
!pip -q install -U langchain langchain-core langchain-community langgraph \
  transformers accelerate sentence-transformers faiss-cpu

from pathlib import Path
import re
import json
import pandas as pd
from typing import TypedDict, Any, Dict, List

# LangChain / LangGraph
from langchain_core.tools import tool
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langgraph.graph import StateGraph, END

# ------------------------------------------------------------------------------
# 1) Caminhos +m carga de tabelas
# ------------------------------------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")

# MedQuAD processado em um CSV
MEDQUAD_CSV = PROJECT_ROOT / "data" / "processed" / "medquad_qa.csv"
if not MEDQUAD_CSV.exists():
    raise FileNotFoundError(f"Missing MedQuAD processed file: {MEDQUAD_CSV}")

qa_df = pd.read_csv(MEDQUAD_CSV)

# Tabelas de dados sintéticos
EHR_DIR = PROJECT_ROOT / "data" / "synthetic" / "ehr"
patients = pd.read_csv(EHR_DIR / "patients.csv")
visits = pd.read_csv(EHR_DIR / "visits.csv")
diagnoses = pd.read_csv(EHR_DIR / "diagnoses.csv")
prescriptions = pd.read_csv(EHR_DIR / "prescriptions.csv")
labs = pd.read_csv(EHR_DIR / "labs.csv")

print("✅ Loaded:")
print(" - MedQuAD QAs:", len(qa_df))
print(" - EHR patients:", len(patients), "| visits:", len(visits), "| dx:", len(diagnoses), "| rx:", len(prescriptions), "| labs:", len(labs))

# ------------------------------------------------------------------------------
# 2) Construir o Vector Store do MedQuAD (RAG)
# ------------------------------------------------------------------------------
docs = []
for _, row in qa_df.iterrows():
    content = f"QUESTION: {str(row['question']).strip()}\nANSWER: {str(row['answer']).strip()}"
    meta = {"topic": row.get("topic", ""), "source_file": row.get("source_file", "")}
    docs.append(Document(page_content=content, metadata=meta))

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vs = FAISS.from_documents(docs, embeddings)

print("✅ Vector store ready (FAISS)")

# ------------------------------------------------------------------------------
# 3) Funçoes determinísticas de prontuário
# ------------------------------------------------------------------------------
def _get_patient_summary(patient_id: str) -> dict:
    p = patients[patients.patient_id == patient_id]
    if p.empty:
        return {"error": "Patient not found"}
    p = p.iloc[0].to_dict()
    return {
        "patient_id": p["patient_id"],
        "full_name": p["full_name"],
        "age": int(p["age"]),
        "sex": p["sex"],
        "contact": {"phone": p["phone"], "email": p["email"]},
    }

def _get_active_diagnoses(patient_id: str):
    df = diagnoses[(diagnoses.patient_id == patient_id) & (diagnoses.status.isin(["active", "chronic"]))]
    return (
        df[["diagnosis_name", "icd10_code", "status", "diagnosis_date"]]
        .sort_values("diagnosis_date", ascending=False)
        .to_dict(orient="records")
    )

def _get_latest_labs(patient_id: str, n: int = 5):
    df = labs[labs.patient_id == patient_id]
    if df.empty:
        return []
    df = df.sort_values("lab_date", ascending=False).head(n)
    return df[["test_name", "value", "unit", "reference_range", "lab_date"]].to_dict(orient="records")

def _list_recent_visits(patient_id: str, n: int = 3):
    df = visits[visits.patient_id == patient_id]
    if df.empty:
        return []
    df = df.sort_values("visit_date", ascending=False).head(n)
    return df[["visit_date", "provider_name", "specialty", "reason_for_visit"]].to_dict(orient="records")

def _check_missing_labs(patient_id: str, required_tests: list):
    df = labs[labs.patient_id == patient_id]
    existing = set(df.test_name.unique())
    return [t for t in required_tests if t not in existing]

# ------------------------------------------------------------------------------
# 4) Encapsular como ferramentas do LangChain
# ------------------------------------------------------------------------------
@tool("ehr_get_patient_summary")
def ehr_get_patient_summary(patient_id: str) -> str:
    """Get basic demographic information for a patient_id like P00001."""
    return json.dumps(_get_patient_summary(patient_id), indent=2)

@tool("ehr_get_active_diagnoses")
def ehr_get_active_diagnoses(patient_id: str) -> str:
    """Get active/chronic diagnoses for a patient_id."""
    return json.dumps(_get_active_diagnoses(patient_id), indent=2)

@tool("ehr_get_latest_labs")
def ehr_get_latest_labs(patient_id: str, n: int = 5) -> str:
    """Get latest lab results for a patient_id."""
    return json.dumps(_get_latest_labs(patient_id, n=n), indent=2)

@tool("ehr_list_recent_visits")
def ehr_list_recent_visits(patient_id: str, n: int = 3) -> str:
    """List recent visits for a patient_id."""
    return json.dumps(_list_recent_visits(patient_id, n=n), indent=2)

@tool("ehr_check_missing_labs")
def ehr_check_missing_labs(patient_id: str, required_tests: str = "Hemoglobin A1c, LDL cholesterol, Creatinine") -> str:
    """Check missing labs for a patient_id. required_tests is comma-separated."""
    tests = [t.strip() for t in required_tests.split(",") if t.strip()]
    return json.dumps(_check_missing_labs(patient_id, tests), indent=2)

@tool("medquad_retrieve")
def medquad_retrieve(query: str, k: int = 4) -> str:
    """Retrieve top-k MedQuAD Q&A snippets relevant to the query."""
    hits = vs.similarity_search(query, k=k)
    payload = [{"content": h.page_content, "metadata": h.metadata} for h in hits]
    return json.dumps(payload, indent=2)

# ------------------------------------------------------------------------------
# 5) Regras do roteador: decidir entre EHR vs MedQuAD vs AMBOS
# ------------------------------------------------------------------------------
PATIENT_ID_RE = re.compile(r"\bP\d{5}\b")

EHR_KEYWORDS = [
    "patient", "visit", "visits", "diagnosis", "diagnoses", "prescription", "medication",
    "lab", "labs", "a1c", "ldl", "creatinine", "blood pressure", "latest", "recent", "pending", "missing"
]

def route_intent(user_text: str) -> str:
    """
    Returns:
      - "EHR" if patient-specific / factual
      - "MEDQUAD" if general medical knowledge
      - "BOTH" if explanation + patient context
    """
    t = user_text.lower()
    has_pid = bool(PATIENT_ID_RE.search(user_text))
    has_ehr_kw = any(k in t for k in EHR_KEYWORDS)

    if has_pid and any(x in t for x in ["what is", "explain", "risks", "significance", "why"]):
        return "BOTH"
    if has_pid or has_ehr_kw:
        return "EHR"
    return "MEDQUAD"

# ------------------------------------------------------------------------------
# 6) LangGraph com esquema de State explícito
# ------------------------------------------------------------------------------
class AgentState(TypedDict, total=False):
    input: str
    route: str
    ehr: Dict[str, Any]
    medquad: List[Dict[str, Any]]
    final: str

def node_route(state: AgentState) -> AgentState:
    return {"route": route_intent(state["input"])}

def node_ehr_only(state: AgentState) -> AgentState:
    text = state["input"]
    pid_match = PATIENT_ID_RE.search(text)
    pid = pid_match.group(0) if pid_match else None

    if not pid:
        return {"ehr": {"error": "No patient_id found (expected like P00001)."}}

    t = text.lower()
    out = {"patient_id": pid}

    out["summary"] = json.loads(ehr_get_patient_summary.invoke({"patient_id": pid}))

    if "diagnos" in t or "condition" in t or "problem" in t:
        out["active_diagnoses"] = json.loads(ehr_get_active_diagnoses.invoke({"patient_id": pid}))

    if "lab" in t or "a1c" in t or "ldl" in t or "creatinine" in t or "blood pressure" in t or "latest" in t:
        out["latest_labs"] = json.loads(ehr_get_latest_labs.invoke({"patient_id": pid, "n": 5}))

    if "visit" in t or "recent" in t or "history" in t:
        out["recent_visits"] = json.loads(ehr_list_recent_visits.invoke({"patient_id": pid, "n": 3}))

    if "missing" in t or "pending" in t:
        out["missing_labs"] = json.loads(ehr_check_missing_labs.invoke({
            "patient_id": pid,
            "required_tests": "Hemoglobin A1c, LDL cholesterol, Creatinine"
        }))

    return {"ehr": out}

def node_medquad_only(state: AgentState) -> AgentState:
    hits = json.loads(medquad_retrieve.invoke({"query": state["input"], "k": 4}))
    return {"medquad": hits}

def node_ehr_plus_medquad(state: AgentState) -> AgentState:
    out = {}
    out.update(node_ehr_only(state))
    out.update(node_medquad_only(state))
    return out

def node_finalize(state: AgentState) -> AgentState:
    route = state.get("route", "UNKNOWN")
    user_q = state["input"]

    parts = [f"USER QUESTION: {user_q}\n", f"ROUTE: {route}\n"]

    if route in ("EHR", "BOTH"):
        parts.append("=== STRUCTURED EHR DATA (authoritative) ===")
        parts.append(json.dumps(state.get("ehr", {}), indent=2))

    if route in ("MEDQUAD", "BOTH"):
        parts.append("\n=== MEDQUAD RETRIEVAL (knowledge context) ===")
        parts.append(json.dumps(state.get("medquad", []), indent=2))

    parts.append(
        "\nNOTE: This is a deterministic finalizer (no LLM generation yet). "
        "Next we will plug your fine-tuned Llama to produce fluent answers, while keeping EHR facts deterministic."
    )

    return {"final": "\n".join(parts)}

graph = StateGraph(AgentState)
graph.add_node("route", node_route)
graph.add_node("ehr_only", node_ehr_only)
graph.add_node("medquad_only", node_medquad_only)
graph.add_node("ehr_plus_medquad", node_ehr_plus_medquad)
graph.add_node("finalize", node_finalize)

graph.set_entry_point("route")

graph.add_conditional_edges(
    "route",
    lambda s: "ehr_plus_medquad" if s["route"] == "BOTH" else ("ehr_only" if s["route"] == "EHR" else "medquad_only"),
    {"ehr_only": "ehr_only", "ehr_plus_medquad": "ehr_plus_medquad", "medquad_only": "medquad_only"}
)

graph.add_edge("ehr_only", "finalize")
graph.add_edge("ehr_plus_medquad", "finalize")
graph.add_edge("medquad_only", "finalize")
graph.add_edge("finalize", END)

app = graph.compile()
print("✅ LangGraph app compiled (FIXED)")

# ------------------------------------------------------------------------------
# 7) Demo runs
# ------------------------------------------------------------------------------
demo_pid = patients.sample(1).iloc[0].patient_id

demo_questions = [
    f"What are the active diagnoses for patient {demo_pid}?",
    f"Does patient {demo_pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?",
    "What is hypertension and what are common management approaches?",
    f"Explain the significance of elevated HbA1c in older adults for patient {demo_pid}."
]

for q in demo_questions:
    print("\n" + "="*100)
    print("QUERY:", q)
    result = app.invoke({"input": q})
    print(result["final"][:2000])
    if len(result["final"]) > 2000:
        print("\n... [truncated] ...")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/484.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m101.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m88.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.7/64.7 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.0/51.0 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency c

  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Vector store ready (FAISS)
✅ LangGraph app compiled (FIXED)

QUERY: What are the active diagnoses for patient P00153?
USER QUESTION: What are the active diagnoses for patient P00153?

ROUTE: EHR

=== STRUCTURED EHR DATA (authoritative) ===
{
  "patient_id": "P00153",
  "summary": {
    "patient_id": "P00153",
    "full_name": "Jordan Smith",
    "age": 23,
    "sex": "Female",
    "contact": {
      "phone": "+1-555-826-3594",
      "email": "p00153@example.com"
    }
  },
  "active_diagnoses": [
    {
      "diagnosis_name": "Coronary artery disease",
      "icd10_code": "I25.10",
      "status": "active",
      "diagnosis_date": "2024-09-28"
    },
    {
      "diagnosis_name": "Hypothyroidism",
      "icd10_code": "E03.9",
      "status": "active",
      "diagnosis_date": "2024-09-28"
    },
    {
      "diagnosis_name": "Type 2 diabetes mellitus",
      "icd10_code": "E11.9",
      "status": "chronic",
      "diagnosis_date": "2024-02-08"
    },
    {
      "diagnosis_name": "Ast

<p style="white-space: nowrap;"><strong>Comentário (Bloco 7)</strong> – Nesse passo 7, é feita a orquestração do assistente médico usando LangChain Tools + LangGraph, integrando duas fontes: (1) consultas determinísticas ao EHR sintético (dados factuais do paciente) e (2) recuperação semântica (RAG) no MedQuAD via FAISS (conhecimento médico geral). As funções do prontuário são empacotadas como tools e um roteador de intenções decide automaticamente entre EHR, MedQuAD ou BOTH (quando a pergunta mistura explicação + contexto do paciente). O LangGraph define um fluxo de execução com estado tipado, garantindo robustez e evitando erros de input. A finalização nesta etapa é determinística (sem LLM), apenas para validar o pipeline de roteamento e recuperação antes de acoplar o modelo gerador. Esse bloco implementa o “cérebro” do assistente (RAG + ferramentas + roteamento), preparando o sistema para gerar respostas fluentes com LLM mantendo fatos clínicos controlados.

 # 08 - Geração de Respostas com LLM (substitui o finalizador determinístico mantendo os guardrails)

In [10]:
#LLM Finalizer + Deterministic Sources + Lab Guardrails + Drive Logs ---
# 0) Instalar dependências
!pip -q install -U "transformers>=4.41.0" "accelerate>=0.30.0" "bitsandbytes>=0.43.1" langgraph

import json
import re
from pathlib import Path
from datetime import datetime, timezone
from typing import TypedDict, Any, Dict, List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langgraph.graph import StateGraph, END

required = [
    "route_intent", "PATIENT_ID_RE",
    "ehr_get_patient_summary", "ehr_get_active_diagnoses", "ehr_get_latest_labs",
    "ehr_list_recent_visits", "ehr_check_missing_labs", "medquad_retrieve",
    "patients"  # loaded DataFrame from EHR
]
missing = [x for x in required if x not in globals()]
if missing:
    raise RuntimeError(
        "Missing objects from previous blocks: "
        + ", ".join(missing)
        + "\nRun Block 7 first (LangChain Tools + Router + LangGraph orchestration)."
    )

# -------------------------------------------------------------------
# 0) Caminhos
# -------------------------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
LOG_DIR = PROJECT_ROOT / "logs"
LOG_DIR.mkdir(parents=True, exist_ok=True)
RUNS_LOG = LOG_DIR / "assistant_runs.jsonl"
print("✅ Logs will be written to:", RUNS_LOG)

# -------------------------------------------------------------------
# 1) Carregar LLM de instruções (Mistral por padrão; sem modelos restritos)
# -------------------------------------------------------------------
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"

def load_llm_robust(model_name: str):
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    # Try 4-bit quantization first
    try:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        mdl = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
        )
        mdl.eval()
        return tok, mdl, f"{model_name} (4-bit)"
    except Exception as e:
        print("⚠️ 4-bit load failed, falling back to fp16 on GPU. Reason:\n", e, "\n")

    # Fallback: fp16 on GPU
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA not available. Please switch runtime to GPU (T4/L4/A100).")

    mdl = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map={"": 0},
    )
    mdl.eval()
    return tok, mdl, f"{model_name} (fp16)"

tokenizer, model, loaded_model_name = load_llm_robust(BASE_MODEL)
print("✅ Loaded model:", loaded_model_name)

# -------------------------------------------------------------------
# 2) Prompt + generator (LLM will NOT output Sources; we add them deterministically)
# -------------------------------------------------------------------
def compact_medquad(medquad: List[Dict[str, Any]], limit: int = 3) -> str:
    snips = []
    for i, hit in enumerate(medquad[:limit], start=1):
        content = (hit.get("content") or "").strip()[:1100]
        meta = hit.get("metadata") or {}
        snips.append(
            f"[MedQuAD {i}] Topic: {meta.get('topic','')}\n"
            f"SourceFile: {meta.get('source_file','')}\n"
            f"{content}"
        )
    return "\n\n".join(snips) if snips else "None"

def build_prompt(user_question: str, route: str, ehr: Dict[str, Any], medquad: List[Dict[str, Any]]) -> str:
    ehr_block = json.dumps(ehr, indent=2) if ehr else "None"
    medquad_block = compact_medquad(medquad, limit=3)

    system = (
        "You are a clinical assistant for physicians.\n"
        "SAFETY RULES (must follow):\n"
        "1) Do NOT prescribe medications or provide definitive treatment plans.\n"
        "2) Patient-specific facts MUST come only from the STRUCTURED EHR JSON provided.\n"
        "3) If EHR lacks info, say what is missing and recommend checking the chart.\n"
        "4) For lab values: do NOT interpret/classify (normal/high/low/elevated/poor control) unless explicit thresholds are provided.\n"
        "5) For missing-labs checks: ONLY report which labs are missing vs present.\n"
        "6) Always answer in English.\n"
        "7) DO NOT include a 'Sources' section. The system will attach sources automatically.\n"
    )

    instruction = (
        f"ROUTE: {route}\n\n"
        f"USER QUESTION:\n{user_question}\n\n"
        f"STRUCTURED EHR JSON:\n{ehr_block}\n\n"
        f"MEDQUAD RETRIEVAL:\n{medquad_block}\n\n"
        "Write a helpful, concise answer following the safety rules.\n"
    )

    # Mistral chat template style (works well even without apply_chat_template)
    return f"<s>[INST] {system}\n\n{instruction} [/INST]"

@torch.inference_mode()
def generate_text(prompt: str, max_new_tokens: int = 260) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.35,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True).strip()

    # Robustly remove prompt echo for chat/instruct formats
    if "[/INST]" in text:
        text = text.split("[/INST]", 1)[1].strip()
    text = text.replace("</s>", "").strip()
    return text

def strip_sources_if_any(text: str) -> str:
    parts = re.split(r"\n\s*Sources\s*:\s*\n", text, flags=re.IGNORECASE)
    return parts[0].strip()

def log_run(payload: Dict[str, Any]):
    payload["timestamp_utc"] = datetime.now(timezone.utc).isoformat()
    payload["model"] = loaded_model_name
    with open(RUNS_LOG, "a", encoding="utf-8") as f:
        f.write(json.dumps(payload, ensure_ascii=False) + "\n")

# -------------------------------------------------------------------
# 3) Lab guardrails (block interpretation language if labs exist)
# -------------------------------------------------------------------
LAB_INTERPRETATION_RE = re.compile(
    r"\b("
    r"normal|abnormal|elevated|high|low|within (the )?normal|outside (the )?normal|"
    r"poor glycemic control|well[- ]controlled|uncontrolled|controlled|"
    r"indicates (diabetes|prediabetes)|suggests (diabetes|prediabetes)|"
    r"goal|target|above goal|below goal"
    r")\b",
    re.IGNORECASE
)

def lab_safe_summary(ehr: Dict[str, Any]) -> str:
    pid = ehr.get("patient_id", "this patient")
    labs = ehr.get("latest_labs", [])[:5] if isinstance(ehr.get("latest_labs"), list) else []
    lines = []
    for lab in labs:
        tn = lab.get("test_name")
        val = lab.get("value")
        unit = lab.get("unit", "")
        date = lab.get("lab_date", "")
        if tn is not None and val is not None:
            lines.append(f"- {tn}: {val} {unit} (date: {date})")

    return (
        f"For patient {pid}, the EHR reports the following recent lab values:\n"
        + ("\n".join(lines) if lines else "- No recent lab values found.")
        + "\n\nInterpretation (e.g., whether a value is high/low/normal or indicates control) depends on explicit clinical thresholds "
          "and patient context. Please verify thresholds and clinical context in the chart.\n"
    )

def enforce_no_lab_interpretation(answer: str, ehr: Dict[str, Any]) -> str:
    if not isinstance(ehr, dict):
        return answer
    has_labs = isinstance(ehr.get("latest_labs"), list) and len(ehr.get("latest_labs", [])) > 0
    if has_labs and LAB_INTERPRETATION_RE.search(answer):
        return lab_safe_summary(ehr)
    return answer

# -------------------------------------------------------------------
# 4) LangGraph (uses tools created in Block 7)
# -------------------------------------------------------------------
class AgentState(TypedDict, total=False):
    input: str
    route: str
    ehr: Dict[str, Any]
    medquad: List[Dict[str, Any]]
    final: str

def node_route(state: AgentState) -> AgentState:
    return {"route": route_intent(state["input"])}

def node_ehr_only(state: AgentState) -> AgentState:
    text = state["input"]
    pid_match = PATIENT_ID_RE.search(text)
    pid = pid_match.group(0) if pid_match else None

    if not pid:
        return {"ehr": {"error": "No patient_id found (expected like P00001).", "_task_type": "ehr_missing_patient_id", "_used_fields": []}}

    t = text.lower()
    out: Dict[str, Any] = {"patient_id": pid}

    out["summary"] = json.loads(ehr_get_patient_summary.invoke({"patient_id": pid}))
    used_fields = ["summary"]

    if "missing" in t or "pending" in t:
        out["_task_type"] = "ehr_missing_labs_check"
        out["missing_labs"] = json.loads(ehr_check_missing_labs.invoke({
            "patient_id": pid,
            "required_tests": "Hemoglobin A1c, LDL cholesterol, Creatinine"
        }))
        used_fields.append("missing_labs")
    else:
        out["_task_type"] = "ehr_general"

        if any(k in t for k in ["diagnos", "diagnoses", "condition", "conditions", "problem", "problems"]):
            out["active_diagnoses"] = json.loads(ehr_get_active_diagnoses.invoke({"patient_id": pid}))
            used_fields.append("active_diagnoses")

        if any(k in t for k in ["lab", "a1c", "hba1c", "ldl", "creatinine", "blood pressure", "latest"]):
            out["latest_labs"] = json.loads(ehr_get_latest_labs.invoke({"patient_id": pid, "n": 5}))
            used_fields.append("latest_labs")

        if any(k in t for k in ["visit", "visits", "recent", "history"]):
            out["recent_visits"] = json.loads(ehr_list_recent_visits.invoke({"patient_id": pid, "n": 3}))
            used_fields.append("recent_visits")

    out["_used_fields"] = used_fields
    return {"ehr": out}

def node_medquad_only(state: AgentState) -> AgentState:
    hits = json.loads(medquad_retrieve.invoke({"query": state["input"], "k": 4}))
    return {"medquad": hits}

def node_ehr_plus_medquad(state: AgentState) -> AgentState:
    out = {}
    out.update(node_ehr_only(state))
    out.update(node_medquad_only(state))
    return out

def deterministic_sources(ehr: Dict[str, Any], medq: List[Dict[str, Any]]) -> str:
    ehr_fields = ehr.get("_used_fields", []) if isinstance(ehr, dict) else []
    lines = []
    lines.append(f"- EHR fields: {', '.join(ehr_fields) if ehr_fields else 'None'}")
    if medq:
        used = []
        for i, hit in enumerate(medq[:3], start=1):
            meta = hit.get("metadata") or {}
            used.append(f"[MedQuAD {i}] {meta.get('topic','')} | {meta.get('source_file','')}")
        lines.append("- MedQuAD snippets: " + "; ".join(used))
    else:
        lines.append("- MedQuAD snippets: None")
    return "Sources:\n" + "\n".join(lines)

def node_finalize_llm(state: AgentState) -> AgentState:
    route = state.get("route", "UNKNOWN")
    user_q = state["input"]
    ehr = state.get("ehr", {}) or {}
    medq = state.get("medquad", []) or []

    # Hard template for missing labs (no LLM)
    if isinstance(ehr, dict) and ehr.get("_task_type") == "ehr_missing_labs_check":
        missing = ehr.get("missing_labs", [])
        pid = ehr.get("patient_id")
        if isinstance(missing, list) and len(missing) == 0:
            body = (
                f"According to the structured EHR data, patient {pid} has all of the requested labs "
                f"(Hemoglobin A1c, LDL cholesterol, Creatinine) on record."
            )
        else:
            body = (
                f"According to the structured EHR data, patient {pid} is missing the following requested lab(s): "
                f"{', '.join(missing) if isinstance(missing, list) else str(missing)}."
            )

        final = body + "\n\n" + deterministic_sources(ehr, [])
        log_run({
            "route": route, "question": user_q, "ehr_task_type": ehr.get("_task_type"),
            "ehr_used_fields": ehr.get("_used_fields"), "ehr_used": True, "medquad_used": False,
            "patient_id": pid,
        })
        return {"final": final}

    prompt = build_prompt(user_q, route, ehr, medq)
    body = generate_text(prompt, max_new_tokens=260)
    body = strip_sources_if_any(body)
    body = enforce_no_lab_interpretation(body, ehr)

    final = body.strip() + "\n\n" + deterministic_sources(ehr, medq)

    log_run({
        "route": route,
        "question": user_q,
        "ehr_task_type": ehr.get("_task_type") if isinstance(ehr, dict) else None,
        "ehr_used_fields": ehr.get("_used_fields") if isinstance(ehr, dict) else None,
        "ehr_used": bool(ehr),
        "medquad_used": bool(medq),
        "patient_id": ehr.get("patient_id") if isinstance(ehr, dict) else None,
    })

    return {"final": final}

graph_llm = StateGraph(AgentState)
graph_llm.add_node("route", node_route)
graph_llm.add_node("ehr_only", node_ehr_only)
graph_llm.add_node("medquad_only", node_medquad_only)
graph_llm.add_node("ehr_plus_medquad", node_ehr_plus_medquad)
graph_llm.add_node("finalize_llm", node_finalize_llm)

graph_llm.set_entry_point("route")
graph_llm.add_conditional_edges(
    "route",
    lambda s: "ehr_plus_medquad" if s["route"] == "BOTH" else ("ehr_only" if s["route"] == "EHR" else "medquad_only"),
    {"ehr_only": "ehr_only", "ehr_plus_medquad": "ehr_plus_medquad", "medquad_only": "medquad_only"}
)

graph_llm.add_edge("ehr_only", "finalize_llm")
graph_llm.add_edge("ehr_plus_medquad", "finalize_llm")
graph_llm.add_edge("medquad_only", "finalize_llm")
graph_llm.add_edge("finalize_llm", END)

app_llm = graph_llm.compile()
print("✅ LangGraph app compiled with deterministic Sources + strict lab safety")

# -------------------------------------------------------------------
# 5) Demo runs
# -------------------------------------------------------------------
demo_pid = patients.sample(1).iloc[0].patient_id

demo_questions = [
    f"What are the active diagnoses for patient {demo_pid}?",
    f"Does patient {demo_pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?",
    "What is hypertension and what are common management approaches?",
    f"Explain the clinical meaning of HbA1c and how it is used in older adults, using patient {demo_pid} as context.",
]

for q in demo_questions:
    print("\n" + "="*100)
    print("QUERY:", q)
    result = app_llm.invoke({"input": q})
    print(result["final"])


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25h✅ Logs will be written to: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/logs/assistant_runs.jsonl


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

⚠️ 4-bit load failed, falling back to fp16 on GPU. Reason:
 Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes` 



`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

✅ Loaded model: mistralai/Mistral-7B-Instruct-v0.2 (fp16)
✅ LangGraph app compiled with deterministic Sources + strict lab safety

QUERY: What are the active diagnoses for patient P00076?
The structured EHR JSON for patient P00076, Quinn Brown, lists the following active diagnoses: Hyperlipidemia with ICD-10 code E78.5 and chronic status, diagnosed on two separate occasions - May 1, 2024, and August 16, 2025. Additionally, there are two chronic diagnoses of Osteoarthritis with ICD-10 code M19.90, diagnosed on November 26, 2024, and Depression with ICD-10 code F32.9, also diagnosed on November 26, 2024.

Sources:
- EHR fields: summary, active_diagnoses
- MedQuAD snippets: None

QUERY: Does patient P00076 have missing labs for HbA1c, LDL cholesterol, and Creatinine?
According to the structured EHR data, patient P00076 is missing the following requested lab(s): Hemoglobin A1c.

Sources:
- EHR fields: summary, missing_labs
- MedQuAD snippets: None

QUERY: What is hypertension and what are 

<p style="white-space: nowrap;"><strong>Comentário (Bloco 8)</strong> – Este bloco acopla um LLM gerador ao fluxo do assistente (LangGraph), mantendo as consultas ao EHR determinísticas e o RAG MedQuAD como contexto de conhecimento. Ele adiciona fontes (“Sources”) de forma determinística ao final da resposta (em vez de deixar o modelo “inventar” fontes), melhora o roteamento para consultas clínicas (cobertura de palavras-chave de EHR) e implementa guardrails de segurança, especialmente para exames laboratoriais (evitando interpretar valores como “alto/baixo/normal” sem limiares explícitos). Também inclui um template determinístico para a tarefa “exames faltantes”, garantindo respostas estáveis sem LLM. Por fim, registra cada execução em log com metadados (rota, campos do EHR usados, paciente) para rastreabilidade. Esse bloco representa a transição do “pipeline validado” para um assistente com geração controlada + auditoria + segurança clínica.

# 09_Fine-tuning (QLoRA) do modelo com dados do MedQuAD

In [11]:
# --- GPU HARD RESET (Colab): free VRAM held by stuck processes ---

import os, signal, gc, subprocess, textwrap
import torch

print("=== nvidia-smi (before) ===")
!nvidia-smi

# Try to kill python processes that hold GPU memory (common in Colab after crashes)
try:
    smi = subprocess.check_output(
        ["bash", "-lc", "nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv,noheader"],
        text=True
    ).strip()

    if smi:
        print("\n=== GPU processes detected ===")
        print(smi)

        pids = []
        for line in smi.splitlines():
            parts = [p.strip() for p in line.split(",")]
            if parts and parts[0].isdigit():
                pid = int(parts[0])
                pname = parts[1] if len(parts) > 1 else ""
                # Only kill typical notebook/python processes; avoid killing system daemons if any
                if "python" in pname.lower() or "colab" in pname.lower() or "ipykernel" in pname.lower():
                    pids.append(pid)

        if pids:
            print("\n🔪 Killing GPU-hogging Python kernel processes:", pids)
            for pid in pids:
                try:
                    os.kill(pid, signal.SIGKILL)
                except Exception as e:
                    print(f"Could not kill PID {pid}: {e}")
        else:
            print("\nNo python/ipykernel processes found to kill.")
    else:
        print("\nNo GPU compute processes listed by nvidia-smi.")
except Exception as e:
    print("Could not parse/kill via nvidia-smi:", e)

# Clear torch memory
gc.collect()
if torch.cuda.is_available():
    try:
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    except Exception as e:
        print("torch cuda cleanup warning:", e)

print("\n=== nvidia-smi (after) ===")
!nvidia-smi

# Show torch view of memory
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info()
    print(f"\n✅ VRAM free/total (GB): {free/1e9:.2f} / {total/1e9:.2f}")


=== nvidia-smi (before) ===
Wed Dec 24 10:56:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   61C    P0             31W /   72W |   14579MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                    

In [12]:
# LoRA fine-tuning (NO bitsandbytes) + eval_strategy compatible ---
# Uses TinyLlama 1.1B Chat in FP16 + LoRA (fits on NVIDIA L4).
# Fix: uses eval_strategy (not evaluation_strategy) to match your installed transformers.

!pip -q install -U "transformers" "accelerate" "datasets" "peft" "evaluate"

import gc
from pathlib import Path
import pandas as pd
from datasets import Dataset
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model

# -------------------------------------------------------
# 0) Limpeza da VRAM
# -------------------------------------------------------
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

print("CUDA available:", torch.cuda.is_available())
if not torch.cuda.is_available():
    raise RuntimeError("This training block requires a GPU runtime in Colab (T4/L4/A100).")

print("GPU:", torch.cuda.get_device_name(0))
free, total = torch.cuda.mem_get_info()
print(f"VRAM free/total (GB): {free/1e9:.2f} / {total/1e9:.2f}")

# -------------------------------------------------------
# 1) Paths
# -------------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
DATA_PATH = PROJECT_ROOT / "data" / "processed" / "medquad_qa.csv"
OUT_DIR = PROJECT_ROOT / "models" / "tinyllama_medquad_lora"
OUT_DIR.mkdir(parents=True, exist_ok=True)

if not DATA_PATH.exists():
    raise FileNotFoundError(f"Missing: {DATA_PATH}")

print("✅ Dataset:", DATA_PATH)
print("✅ Output dir:", OUT_DIR)

# -------------------------------------------------------
# 2) Carregar pares de perguntas e respostas (QA) do MedQuAD
# -------------------------------------------------------
df = pd.read_csv(DATA_PATH).dropna(subset=["question", "answer"]).copy()
df["question"] = df["question"].astype(str).str.strip()
df["answer"] = df["answer"].astype(str).str.strip()

print("✅ QA rows:", len(df))

# -------------------------------------------------------
# 3) Construir texto de treinamento no formato de instruções
# -------------------------------------------------------
SYSTEM = (
    "You are a helpful medical assistant. "
    "Answer clearly, in English, and avoid prescribing medications or giving definitive treatment plans. "
    "Provide general educational information and recommend consulting a licensed clinician for decisions."
)

def format_sample(q, a):
    return (
        f"<s>[INST] {SYSTEM}\n\n"
        f"Question: {q}\n\n"
        f"Answer in English. [/INST]\n"
        f"{a}</s>"
    )

df["text"] = [format_sample(q, a) for q, a in zip(df["question"], df["answer"])]

ds = Dataset.from_pandas(df[["text"]]).train_test_split(test_size=0.05, seed=42)
train_ds, eval_ds = ds["train"], ds["test"]

print("✅ Train size:", len(train_ds), "| Eval size:", len(eval_ds))

# -------------------------------------------------------
# 4) Modelo + Tokenizer (modelo pequeno de instruções, que cabe em fp16)
# -------------------------------------------------------
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MAX_LEN = 256

def tokenize_fn(batch):
    tok = tokenizer(
        batch["text"],
        truncation=True,
        max_length=MAX_LEN,
        padding="max_length",
    )
    tok["labels"] = tok["input_ids"].copy()
    return tok

train_tok = train_ds.map(tokenize_fn, batched=True, remove_columns=["text"])
eval_tok  = eval_ds.map(tokenize_fn, batched=True, remove_columns=["text"])

# -------------------------------------------------------
# 5) Carregar modelo (fp16, sem quantização) + LoRA
# -------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
model.config.use_cache = False  # Melhora estabilidade do treinamento

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# -------------------------------------------------------
# 6)Argumentos de treinamento (COMPAT: eval_strategy)
# -------------------------------------------------------
training_args = TrainingArguments(
    output_dir=str(OUT_DIR),
    num_train_epochs=2,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    logging_steps=25,
    eval_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    fp16=True,
    report_to="none",
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# -------------------------------------------------------
# 7) Treinamento
# -------------------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=eval_tok,
    data_collator=data_collator,
)

train_result = trainer.train()

# -------------------------------------------------------
# 8) Salvar adaptador + tokenizer no Drive
# -------------------------------------------------------
trainer.model.save_pretrained(str(OUT_DIR))
tokenizer.save_pretrained(str(OUT_DIR))

print("✅ Training complete.")
print("✅ Saved LoRA adapter + tokenizer to:", OUT_DIR)
print("✅ Train output:", train_result)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/512.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0m
[?25hCUDA available: True
GPU: NVIDIA L4
VRAM free/total (GB): 8.75 / 23.80
✅ Dataset: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/processed/medquad_qa.csv
✅ Output dir: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/models/tinyllama_medquad_lora
✅ QA rows: 1750
✅ Train size: 1662 | Eval size: 88


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Map:   0%|          | 0/1662 [00:00<?, ? examples/s]

Map:   0%|          | 0/88 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


trainable params: 2,252,800 || all params: 1,102,301,184 || trainable%: 0.2044


Step,Training Loss,Validation Loss
200,1.0167,1.088678


✅ Training complete.
✅ Saved LoRA adapter + tokenizer to: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/models/tinyllama_medquad_lora
✅ Train output: TrainOutput(global_step=208, training_loss=1.1615644372426546, metrics={'train_runtime': 253.1689, 'train_samples_per_second': 13.13, 'train_steps_per_second': 0.822, 'total_flos': 5293374676402176.0, 'train_loss': 1.1615644372426546, 'epoch': 2.0})


<p style="white-space: nowrap;"><strong>Comentário (Bloco 9)</strong> – Este bloco realiza o fine-tuning supervisionado do modelo de linguagem usando LoRA (Low-Rank Adaptation) sobre um modelo de linguagem de pequeno porte (TinyLlama 1.1B), adequado às limitações de GPU do Google Colab. As perguntas e respostas do MedQuAD são formatadas em estilo instruction, tokenizadas e divididas em conjuntos de treino e validação. O treinamento ajusta apenas uma pequena fração dos parâmetros do modelo (≈0,2%), reduzindo custo computacional e risco de overfitting. Ao final, o adaptador LoRA e o tokenizer são salvos no Google Drive para posterior uso no assistente médico. No contexto do Tech Challenge 3, este bloco cria o modelo especializado em conhecimento médico geral que será integrado ao pipeline RAG + EHR.

# 10_Aplicação do Modelo Fine-Tuned: Carregamento do TinyLlama + Adapter LoRA no Assistente Médico

In [13]:
# Carregar o LoRA fine-tuned do TinyLlama e conectá-lo ao pipeline do assistente ---

!pip -q install -U transformers accelerate peft langgraph

import gc
import json
import re
from pathlib import Path
from datetime import datetime, timezone
from typing import TypedDict, Any, Dict, List

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from langgraph.graph import StateGraph, END

required = [
    "route_intent", "PATIENT_ID_RE",
    "ehr_get_patient_summary", "ehr_get_active_diagnoses", "ehr_get_latest_labs",
    "ehr_list_recent_visits", "ehr_check_missing_labs", "medquad_retrieve",
    "patients"
]
missing = [x for x in required if x not in globals()]
if missing:
    raise RuntimeError(
        "Missing objects from previous blocks: "
        + ", ".join(missing)
        + "\nRun Block 7 first (Tools + Router + VectorStore + EHR tables)."
    )

# -------------------------------------------------------------------
# 0) Caminhos
# -------------------------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
ADAPTER_DIR = PROJECT_ROOT / "models" / "tinyllama_medquad_lora"
RUNS_LOG = PROJECT_ROOT / "logs" / "assistant_runs.jsonl"

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

if not ADAPTER_DIR.exists():
    raise FileNotFoundError(f"Adapter directory not found: {ADAPTER_DIR}")

print("✅ Base model:", BASE_MODEL)
print("✅ LoRA adapter:", ADAPTER_DIR)
print("✅ Log file:", RUNS_LOG)

# -------------------------------------------------------------------
# 1) Limpeza da VRAM
# -------------------------------------------------------------------
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

print("CUDA available:", torch.cuda.is_available())
if not torch.cuda.is_available():
    raise RuntimeError("GPU runtime required (T4/L4/A100).")

print("GPU:", torch.cuda.get_device_name(0))
free, total = torch.cuda.mem_get_info()
print(f"VRAM free/total (GB): {free/1e9:.2f} / {total/1e9:.2f}")

# -------------------------------------------------------------------
# 2) Carregar o tokenizador e o modelo base (FP16) e aplicar o adaptador LoRA
# -------------------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
base_model.config.use_cache = False

# Apply LoRA adapter
model = PeftModel.from_pretrained(base_model, str(ADAPTER_DIR))
model.eval()

loaded_model_name = f"{BASE_MODEL} + LoRA({ADAPTER_DIR.name})"
print("✅ Loaded fine-tuned model:", loaded_model_name)

# -------------------------------------------------------------------
# 3) Auxiliar de geração (formato de instruções)
# -------------------------------------------------------------------
@torch.inference_mode()
def generate_text(prompt: str, max_new_tokens: int = 250) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.35,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True).strip()

        if "[/INST]" in text:
        text = text.split("[/INST]", 1)[1].strip()
    text = text.replace("</s>", "").strip()
    return text

def log_run(payload: Dict[str, Any]):
    payload["timestamp_utc"] = datetime.now(timezone.utc).isoformat()
    payload["model"] = loaded_model_name
    RUNS_LOG.parent.mkdir(parents=True, exist_ok=True)
    with open(RUNS_LOG, "a", encoding="utf-8") as f:
        f.write(json.dumps(payload, ensure_ascii=False) + "\n")

# -------------------------------------------------------------------
# 4)Segurança: remover fontes acidentais e aplicar interpretação estrita de laboratório
# -------------------------------------------------------------------
def strip_sources_if_any(text: str) -> str:
    parts = re.split(r"\n\s*Sources\s*:\s*\n", text, flags=re.IGNORECASE)
    return parts[0].strip()

LAB_INTERPRETATION_RE = re.compile(
    r"\b("
    r"normal|abnormal|elevated|high|low|within (the )?normal|outside (the )?normal|"
    r"poor glycemic control|well[- ]controlled|uncontrolled|controlled|"
    r"indicates (diabetes|prediabetes)|suggests (diabetes|prediabetes)|"
    r"goal|target|above goal|below goal"
    r")\b",
    re.IGNORECASE
)

def lab_safe_summary(ehr: Dict[str, Any]) -> str:
    pid = ehr.get("patient_id", "this patient")
    labs = ehr.get("latest_labs", [])[:5] if isinstance(ehr.get("latest_labs"), list) else []
    lines = []
    for lab in labs:
        tn = lab.get("test_name")
        val = lab.get("value")
        unit = lab.get("unit", "")
        date = lab.get("lab_date", "")
        if tn is not None and val is not None:
            lines.append(f"- {tn}: {val} {unit} (date: {date})")

    return (
        f"For patient {pid}, the EHR reports the following recent lab values:\n"
        + ("\n".join(lines) if lines else "- No recent lab values found.")
        + "\n\nInterpretation (e.g., whether a value is high/low/normal or indicates control) depends on explicit clinical thresholds "
          "and patient context. Please verify thresholds and clinical context in the chart.\n"
    )

def enforce_no_lab_interpretation(answer: str, ehr: Dict[str, Any]) -> str:
    if not isinstance(ehr, dict):
        return answer
    has_labs = isinstance(ehr.get("latest_labs"), list) and len(ehr.get("latest_labs", [])) > 0
    if has_labs and LAB_INTERPRETATION_RE.search(answer):
        return lab_safe_summary(ehr)
    return answer

# -------------------------------------------------------------------
# 5) Gerador de prompt (o LLM não produz fontes; a anexação é determinística)
# -------------------------------------------------------------------
def compact_medquad(medquad: List[Dict[str, Any]], limit: int = 3) -> str:
    snips = []
    for i, hit in enumerate(medquad[:limit], start=1):
        content = (hit.get("content") or "").strip()[:900]
        meta = hit.get("metadata") or {}
        snips.append(
            f"[MedQuAD {i}] Topic: {meta.get('topic','')}\n"
            f"SourceFile: {meta.get('source_file','')}\n"
            f"{content}"
        )
    return "\n\n".join(snips) if snips else "None"

def build_prompt(user_question: str, route: str, ehr: Dict[str, Any], medquad: List[Dict[str, Any]]) -> str:
    SYSTEM = (
        "You are a helpful medical assistant for physicians. "
        "Answer in English, clearly, and avoid prescribing medications or giving definitive treatment plans. "
        "Use the STRUCTURED EHR JSON only for patient-specific facts. "
        "Use MedQuAD only as general medical knowledge context. "
        "Do NOT include a Sources section; the system will attach sources automatically."
    )
    ehr_block = json.dumps(ehr, indent=2) if ehr else "None"
    medquad_block = compact_medquad(medquad, limit=3)

    return (
        f"<s>[INST] {SYSTEM}\n\n"
        f"ROUTE: {route}\n\n"
        f"USER QUESTION:\n{user_question}\n\n"
        f"STRUCTURED EHR JSON:\n{ehr_block}\n\n"
        f"MEDQUAD RETRIEVAL:\n{medquad_block}\n\n"
        f"Write a concise, helpful answer following the rules. [/INST]\n"
    )

def deterministic_sources(ehr: Dict[str, Any], medq: List[Dict[str, Any]]) -> str:
    ehr_fields = ehr.get("_used_fields", []) if isinstance(ehr, dict) else []
    lines = []
    lines.append(f"- EHR fields: {', '.join(ehr_fields) if ehr_fields else 'None'}")
    if medq:
        used = []
        for i, hit in enumerate(medq[:3], start=1):
            meta = hit.get("metadata") or {}
            used.append(f"[MedQuAD {i}] {meta.get('topic','')} | {meta.get('source_file','')}")
        lines.append("- MedQuAD snippets: " + "; ".join(used))
    else:
        lines.append("- MedQuAD snippets: None")
    return "Sources:\n" + "\n".join(lines)

# -------------------------------------------------------------------
# 6) Checagem básica
# -------------------------------------------------------------------
sanity_q = "What is hypertension and what are common management approaches?"
sanity_prompt = build_prompt(sanity_q, "MEDQUAD", ehr={}, medquad=[])
print("\n" + "="*80)
print("SANITY CHECK (model generation):")
print(generate_text(sanity_prompt, max_new_tokens=140))

# -------------------------------------------------------------------
# 7) Aplicação LangGraph usando o modelo ajustado (fine-tuned)
# -------------------------------------------------------------------
class AgentState(TypedDict, total=False):
    input: str
    route: str
    ehr: Dict[str, Any]
    medquad: List[Dict[str, Any]]
    final: str

def node_route(state: AgentState) -> AgentState:
    return {"route": route_intent(state["input"])}

def node_ehr_only(state: AgentState) -> AgentState:
    text = state["input"]
    pid_match = PATIENT_ID_RE.search(text)
    pid = pid_match.group(0) if pid_match else None

    if not pid:
        return {"ehr": {"error": "No patient_id found (expected like P00001).", "_task_type": "ehr_missing_patient_id", "_used_fields": []}}

    t = text.lower()
    out: Dict[str, Any] = {"patient_id": pid}

    out["summary"] = json.loads(ehr_get_patient_summary.invoke({"patient_id": pid}))
    used_fields = ["summary"]

    if "missing" in t or "pending" in t:
        out["_task_type"] = "ehr_missing_labs_check"
        out["missing_labs"] = json.loads(ehr_check_missing_labs.invoke({
            "patient_id": pid,
            "required_tests": "Hemoglobin A1c, LDL cholesterol, Creatinine"
        }))
        used_fields.append("missing_labs")
    else:
        out["_task_type"] = "ehr_general"

        if any(k in t for k in ["diagnos", "diagnoses", "condition", "conditions", "problem", "problems"]):
            out["active_diagnoses"] = json.loads(ehr_get_active_diagnoses.invoke({"patient_id": pid}))
            used_fields.append("active_diagnoses")

        if any(k in t for k in ["lab", "a1c", "hba1c", "ldl", "creatinine", "blood pressure", "latest"]):
            out["latest_labs"] = json.loads(ehr_get_latest_labs.invoke({"patient_id": pid, "n": 5}))
            used_fields.append("latest_labs")

        if any(k in t for k in ["visit", "visits", "recent", "history"]):
            out["recent_visits"] = json.loads(ehr_list_recent_visits.invoke({"patient_id": pid, "n": 3}))
            used_fields.append("recent_visits")

    out["_used_fields"] = used_fields
    return {"ehr": out}

def node_medquad_only(state: AgentState) -> AgentState:
    hits = json.loads(medquad_retrieve.invoke({"query": state["input"], "k": 4}))
    return {"medquad": hits}

def node_ehr_plus_medquad(state: AgentState) -> AgentState:
    out = {}
    out.update(node_ehr_only(state))
    out.update(node_medquad_only(state))
    return out

def node_finalize_llm(state: AgentState) -> AgentState:
    route = state.get("route", "UNKNOWN")
    user_q = state["input"]
    ehr = state.get("ehr", {}) or {}
    medq = state.get("medquad", []) or []

    # Resposta determinística para dados laboratoriais ausentes
    if isinstance(ehr, dict) and ehr.get("_task_type") == "ehr_missing_labs_check":
        missing = ehr.get("missing_labs", [])
        pid = ehr.get("patient_id")
        if isinstance(missing, list) and len(missing) == 0:
            body = (
                f"According to the structured EHR data, patient {pid} has all of the requested labs "
                f"(Hemoglobin A1c, LDL cholesterol, Creatinine) on record."
            )
        else:
            body = (
                f"According to the structured EHR data, patient {pid} is missing the following requested lab(s): "
                f"{', '.join(missing) if isinstance(missing, list) else str(missing)}."
            )
        final = body + "\n\n" + deterministic_sources(ehr, [])
        log_run({"route": route, "question": user_q, "ehr_used": True, "medquad_used": False, "patient_id": pid})
        return {"final": final}

    prompt = build_prompt(user_q, route, ehr, medq)
    body = generate_text(prompt, max_new_tokens=260)
    body = strip_sources_if_any(body)
    body = enforce_no_lab_interpretation(body, ehr)

    final = body.strip() + "\n\n" + deterministic_sources(ehr, medq)

    log_run({
        "route": route,
        "question": user_q,
        "ehr_used": bool(ehr),
        "medquad_used": bool(medq),
        "patient_id": ehr.get("patient_id") if isinstance(ehr, dict) else None,
    })
    return {"final": final}

graph = StateGraph(AgentState)
graph.add_node("route", node_route)
graph.add_node("ehr_only", node_ehr_only)
graph.add_node("medquad_only", node_medquad_only)
graph.add_node("ehr_plus_medquad", node_ehr_plus_medquad)
graph.add_node("finalize_llm", node_finalize_llm)

graph.set_entry_point("route")
graph.add_conditional_edges(
    "route",
    lambda s: "ehr_plus_medquad" if s["route"] == "BOTH" else ("ehr_only" if s["route"] == "EHR" else "medquad_only"),
    {"ehr_only": "ehr_only", "ehr_plus_medquad": "ehr_plus_medquad", "medquad_only": "medquad_only"}
)
graph.add_edge("ehr_only", "finalize_llm")
graph.add_edge("ehr_plus_medquad", "finalize_llm")
graph.add_edge("medquad_only", "finalize_llm")
graph.add_edge("finalize_llm", END)

app_llm_ft = graph.compile()
print("\n✅ app_llm_ft compiled (fine-tuned TinyLlama plugged in)")

# -------------------------------------------------------------------
# 8) Rodadas de demonstração
# -------------------------------------------------------------------
demo_pid = patients.sample(1).iloc[0].patient_id

demo_questions = [
    f"What are the active diagnoses for patient {demo_pid}?",
    f"Does patient {demo_pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?",
    "What is hypertension and what are common management approaches?",
    f"Explain the clinical meaning of HbA1c and how it is used in older adults, using patient {demo_pid} as context.",
]

for q in demo_questions:
    print("\n" + "="*100)
    print("QUERY:", q)
    result = app_llm_ft.invoke({"input": q})
    print(result["final"])


✅ Base model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
✅ LoRA adapter: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/models/tinyllama_medquad_lora
✅ Log file: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/logs/assistant_runs.jsonl
CUDA available: True
GPU: NVIDIA L4
VRAM free/total (GB): 7.09 / 23.80
✅ Loaded fine-tuned model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + LoRA(tinyllama_medquad_lora)

SANITY CHECK (model generation):
Summary : Hypertension is high blood pressure. It can be caused by many factors, including age, race, gender, genetics, and lifestyle. In most cases, it does not cause any symptoms. But if you have high blood pressure, you may experience - Headaches - Nausea - Fatigue - Shortness of breath - Weakness - Swelling in your legs and feet - Chest pain - Back pain - Painful urination - Blurred vision - Dizziness - Poor sleep - Weight gain - High blood pressure is a serious problem that can lead to heart disease, stroke, kidney failure, and other complications

✅ a

<p style="white-space: nowrap;"><strong>Comentário (Bloco 10)</strong> –Carga do modelo base TinyLlama (1.1B) e aplica o adaptador LoRA treinado no bloco anterior (fine-tuning com MedQuAD), formando a versão “especializada” do LLM. Em seguida, ele conecta esse modelo ao pipeline do assistente construído anteriormente (LangGraph), mantendo a lógica de roteamento entre consultas específicas do paciente (EHR) e conhecimento médico geral (MedQuAD/RAG). O bloco também reforça regras de segurança (ex.: evitar prescrição e bloquear interpretações de exames) e gera uma seção de fontes determinística baseada nos campos do EHR e nos trechos recuperados do MedQuAD. Por fim, executa perguntas de demonstração e salva logs das interações em /logs/assistant_runs.jsonl, permitindo auditoria e evidência de funcionamento do assistente.

# 11_Ligar o assistente médico e vê-lo funcionando

In [14]:
#  Carga de dados + FAISS + TinyLlama com LoRA + aplicação + demonstrações

!pip -q install -U transformers accelerate peft langchain langchain-community langchain-core langgraph sentence-transformers faiss-cpu

import gc
import json
import re
from pathlib import Path
from datetime import datetime, timezone
from typing import TypedDict, Any, Dict, List

import pandas as pd
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

from langgraph.graph import StateGraph, END

# -----------------------------
# 0) Caminhos
# -----------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")

MEDQUAD_CSV = PROJECT_ROOT / "data" / "processed" / "medquad_qa.csv"
EHR_DIR = PROJECT_ROOT / "data" / "synthetic" / "ehr"

# ✅ Resultado final do adaptador
ADAPTER_DIR = PROJECT_ROOT / "models" / "tinyllama_medquad_lora"
RUNS_LOG = PROJECT_ROOT / "logs" / "assistant_runs.jsonl"

# ✅ O modelo base deve corresponder ao adaptador
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

required_files = [
    MEDQUAD_CSV,
    EHR_DIR / "patients.csv",
    EHR_DIR / "visits.csv",
    EHR_DIR / "diagnoses.csv",
    EHR_DIR / "prescriptions.csv",
    EHR_DIR / "labs.csv",
]
for p in required_files:
    if not p.exists():
        raise FileNotFoundError(f"Missing required file: {p}")

if not ADAPTER_DIR.exists():
    raise FileNotFoundError(f"Missing adapter dir: {ADAPTER_DIR}")

print("✅ Paths OK")
print(" - MedQuAD:", MEDQUAD_CSV)
print(" - EHR dir:", EHR_DIR)
print(" - Adapter:", ADAPTER_DIR)
print(" - Logs:", RUNS_LOG)

# -----------------------------
# 1) Limpeza da VRAM
# -----------------------------
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

print("CUDA available:", torch.cuda.is_available())
if not torch.cuda.is_available():
    raise RuntimeError("GPU runtime required (T4/L4/A100).")

print("GPU:", torch.cuda.get_device_name(0))
free, total = torch.cuda.mem_get_info()
print(f"VRAM free/total (GB): {free/1e9:.2f} / {total/1e9:.2f}")

# -----------------------------
# 2) Carregar dataframes
# -----------------------------
qa_df = pd.read_csv(MEDQUAD_CSV)

patients = pd.read_csv(EHR_DIR / "patients.csv")
visits = pd.read_csv(EHR_DIR / "visits.csv")
diagnoses = pd.read_csv(EHR_DIR / "diagnoses.csv")
prescriptions = pd.read_csv(EHR_DIR / "prescriptions.csv")
labs = pd.read_csv(EHR_DIR / "labs.csv")

print("✅ Loaded:")
print(" - MedQuAD QAs:", len(qa_df))
print(" - EHR patients:", len(patients), "| visits:", len(visits), "| dx:", len(diagnoses), "| rx:", len(prescriptions), "| labs:", len(labs))

# -----------------------------
# 3) Reconstruir o FAISS a partir do MedQuAD
# -----------------------------
docs = []
for _, row in qa_df.iterrows():
    q = str(row["question"]).strip()
    a = str(row["answer"]).strip()
    content = f"QUESTION: {q}\nANSWER: {a}"
    meta = {"topic": row.get("topic", ""), "source_file": row.get("source_file", "")}
    docs.append(Document(page_content=content, metadata=meta))

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vs = FAISS.from_documents(docs, embeddings)
print("✅ FAISS vector store ready")

def medquad_search(query: str, k: int = 4) -> List[Dict[str, Any]]:
    hits = vs.similarity_search(query, k=k)
    return [{"content": h.page_content, "metadata": h.metadata} for h in hits]

# -----------------------------
# 4) Utilitários de prontuário (determinísticos)
# -----------------------------
PATIENT_ID_RE = re.compile(r"\bP\d{5}\b")

def get_patient_summary(pid: str) -> dict:
    p = patients[patients.patient_id == pid]
    if p.empty:
        return {"error": "Patient not found"}
    p = p.iloc[0].to_dict()
    return {
        "patient_id": p["patient_id"],
        "full_name": p["full_name"],
        "age": int(p["age"]),
        "sex": p["sex"],
        "contact": {"phone": p["phone"], "email": p["email"]},
    }

def get_active_diagnoses(pid: str) -> list:
    df = diagnoses[(diagnoses.patient_id == pid) & (diagnoses.status.isin(["active", "chronic"]))]
    if df.empty:
        return []
    return (
        df[["diagnosis_name", "icd10_code", "status", "diagnosis_date"]]
        .sort_values("diagnosis_date", ascending=False)
        .to_dict(orient="records")
    )

def get_latest_labs(pid: str, n: int = 5) -> list:
    df = labs[labs.patient_id == pid]
    if df.empty:
        return []
    df = df.sort_values("lab_date", ascending=False).head(n)
    return df[["test_name", "value", "unit", "reference_range", "lab_date"]].to_dict(orient="records")

def list_recent_visits(pid: str, n: int = 3) -> list:
    df = visits[visits.patient_id == pid]
    if df.empty:
        return []
    df = df.sort_values("visit_date", ascending=False).head(n)
    return df[["visit_date", "provider_name", "specialty", "reason_for_visit"]].to_dict(orient="records")

def check_missing_labs(pid: str, required_tests: list) -> list:
    df = labs[labs.patient_id == pid]
    existing = set(df.test_name.unique())
    return [t for t in required_tests if t not in existing]

# -----------------------------
# 5) Roteador
# -----------------------------
EHR_KEYWORDS = [
    "patient", "visit", "visits", "diagnosis", "diagnoses", "condition", "conditions",
    "prescription", "medication", "lab", "labs", "a1c", "hba1c", "ldl", "creatinine",
    "blood pressure", "latest", "recent", "missing", "pending", "history"
]

def route_intent(user_text: str) -> str:
    t = user_text.lower()
    has_pid = bool(PATIENT_ID_RE.search(user_text))
    has_ehr_kw = any(k in t for k in EHR_KEYWORDS)
    if has_pid and any(x in t for x in ["what is", "explain", "significance", "meaning", "risks", "why"]):
        return "BOTH"
    if has_pid or has_ehr_kw:
        return "EHR"
    return "MEDQUAD"

# -----------------------------
# 6) Carregar modelo com fine-tuning (TinyLlama + adaptador LoRA), sem usar bitsandbytes
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
base_model.config.use_cache = False

model = PeftModel.from_pretrained(base_model, str(ADAPTER_DIR))
model.eval()

loaded_model_name = f"{BASE_MODEL} + LoRA({ADAPTER_DIR.name})"
print("✅ Loaded model:", loaded_model_name)

@torch.inference_mode()
def generate_text(prompt: str, max_new_tokens: int = 260) -> str:
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.35,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True).strip()
    if "[/INST]" in text:
        text = text.split("[/INST]", 1)[1].strip()
    return text.replace("</s>", "").strip()

def log_run(payload: Dict[str, Any]):
    payload["timestamp_utc"] = datetime.now(timezone.utc).isoformat()
    payload["model"] = loaded_model_name
    RUNS_LOG.parent.mkdir(parents=True, exist_ok=True)
    with open(RUNS_LOG, "a", encoding="utf-8") as f:
        f.write(json.dumps(payload, ensure_ascii=False) + "\n")

# -----------------------------
# 7) Guardrails: anexação determinística de fontes + interpretação laboratorial estrita
# -----------------------------
def strip_sources_if_any(text: str) -> str:
    parts = re.split(r"\n\s*Sources\s*:\s*\n", text, flags=re.IGNORECASE)
    return parts[0].strip()

LAB_INTERPRETATION_RE = re.compile(
    r"\b("
    r"normal|abnormal|elevated|high|low|within (the )?normal|outside (the )?normal|"
    r"poor glycemic control|well[- ]controlled|uncontrolled|controlled|"
    r"indicates (diabetes|prediabetes)|suggests (diabetes|prediabetes)|"
    r"goal|target|above goal|below goal"
    r")\b", re.IGNORECASE
)

def lab_safe_summary(ehr: Dict[str, Any]) -> str:
    pid = ehr.get("patient_id", "this patient")
    labs_list = ehr.get("latest_labs", [])[:5] if isinstance(ehr.get("latest_labs"), list) else []
    lines = []
    for lab in labs_list:
        tn = lab.get("test_name"); val = lab.get("value")
        unit = lab.get("unit", ""); date = lab.get("lab_date", "")
        if tn is not None and val is not None:
            lines.append(f"- {tn}: {val} {unit} (date: {date})")
    return (
        f"For patient {pid}, the EHR reports the following recent lab values:\n"
        + ("\n".join(lines) if lines else "- No recent lab values found.")
        + "\n\nInterpretation depends on explicit clinical thresholds and patient context. "
          "Please verify thresholds and clinical context in the chart.\n"
    )

def enforce_no_lab_interpretation(answer: str, ehr: Dict[str, Any]) -> str:
    has_labs = isinstance(ehr, dict) and isinstance(ehr.get("latest_labs"), list) and len(ehr["latest_labs"]) > 0
    if has_labs and LAB_INTERPRETATION_RE.search(answer):
        return lab_safe_summary(ehr)
    return answer

def deterministic_sources(ehr: Dict[str, Any], medq: List[Dict[str, Any]]) -> str:
    ehr_fields = ehr.get("_used_fields", []) if isinstance(ehr, dict) else []
    lines = [f"- EHR fields: {', '.join(ehr_fields) if ehr_fields else 'None'}"]
    if medq:
        used = []
        for i, hit in enumerate(medq[:3], start=1):
            meta = hit.get("metadata") or {}
            used.append(f"[MedQuAD {i}] {meta.get('topic','')} | {meta.get('source_file','')}")
        lines.append("- MedQuAD snippets: " + "; ".join(used))
    else:
        lines.append("- MedQuAD snippets: None")
    return "Sources:\n" + "\n".join(lines)

def build_prompt(user_question: str, route: str, ehr: Dict[str, Any], medquad: List[Dict[str, Any]]) -> str:
    SYSTEM = (
        "You are a helpful medical assistant for physicians. "
        "Answer in English, clearly, and avoid prescribing medications or giving definitive treatment plans. "
        "Use the STRUCTURED EHR JSON only for patient-specific facts. "
        "Use MedQuAD only as general medical knowledge context. "
        "Do NOT include a Sources section; the system will attach sources automatically."
    )
    ehr_block = json.dumps(ehr, indent=2) if ehr else "None"
    medquad_block = "\n\n".join(
        [
            f"[MedQuAD {i}] Topic: {h.get('metadata',{}).get('topic','')}\n"
            f"SourceFile: {h.get('metadata',{}).get('source_file','')}\n"
            f"{(h.get('content') or '')[:900]}"
            for i, h in enumerate(medquad[:3], start=1)
        ]
    ) if medquad else "None"

    return (
        f"<s>[INST] {SYSTEM}\n\n"
        f"ROUTE: {route}\n\n"
        f"USER QUESTION:\n{user_question}\n\n"
        f"STRUCTURED EHR JSON:\n{ehr_block}\n\n"
        f"MEDQUAD RETRIEVAL:\n{medquad_block}\n\n"
        f"Write a concise, helpful answer following the rules. [/INST]\n"
    )

# -----------------------------
# 8) Construir aplicação LangGraph
# -----------------------------
class AgentState(TypedDict, total=False):
    input: str
    route: str
    ehr: Dict[str, Any]
    medquad: List[Dict[str, Any]]
    final: str

def node_route(state: AgentState) -> AgentState:
    return {"route": route_intent(state["input"])}

def node_ehr(state: AgentState) -> AgentState:
    text = state["input"]
    pid_match = PATIENT_ID_RE.search(text)
    pid = pid_match.group(0) if pid_match else None
    if not pid:
        return {"ehr": {"error": "No patient_id found (expected like P00001).", "_task_type": "ehr_missing_patient_id", "_used_fields": []}}

    t = text.lower()
    out: Dict[str, Any] = {"patient_id": pid}
    out["summary"] = get_patient_summary(pid)
    used_fields = ["summary"]

    if "missing" in t or "pending" in t:
        out["_task_type"] = "ehr_missing_labs_check"
        out["missing_labs"] = check_missing_labs(pid, ["Hemoglobin A1c", "LDL cholesterol", "Creatinine"])
        used_fields.append("missing_labs")
    else:
        out["_task_type"] = "ehr_general"

        if any(k in t for k in ["diagnos", "diagnoses", "condition", "conditions", "problem", "problems"]):
            out["active_diagnoses"] = get_active_diagnoses(pid)
            used_fields.append("active_diagnoses")

        if any(k in t for k in ["lab", "a1c", "hba1c", "ldl", "creatinine", "blood pressure", "latest"]):
            out["latest_labs"] = get_latest_labs(pid, n=5)
            used_fields.append("latest_labs")

        if any(k in t for k in ["visit", "recent", "history"]):
            out["recent_visits"] = list_recent_visits(pid, n=3)
            used_fields.append("recent_visits")

    out["_used_fields"] = used_fields
    return {"ehr": out}

def node_medquad(state: AgentState) -> AgentState:
    return {"medquad": medquad_search(state["input"], k=4)}

def node_both(state: AgentState) -> AgentState:
    out = {}
    out.update(node_ehr(state))
    out.update(node_medquad(state))
    return out

def node_finalize(state: AgentState) -> AgentState:
    route = state.get("route", "UNKNOWN")
    user_q = state["input"]
    ehr = state.get("ehr", {}) or {}
    medq = state.get("medquad", []) or []

    # Template determinístico para exames faltantes (sem LLM)
    if isinstance(ehr, dict) and ehr.get("_task_type") == "ehr_missing_labs_check":
        missing = ehr.get("missing_labs", [])
        pid = ehr.get("patient_id")
        if isinstance(missing, list) and len(missing) == 0:
            body = (
                f"According to the structured EHR data, patient {pid} has all of the requested labs "
                f"(Hemoglobin A1c, LDL cholesterol, Creatinine) on record."
            )
        else:
            body = (
                f"According to the structured EHR data, patient {pid} is missing the following requested lab(s): "
                f"{', '.join(missing) if isinstance(missing, list) else str(missing)}."
            )
        final = body + "\n\n" + deterministic_sources(ehr, [])
        log_run({"route": route, "question": user_q, "ehr_used": True, "medquad_used": False, "patient_id": pid})
        return {"final": final}

    prompt = build_prompt(user_q, route, ehr, medq)
    body = generate_text(prompt, max_new_tokens=260)
    body = strip_sources_if_any(body)
    body = enforce_no_lab_interpretation(body, ehr)

    final = body.strip() + "\n\n" + deterministic_sources(ehr, medq)
    log_run({"route": route, "question": user_q, "ehr_used": bool(ehr), "medquad_used": bool(medq), "patient_id": ehr.get("patient_id") if isinstance(ehr, dict) else None})
    return {"final": final}

graph = StateGraph(AgentState)
graph.add_node("route", node_route)
graph.add_node("ehr", node_ehr)
graph.add_node("medquad", node_medquad)
graph.add_node("both", node_both)
graph.add_node("finalize", node_finalize)

graph.set_entry_point("route")
graph.add_conditional_edges(
    "route",
    lambda s: "both" if s["route"] == "BOTH" else ("ehr" if s["route"] == "EHR" else "medquad"),
    {"ehr": "ehr", "both": "both", "medquad": "medquad"}
)
graph.add_edge("ehr", "finalize")
graph.add_edge("both", "finalize")
graph.add_edge("medquad", "finalize")
graph.add_edge("finalize", END)

app_llm_ft = graph.compile()
print("✅ app_llm_ft compiled (run-from-scratch ready)")

# -----------------------------
# 9) Rodadas de demonstração
# -----------------------------
demo_pid = patients.sample(1).iloc[0].patient_id
demo_questions = [
    f"What are the active diagnoses for patient {demo_pid}?",
    f"Does patient {demo_pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?",
    "What is hypertension and what are common management approaches?",
    f"Explain the clinical meaning of HbA1c and how it is used in older adults, using patient {demo_pid} as context.",
]

for q in demo_questions:
    print("\n" + "="*100)
    print("QUERY:", q)
    result = app_llm_ft.invoke({"input": q})
    print(result["final"])


✅ Paths OK
 - MedQuAD: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/processed/medquad_qa.csv
 - EHR dir: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/synthetic/ehr
 - Adapter: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/models/tinyllama_medquad_lora
 - Logs: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/logs/assistant_runs.jsonl
CUDA available: True
GPU: NVIDIA L4
VRAM free/total (GB): 7.04 / 23.80
✅ Loaded:
 - MedQuAD QAs: 1750
 - EHR patients: 250 | visits: 1252 | dx: 2480 | rx: 1267 | labs: 1780
✅ FAISS vector store ready
✅ Loaded model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + LoRA(tinyllama_medquad_lora)
✅ app_llm_ft compiled (run-from-scratch ready)

QUERY: What are the active diagnoses for patient P00206?
Summary : Patient has chronic kidney disease stage 3. The diagnosis is chronic kidney disease (stage 3) and the status is chronic. Diagnosis date is 2025-10-16. Diagnosis code is N18.30. Status is chronic. Diagnosis date is 2025-09-22. Diag

<p style="white-space: nowrap;"><strong>Comentário (Bloco 11)</strong> –Este bloco reconstrói todo o pipeline do assistente do zero em uma única execução: ele carrega o MedQuAD processado e as tabelas do EHR sintético, recria o índice FAISS para busca semântica (RAG) no MedQuAD e define funções determinísticas para consultar dados clínicos do paciente (diagnósticos, exames, visitas, etc.). Em seguida, ele carrega o modelo TinyLlama e aplica o adaptador LoRA treinado, formando o LLM especializado. O bloco monta um fluxo no LangGraph que roteia a pergunta para EHR, MedQuAD ou ambos, gera a resposta em inglês com o LLM e anexa fontes determinísticas (campos do EHR e trechos do MedQuAD), além de aplicar guardrails para evitar interpretação indevida de exames e recomendações prescritivas. Por fim, executa perguntas de demonstração e registra cada execução em /logs/assistant_runs.jsonl, deixando evidências auditáveis para o assistente médico.

# 12_Gerar um pacote de demos em .jsonl

In [15]:
# Regerar JSONL de demos com remoção do prompt-echo

import json
import random
from pathlib import Path
from datetime import datetime, timezone

# -----------------------------
# 0) Validações de segurança
# -----------------------------
missing = []
for name in ["patients", "PATIENT_ID_RE", "app_llm_ft"]:
    if name not in globals():
        missing.append(name)

if missing:
    raise RuntimeError(
        "Missing required objects from the previous blocks: "
        + ", ".join(missing)
        + ".\nRun BLOCK 11 first (run-from-scratch) and then run this block."
    )

# Reprodutibilidade
random.seed(42)

# -----------------------------
# 1) Caminho de saída
# -----------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
DEMO_DIR = PROJECT_ROOT / "demos"
DEMO_DIR.mkdir(parents=True, exist_ok=True)

ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
OUT_PATH = DEMO_DIR / f"demo_runs_clean_{ts}.jsonl"

print("✅ Demo output will be saved to:", OUT_PATH)

# -----------------------------
# 2) Construir um conjunto de questões de demonstração
# -----------------------------
unique_pids = list(patients["patient_id"].unique())
k_pids = min(5, len(unique_pids))
sample_pids = random.sample(unique_pids, k=k_pids)

ehr_templates = [
    "What are the active diagnoses for patient {pid}?",
    "List the 3 most recent visits for patient {pid}.",
    "What are the most recent lab results for patient {pid}?",
    "Does patient {pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?",
]

both_templates = [
    "Explain hypertension in plain language, and relate it to the active diagnoses for patient {pid}.",
    "Explain what HbA1c measures and show the most recent HbA1c-related context for patient {pid}.",
    "Explain chronic kidney disease briefly and relate it to patient {pid}'s diagnoses or labs if present.",
]

medquad_questions = [
    "What is hypertension and what are common management approaches?",
    "What is diabetes and how is it usually managed?",
    "What is asthma and what are common triggers and treatments?",
    "What is hypothyroidism and what are common symptoms?",
    "What is gastroesophageal reflux disease (GERD) and how is it managed?",
    "What is osteoarthritis and what are common management strategies?",
]

demo_questions = []
for pid in sample_pids:
    demo_questions.append(ehr_templates[0].format(pid=pid))
    demo_questions.append(ehr_templates[3].format(pid=pid))
    demo_questions.append(both_templates[0].format(pid=pid))

k_med = min(6, len(medquad_questions))
demo_questions += random.sample(medquad_questions, k=k_med)
random.shuffle(demo_questions)

print("✅ Total demo questions:", len(demo_questions))
print("Sample:", demo_questions[:3])

# -----------------------------
# 3) Auxiliares de limpeza
# -----------------------------
def strip_inst_echo(text: str) -> str:
    """
    If model echoed the full prompt, remove everything up to [/INST].
    Also remove leading <s> token if present.
    """
    t = (text or "").strip()
    t = t.replace("<s>", "").strip()
    if "[/INST]" in t:
        t = t.split("[/INST]", 1)[1].strip()
    t = t.replace("</s>", "").strip()
    return t

def extract_sources_block(text: str):
    if "\n\nSources:\n" in text:
        answer, sources = text.split("\n\nSources:\n", 1)
        return answer.strip(), "Sources:\n" + sources.strip()
    if "\nSources:\n" in text:
        answer, sources = text.split("\nSources:\n", 1)
        return answer.strip(), "Sources:\n" + sources.strip()
    return text.strip(), None

def infer_patient_id(question: str):
    m = PATIENT_ID_RE.search(question)
    return m.group(0) if m else None

# -----------------------------
# 4) Executar assistente e salvar JSONL
# -----------------------------
records = []
for i, q in enumerate(demo_questions, start=1):
    result = app_llm_ft.invoke({"input": q})
    final_raw = result.get("final", "")

    # clean prompt echo FIRST
    final_clean = strip_inst_echo(final_raw)

    answer, sources = extract_sources_block(final_clean)
    pid = infer_patient_id(q)

    rec = {
        "timestamp_utc": datetime.now(timezone.utc).isoformat(),
        "demo_id": i,
        "question": q,
        "patient_id": pid,
        "answer": answer,
        "sources": sources,
    }
    records.append(rec)

    if i % 5 == 0:
        print(f"  ... completed {i}/{len(demo_questions)}")

with open(OUT_PATH, "w", encoding="utf-8") as f:
    for r in records:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("✅ Saved CLEAN demo JSONL:", OUT_PATH)

# -----------------------------
# 5) Pré-visualização
# -----------------------------
print("\n--- Preview (first 3) ---")
for r in records[:3]:
    print("\n" + "="*100)
    print("Q:", r["question"])
    print("Patient:", r["patient_id"])
    print("Answer:", (r["answer"][:600] + (" ..." if len(r["answer"]) > 600 else "")))
    print(r["sources"])


This is a friendly reminder - the current text generation call has exceeded the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


✅ Demo output will be saved to: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/demos/demo_runs_clean_20251224_110523.jsonl
✅ Total demo questions: 21
Sample: ['Does patient P00007 have missing labs for HbA1c, LDL cholesterol, and Creatinine?', 'Explain hypertension in plain language, and relate it to the active diagnoses for patient P00029.', 'Explain hypertension in plain language, and relate it to the active diagnoses for patient P00190.']
  ... completed 5/21
  ... completed 10/21
  ... completed 15/21
  ... completed 20/21
✅ Saved CLEAN demo JSONL: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/demos/demo_runs_clean_20251224_110523.jsonl

--- Preview (first 3) ---

Q: Does patient P00007 have missing labs for HbA1c, LDL cholesterol, and Creatinine?
Patient: P00007
Answer: According to the structured EHR data, patient P00007 is missing the following requested lab(s): Hemoglobin A1c.
Sources:
- EHR fields: summary, missing_labs
- MedQuAD snippets: None

Q: Explain hyperte

<p style="white-space: nowrap;"><strong>Comentário (Bloco 12)</strong> –Este passo executa o assistente final já treinado e orquestrado para gerar um conjunto padronizado de demonstrações (demos), salvando os resultados em um arquivo JSONL limpo e reprodutível. Ele constrói perguntas que cobrem três cenários essenciais a serem demonstrados: consultas clínicas estruturadas (EHR), conhecimento médico geral (MedQuAD) e perguntas híbridas (EHR + explicação). O bloco remove automaticamente qualquer eco de prompt do modelo (como [INST]...[/INST]), separa resposta e fontes determinísticas e registra cada interação com timestamp e identificador do paciente quando aplicável. Como saída, ele cria um arquivo demo_runs_clean_*.jsonl na pasta /demos, que serve como evidência final de funcionamento, facilitando avaliação, auditoria e apresentação do projeto.auditáveis para o assistente médico.

# 13_Avaliação Base vs Fine-Tuned: Comparação de Respostas (Antes e Depois do QLoRA)

In [16]:
#  Comparação entre modelo base e fine-tuned (TinyLlama base vs TinyLlama+LoRA)

!pip -q install -U transformers accelerate peft

import gc
import json
import random
import re
from pathlib import Path
from datetime import datetime, timezone
from typing import Dict, Any, List

import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

# ============================================================================
# 0) Pré-condições
#    - Requer que o app_llm_ft já esteja disponível
#    - Requer que patients e PATIENT_ID_RE já estejam disponíveis
# ============================================================================
required_globals = ["app_llm_ft", "patients", "PATIENT_ID_RE"]
missing = [g for g in required_globals if g not in globals()]
if missing:
    raise RuntimeError(
        "Missing required objects in the current runtime: "
        + ", ".join(missing)
        + "\nRun Block 11 first to build the assistant objects."
    )

random.seed(42)

PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
OUT_DIR = PROJECT_ROOT / "eval"
OUT_DIR.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
OUT_PATH = OUT_DIR / f"base_vs_ft_{ts}.jsonl"

print("✅ Output will be saved to:", OUT_PATH)

# ============================================================================
# 1) Carregar modelo base (sem LoRA), compatível com a base usada no fine-tuning
# ============================================================================
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if base_tokenizer.pad_token is None:
    base_tokenizer.pad_token = base_tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map={"": 0} if torch.cuda.is_available() else "cpu",
)
base_model.eval()

print("✅ Loaded BASE model:", BASE_MODEL)

@torch.inference_mode()
def base_generate(prompt: str, max_new_tokens: int = 220) -> str:
    inputs = base_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(base_model.device) for k, v in inputs.items()}
    out = base_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.35,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=base_tokenizer.eos_token_id,
        eos_token_id=base_tokenizer.eos_token_id,
    )
    text = base_tokenizer.decode(out[0], skip_special_tokens=True)

    # robustly remove prompt echo if present
    t = text.strip().replace("<s>", "").replace("</s>", "").strip()
    if "[/INST]" in t:
        t = t.split("[/INST]", 1)[1].strip()
    return t

# ============================================================================
# 2) Gerador de prompt no mesmo formato utilizado durante o treinamento
# ============================================================================
SYSTEM = (
    "You are a helpful medical assistant. "
    "Answer clearly, in English, and avoid prescribing. "
    "Provide general educational information and recommend consulting a licensed clinician for decisions."
)

def build_inst_prompt(question: str, medquad_hits: List[Dict[str, Any]] = None) -> str:
    ctx = ""
    if medquad_hits:
        pieces = []
        for i, h in enumerate(medquad_hits[:2], start=1):
            meta = h.get("metadata", {})
            pieces.append(
                f"[MedQuAD {i}] Topic: {meta.get('topic','')}\n"
                f"SourceFile: {meta.get('source_file','')}\n"
                f"{(h.get('content') or '')[:900]}"
            )
        ctx = "\n\nCONTEXT (MedQuAD retrieval):\n" + "\n\n".join(pieces) + "\n"

    return (
        f"<s>[INST] {SYSTEM}\n\n"
        f"Question: {question}\n"
        f"{ctx}\n"
        f"Answer in English. [/INST]\n"
    )

# ============================================================================
# 3) Conjunto de questões de avaliação
# ============================================================================
general_questions = [
    "What is hypertension and what are common management approaches?",
    "What is type 2 diabetes and how is it usually managed?",
    "What is asthma and what are common triggers and treatments?",
    "What is hypothyroidism and what are common symptoms?",
    "What is gastroesophageal reflux disease (GERD) and how is it managed?",
    "What is chronic kidney disease and why is it clinically important?",
    "What is HbA1c and what does it measure?",
    "How can older adults reduce fall risk at home?",
    "What are common warning signs of a stroke?",
    "What are common side effects of blood pressure medicines in general terms?",
]

sample_pids = random.sample(list(patients["patient_id"].unique()), k=min(3, len(patients)))
patient_questions = (
    [f"What are the active diagnoses for patient {pid}?" for pid in sample_pids] +
    [f"Does patient {pid} have missing labs for HbA1c, LDL cholesterol, and Creatinine?" for pid in sample_pids[:2]]
)

eval_questions = general_questions + patient_questions
random.shuffle(eval_questions)

print("✅ Evaluation questions:", len(eval_questions))
print("Sample:", eval_questions[:3])

# ============================================================================
# 4) Executar comparações
# ============================================================================
def extract_sources_block(text: str):
    if "\n\nSources:\n" in text:
        answer, sources = text.split("\n\nSources:\n", 1)
        return answer.strip(), "Sources:\n" + sources.strip()
    if "\nSources:\n" in text:
        answer, sources = text.split("\nSources:\n", 1)
        return answer.strip(), "Sources:\n" + sources.strip()
    return text.strip(), None

def infer_patient_id(q: str):
    m = PATIENT_ID_RE.search(q)
    return m.group(0) if m else None

records = []
for i, q in enumerate(eval_questions, start=1):
    pid = infer_patient_id(q)

    medquad_hits = None
    if pid is None and "medquad_search" in globals():
        medquad_hits = medquad_search(q, k=4)

    base_prompt = build_inst_prompt(q, medquad_hits=medquad_hits)
    base_ans = base_generate(base_prompt, max_new_tokens=220).strip()

    ft_result = app_llm_ft.invoke({"input": q})
    ft_final = ft_result["final"]
    ft_answer, ft_sources = extract_sources_block(ft_final)

    rec = {
        "timestamp_utc": datetime.now(timezone.utc).isoformat(),
        "eval_id": i,
        "question": q,
        "patient_id": pid,
        "base_model": BASE_MODEL,
        "base_answer": base_ans,
        "fine_tuned_answer": ft_answer,
        "fine_tuned_sources": ft_sources,
        "note": "Patient-specific questions: BASE model is not given EHR JSON; fine-tuned assistant may use EHR+guardrails."
    }
    records.append(rec)

    if i % 5 == 0:
        print(f"  ... completed {i}/{len(eval_questions)}")

with open(OUT_PATH, "w", encoding="utf-8") as f:
    for r in records:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("✅ Saved comparison JSONL:", OUT_PATH)

print("\n--- Preview (first 3) ---")
for r in records[:3]:
    print("\n" + "="*100)
    print("Q:", r["question"])
    print("\nBASE ANSWER:\n", r["base_answer"][:900], "..." if len(r["base_answer"]) > 900 else "")
    print("\nFINE-TUNED ANSWER:\n", r["fine_tuned_answer"][:900], "..." if len(r["fine_tuned_answer"]) > 900 else "")
    if r["fine_tuned_sources"]:
        print("\n", r["fine_tuned_sources"])


✅ Output will be saved to: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/eval/base_vs_ft_20251224_110841.jsonl
CUDA available: True
GPU: NVIDIA L4
✅ Loaded BASE model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
✅ Evaluation questions: 15
Sample: ['How can older adults reduce fall risk at home?', 'What are the active diagnoses for patient P00164?', 'What is chronic kidney disease and why is it clinically important?']
  ... completed 5/15
  ... completed 10/15
  ... completed 15/15
✅ Saved comparison JSONL: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/eval/base_vs_ft_20251224_110841.jsonl

--- Preview (first 3) ---

Q: How can older adults reduce fall risk at home?

BASE ANSWER:
 [MedQuAD 3] Topic: 8_Medicine_QA
SourceFile: /content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/data/raw/medquad/8_Medicine_QA/0000196.xml
QUESTION: Can you provide me with general information about how to reduce the risk of falls in older adults?
ANSWER: Yes, here are some general tips for reducing the

<p style="white-space: nowrap;"><strong>Comentário (Bloco 13)</strong> – Esse último processo executa uma comparação sistemática entre o modelo base (sem LoRA) e o assistente fine-tuned (com LoRA), utilizando o mesmo formato de prompt para ambos. São avaliadas perguntas médicas gerais e perguntas com contexto de paciente. Para questões gerais, ambos os modelos recebem apenas contexto MedQuAD; para perguntas com paciente, apenas o assistente fine-tuned utiliza dados estruturados de EHR e guardrails de segurança, enquanto o modelo base não recebe dados clínicos. As respostas de ambos os modelos são registradas em formato JSONL, permitindo análise qualitativa do impacto do fine-tuning, da recuperação de conhecimento e das regras de segurança clínica.

# Geração do README do trabalho do TC3 (Tech Challenge 3)

In [18]:
from pathlib import Path

# --------------------------------------------------
# Path
# --------------------------------------------------
PROJECT_ROOT = Path("/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3")
README_PATH = PROJECT_ROOT / "README.md"

# --------------------------------------------------
# Conteúdo do README
# --------------------------------------------------
readme_lines = [
"# 🏥 Assistente Médico Clínico com RAG (Retrieval-Augmented Generation), Prontuário Sintético & Fine-Tuning com qLoRA (Quantized Low-Rank Adaptation)",
"",
"## 📌 Visão Geral",
"Este projeto implementa um **assistente clínico inteligente e auditável**, combinando:",
"",
"- 📚 Conhecimento médico geral (MedQuAD + FAISS (Facebook AI Similarity Search))",
"- 🧾 Dados clínicos estruturados (EHR (Electronic Health Record) sintético)",
"- 🧭 Roteamento determinístico de intenções",
"- 🛡️ Guardrails clínicos rígidos",
"- 🧠 Fine-tuning eficiente com LoRA / QLoRA",
"- 🔎 Fontes rastreáveis e explicáveis",
"",
"O objetivo é demonstrar uma **arquitetura segura e reprodutível** para aplicações de IA em saúde.",
"",
"---",
"",
"## 🧩 Pipeline Completo — Blocos 1 a 13",
"",
"### 🔹 Bloco 1 — Setup e Estrutura de Diretórios",
"Configuração do Google Colab + Google Drive.",
"",
"### 🔹 Bloco 2 — Download do Dataset MedQuAD",
"Clonagem e seleção de tópicos médicos relevantes.",
"",
"### 🔹 Bloco 3 — Processamento do MedQuAD",
"Parsing XML, limpeza e geração do `medquad_qa.csv`.",
"",
"### 🔹 Bloco 4 — Geração de EHR Sintético",
"Geração de pacientes, visitas, diagnósticos, prescrições e exames.",
"",
"### 🔹 Bloco 5 — Funções Determinísticas de EHR",
"Camada segura de consulta clínica sem LLM.",
"",
"### 🔹 Bloco 6 — FAISS + Tools + Router + LangGraph",
"Construção do RAG e roteamento de intenções.",
"",
"### 🔹 Bloco 7 — Pipeline RAG Completo",
"Integração EHR + MedQuAD.",
"",
"### 🔹 Bloco 8 — LLM com Guardrails Clínicos",
"Geração textual controlada, sem interpretação clínica.",
"",
"### 🔹 Bloco 9 — Fine-Tuning com LoRA / QLoRA",
"Ajuste eficiente com fallback automático.",
"",
"### 🔹 Bloco 10 — Integração do Modelo Fine-Tuned",
"Substituição transparente do modelo base.",
"",
"### 🔹 Bloco 11 — Execução 'do zero'",
"Inicialização completa do sistema.",
"",
"### 🔹 Bloco 12 — Geração de Demos",
"Geração de JSONL limpo para avaliação.",
"",
"### 🔹 Bloco 13 — Comparação Modelos Base vs Fine-Tuned",
"Comparação qualitativa antes/depois do ajuste.",
"",
"---",
"",
"## ⭐ Diferenciais",
"",
"- Separação clara entre dados clínicos e geração de linguagem",
"- Zero alucinação em prontuários (EHR) - Dados determinísticos",
"- Fontes auditáveis",
"- Guardrails clínicos explícitos",
"- Arquitetura modular e reprodutível",
"",
"---",
"",
"## 🚀 Potenciais Melhorias",
"",
"- Integração FHIR (Fast Healthcare Interoperability Resources)) real",
"- Avaliação quantitativa",
"- Interface web",
"- Dados 100% em Português, tanto do prontuário, quanto das bases de perguntas e respostas",
"",
"---",
"",
"## 📁 Estrutura de Diretórios",
"",
"```",
"Tech_Challenge_3/",
"├── data/",
"│   ├── raw/",
"│   ├── processed/",
"│   └── synthetic/",
"├── models/",
"├── logs/",
"├── demos/",
"├── eval/",
"└── README.md",
"```",
"",
"---",
"",
"⚠️ Esse projeto é a saída principal do Tech Challenge 3 (propósito educacional). Não utilizar para decisões clínicas reais."
]

readme_text = "\n".join(readme_lines)

# --------------------------------------------------
# Criar o README
# --------------------------------------------------
README_PATH.write_text(readme_text, encoding="utf-8")

print("✅ README.md criado com sucesso em:")
print(README_PATH)


✅ README.md criado com sucesso em:
/content/drive/MyDrive/FIAP_PosTech/Tech_Challenge_3/README.md
