# Capstone 2 – Healthcare Claim Coverage Agent
This notebook implements the agent system.

In [None]:

import json, csv
from dataclasses import dataclass
from typing import Dict, Any, List
from pathlib import Path

ROOT = Path('.').resolve()
DATA = ROOT / 'data'

def load_json(path: Path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

def exists(path: Path) -> bool:
    return Path(path).exists()


In [None]:

@dataclass
class PatientSummary:
    patient_id: str
    age: int | None
    gender: str | None
    diagnoses: List[str]
    procedures: List[str]
    claim_procedure_code: str | None
    preauth_id: str | None

@dataclass
class PolicySummary:
    policy_id: str
    covered_procedure_codes: List[str]
    diagnosis_requirements: List[str]
    min_age: int | None
    max_age: int | None
    gender_restriction: str | None
    preauth_required: bool

def normalize_code(code: str | None) -> str | None:
    if not code: return None
    return code.strip().upper().replace('.', '')


In [None]:

def summarize_patient_record(record_str: Dict[str, Any], ref_codes: Dict[str, Any]) -> PatientSummary:
    pid = str(record_str.get("patient_id"))
    age = record_str.get("age")
    gender = (record_str.get("gender") or "").upper() or None
    diagnoses = [normalize_code(d) for d in record_str.get("diagnoses", [])]
    procedures = [normalize_code(p) for p in record_str.get("procedures", [])]
    claim_proc = normalize_code(record_str.get("claim", {}).get("procedure_code"))
    preauth_id = record_str.get("claim", {}).get("preauthorization_id")
    return PatientSummary(pid, age, gender, diagnoses, procedures, claim_proc, preauth_id)

def summarize_policy_guideline(policy_id: str, policies: Dict[str, Any]) -> PolicySummary:
    p = policies[str(policy_id)]
    return PolicySummary(
        policy_id=str(policy_id),
        covered_procedure_codes=[normalize_code(c) for c in p.get("covered_procedure_codes", [])],
        diagnosis_requirements=[normalize_code(d) for d in p.get("diagnosis_requirements", [])],
        min_age=p.get("min_age"),
        max_age=p.get("max_age"),
        gender_restriction=p.get("gender_restriction"),
        preauth_required=bool(p.get("preauthorization_required", False))
    )

def _age_ok(age, min_age, max_age):
    if age is None: return True
    if min_age is not None and age < min_age: return False
    if max_age is not None and age > max_age: return False
    return True

def _gender_ok(gender, restriction):
    if not restriction: return True
    if not gender: return False
    return gender.upper().startswith(restriction.upper())

def _diagnosis_ok(patient_dx, required_dx):
    if not required_dx: return True
    return any(req in d for d in patient_dx for req in required_dx)

def check_claim_coverage(record_summary: PatientSummary, policy_summary: PolicySummary) -> Dict[str, str]:
    if record_summary.claim_procedure_code not in policy_summary.covered_procedure_codes:
        return {"Decision": "ROUTE FOR REVIEW", "Reason": "Procedure not covered."}
    if not _diagnosis_ok(record_summary.diagnoses, policy_summary.diagnosis_requirements):
        return {"Decision": "ROUTE FOR REVIEW", "Reason": "Diagnosis not met."}
    if not _age_ok(record_summary.age, policy_summary.min_age, policy_summary.max_age):
        return {"Decision": "ROUTE FOR REVIEW", "Reason": "Age outside range."}
    if not _gender_ok(record_summary.gender, policy_summary.gender_restriction):
        return {"Decision": "ROUTE FOR REVIEW", "Reason": "Gender restriction not satisfied."}
    if policy_summary.preauth_required and not record_summary.preauth_id:
        return {"Decision": "ROUTE FOR REVIEW", "Reason": "Preauthorization missing."}
    return {"Decision": "APPROVE", "Reason": "Meets all criteria."}


In [None]:

def run_on_records(records_path: Path, policies_path: Path, refs_path: Path, out_csv: Path|None=None):
    records = load_json(records_path)
    policies_raw = load_json(policies_path)
    refs = load_json(refs_path)
    if isinstance(policies_raw, list):
        policies = {str(p['policy_id']): p for p in policies_raw}
    else:
        policies = policies_raw
    results = []
    for rec in records:
        ps = summarize_patient_record(rec, refs)
        pid = rec.get("claim", {}).get("policy_id")
        pol = summarize_policy_guideline(pid, policies)
        decision = check_claim_coverage(ps, pol)
        resp = f"Decision: {decision['Decision']}\nReason: {decision['Reason']}"
        results.append({"patient_id": ps.patient_id, "generated_response": resp})
    if out_csv:
        with open(out_csv,'w',newline='',encoding='utf-8') as f:
            writer = csv.DictWriter(f,fieldnames=['patient_id','generated_response'])
            writer.writeheader(); writer.writerows(results)
    return results


In [None]:

from langchain_core.messages import SystemMessage, HumanMessage

def polish_reason_with_llm(raw_decision: dict, chat=None) -> dict:
    if chat is None: return raw_decision
    sys_prompt = "You are an insurance claim agent. Rewrite the Reason in one professional sentence."
    human = f"Decision: {raw_decision['Decision']}\nReason: {raw_decision['Reason']}"
    try:
        msg = chat.invoke([SystemMessage(content=sys_prompt), HumanMessage(content=human)])
        polished = (msg.content or "").strip()
        if polished:
            return {"Decision": raw_decision["Decision"], "Reason": polished}
    except Exception as e:
        print("LLM polish failed:", e)
    return raw_decision

def run_on_records_with_llm(records_path, policies_path, refs_path, out_csv=None, chat=None):
    base = run_on_records(records_path, policies_path, refs_path)
    results = []
    for r in base:
        lines = r['generated_response'].splitlines()
        dec = lines[0].split(':',1)[1].strip() if lines else ''
        reas = lines[1].split(':',1)[1].strip() if len(lines)>1 else ''
        polished = polish_reason_with_llm({'Decision':dec,'Reason':reas},chat)
        results.append({'patient_id':r['patient_id'],'generated_response':f"Decision: {polished['Decision']}\nReason: {polished['Reason']}"})
    if out_csv:
        with open(out_csv,'w',newline='',encoding='utf-8') as f:
            writer=csv.DictWriter(f,fieldnames=['patient_id','generated_response'])
            writer.writeheader(); writer.writerows(results)
    return results


In [None]:

paths = {
    'validation_records': DATA / 'validation_records.json',
    'test_records': DATA / 'test_records.json',
    'policies': DATA / 'insurance_policies.json',
    'refs': DATA / 'reference_codes.json',
}
if all(exists(p) for p in [paths['validation_records'], paths['policies'], paths['refs']]):
    val_results = run_on_records(paths['validation_records'], paths['policies'], paths['refs'])
    print(val_results[:3])
else:
    print("Validation skipped")


In [None]:

out_path = Path('submission.csv')
if all(exists(p) for p in [paths['test_records'], paths['policies'], paths['refs']]):
    try:
        results = run_on_records_with_llm(paths['test_records'], paths['policies'], paths['refs'], out_csv=out_path, chat=globals().get('chat_client'))
        print("Wrote", out_path, "(LLM-polished)" if globals().get('chat_client') else "(deterministic)")
    except Exception as e:
        print("LLM runner failed:", e)
        _ = run_on_records(paths['test_records'], paths['policies'], paths['refs'], out_csv=out_path)
        print("Wrote", out_path, "(deterministic fallback)")
else:
    print("Test skipped")
