<a href="https://colab.research.google.com/github/nerr22/2015/blob/master/PoC_Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import numpy as np
import pandas as pd

# --- Setup (same as before) ---
features = ["fatigue", "on_metformin", "high_A1C", "pregnant", "GI_upset", "high_BMI"]
actions = ["order_TSH", "increase_metformin", "check_B12", "start_GLP1"]

weights = pd.DataFrame(0.0, index=actions, columns=features)

# initialize some random-ish weights
weights.loc["order_TSH", ["fatigue", "high_A1C"]] = [1.5, 0.5]
weights.loc["increase_metformin", ["on_metformin", "high_A1C", "GI_upset"]] = [1.0, 1.5, -2.0]
weights.loc["check_B12", ["on_metformin", "GI_upset"]] = [2.0, 1.0]
weights.loc["start_GLP1", ["on_metformin", "high_A1C", "pregnant"]] = [1.0, 2.0, -3.0]

# --- Learning parameters ---
eta = 0.1  # learning rate

# --- Function to get scores ---
def score_action(patient, action):
    return float((weights.loc[action] * patient).sum())

# --- Function to update weights from feedback ---
def update_weights(patient, action, feedback):
    score = score_action(patient, action)
    y_hat = 1 / (1 + np.exp(-score))  # sigmoid
    y = 1.0 if feedback == "approve" else 0.0
    for f, val in patient.items():
        weights.loc[action, f] += eta * (y - y_hat) * val

# --- Step 1: create a patient ---
patient = pd.Series({
    "fatigue": 1, "on_metformin": 1, "high_A1C": 1,
    "pregnant": 0, "GI_upset": 0, "high_BMI": 1
})

# --- Step 2: score before feedback ---
print("Before feedback:")
for a in actions:
    print(f"{a:20s} {score_action(patient,a):+.2f}")

# --- Step 3: doctor gives feedback ---
# Doctor approves 'start_GLP1' (makes sense for high BMI)
update_weights(patient, "start_GLP1", "approve")

print("\nAfter feedback:")
for a in actions:
    print(f"{a:20s} {score_action(patient,a):+.2f}")


Before feedback:
order_TSH            +2.00
increase_metformin   +2.50
check_B12            +2.00
start_GLP1           +3.00

After feedback:
order_TSH            +2.00
increase_metformin   +2.50
check_B12            +2.00
start_GLP1           +3.02


In [2]:
weights

Unnamed: 0,fatigue,on_metformin,high_A1C,pregnant,GI_upset,high_BMI
order_TSH,1.5,0.0,0.5,0.0,0.0,0.0
increase_metformin,0.0,1.0,1.5,0.0,-2.0,0.0
check_B12,0.0,2.0,0.0,0.0,1.0,0.0
start_GLP1,0.004743,1.004743,2.004743,-3.0,0.0,0.004743


In [2]:
import numpy as np
import pandas as pd
from collections import defaultdict

# -----------------------------
# 1) Registries (extensible)
# -----------------------------
FEATURES = set()
ACTIONS  = {}                          # {action_name: {"desc": str}}
GUIDELINE_MODULES = {}                 # {module_name: {action: {"logic": {...}, "strength": "must|should|could"}}}
STRENGTH_PRIOR = {"must": 2.0, "should": 1.0, "could": 0.0}

def add_features(*feature_names):
    for f in feature_names: FEATURES.add(f)

def add_actions(**actions_kv):
    for name, desc in actions_kv.items():
        ACTIONS[name] = {"desc": desc}

def add_guideline_module(name, rules_dict):
    """
    rules_dict action entries may be either:
      - NEW (JSON-logic):
        {"logic": {"and":[{"exists":"on_metformin"},{">=":[{"var":"A1C"},6.5]},{"not":[{"exists":"pregnant"}]}]},
         "strength":"should"}
      - LEGACY (pre/contra):
        {"pre":["on_metformin","high_A1C"], "contra":["pregnant"], "strength":"should"}
    """
    GUIDELINE_MODULES[name] = rules_dict

# -----------------------------
# 2) JSON-logic evaluator
# -----------------------------
def get_val(patient, ref, default=None):
    # ref can be "A1C" or {"var":"A1C"} or {"var":["A1C", default]}
    if isinstance(ref, dict) and "var" in ref:
        v = ref["var"]
        if isinstance(v, list): return patient.get(v[0], v[1])
        return patient.get(v, default)
    if isinstance(ref, str): return patient.get(ref, default)
    return ref  # literal

def eval_logic(node, patient):
    # atoms
    if isinstance(node, bool): return node
    if isinstance(node, str):  return bool(patient.get(node, 0))
    if not isinstance(node, dict): return bool(node)

    # boolean ops
    if "and" in node: return all(eval_logic(n, patient) for n in node["and"])
    if "or"  in node: return any(eval_logic(n, patient) for n in node["or"])
    if "not" in node: return not any(eval_logic(n, patient) for n in node["not"])
    if "exists" in node: return bool(patient.get(node["exists"], 0))

    # comparators
    for op in (">=", ">", "<=", "<", "==", "!="):
        if op in node:
            left, right = node[op]
            L = get_val(patient, left)
            R = get_val(patient, right)
            if L is None: return False
            if op == ">=": return L >= R
            if op ==  ">": return L >  R
            if op == "<=": return L <= R
            if op ==  "<": return L <  R
            if op == "==": return L == R
            if op == "!=": return L != R
    return False

# -----------------------------
# 3) Composable guideline gating
# -----------------------------
def _collect_action_rules(action, modules=None):
    modules = modules or list(GUIDELINE_MODULES.keys())
    found = []
    for m in modules:
        r = GUIDELINE_MODULES.get(m, {}).get(action)
        if r: found.append(r)
    return found

def _legacy_to_logic(r):
    # Convert {"pre":[...], "contra":[...]} → JSON-logic
    pre = r.get("pre", [])
    contra = r.get("contra", [])
    parts = []
    if pre:
        parts.append({"and": pre})  # strings resolve via eval_logic
    if contra:
        parts.append({"not": contra})
    if not parts:
        return {"and": []}  # vacuously true (no info)
    if len(parts) == 1: return parts[0]
    return {"and": parts}

def is_action_allowed(patient: pd.Series, action: str, modules=None) -> bool:
    """
    Allowed if:
      (ANY module's logic evaluates True) AND (no module's logic encodes a direct contradiction)
    With legacy entries, 'pre/contra' are converted to equivalent logic.
    """
    rules = _collect_action_rules(action, modules)
    if not rules: return False

    any_true = False
    for r in rules:
        logic = r.get("logic")
        if logic is None:
            logic = _legacy_to_logic(r)
        if eval_logic(logic, patient):
            any_true = True
            break
    return any_true

def guideline_bias_for(action: str, modules=None) -> float:
    rules = _collect_action_rules(action, modules)
    if not rules: return 0.0
    vals = [STRENGTH_PRIOR.get(r.get("strength", "could"), 0.0) for r in rules]
    return max(vals) if vals else 0.0

def filter_safe_actions(patient: pd.Series, actions_list, modules=None):
    return [a for a in actions_list if is_action_allowed(patient, a, modules)]

# -----------------------------
# 4) Weights (learned prefs)
# -----------------------------
def init_weights(actions_list, feature_list):
    return pd.DataFrame(0.0, index=actions_list, columns=feature_list)

def score_action(patient: pd.Series, action: str, weights: pd.DataFrame, modules=None) -> float:
    bias = guideline_bias_for(action, modules)
    lin  = float((weights.loc[action] * patient).sum())
    return bias + lin

def rank_actions(patient: pd.Series, actions_list, weights: pd.DataFrame, modules=None, top_k=None):
    safe = filter_safe_actions(patient, actions_list, modules)
    scored = [(a, score_action(patient, a, weights, modules)) for a in safe]
    scored.sort(key=lambda t: t[1], reverse=True)
    return scored[:top_k] if top_k else scored

def explain_action(patient: pd.Series, action: str, weights: pd.DataFrame, modules=None):
    bias = guideline_bias_for(action, modules)
    contribs = (weights.loc[action] * patient).replace(0, np.nan).dropna()
    return {"bias_prior": bias, "feature_contribs": contribs.sort_values(key=abs, ascending=False)}

def update_weights(patient: pd.Series, action: str, feedback: str, weights: pd.DataFrame, eta=0.1, modules=None):
    if not is_action_allowed(patient, action, modules):
        return
    score = score_action(patient, action, weights, modules)
    y_hat = 1 / (1 + np.exp(-score))
    y = 1.0 if feedback == "approve" else 0.0
    for f, val in patient.items():
        if val:
            weights.loc[action, f] += eta * (y - y_hat) * val
    weights.loc[action] = weights.loc[action].clip(-5, 5)

# -----------------------------
# 5) Minimal example (mixed: NEW + LEGACY)
# -----------------------------
# Features & Actions
add_features("fatigue", "on_metformin", "high_A1C", "pregnant", "GI_upset", "high_BMI", "A1C")
add_actions(
    order_TSH="Order TSH lab",
    increase_metformin="Increase metformin dose",
    check_B12="Check B12 level",
    start_GLP1="Start GLP-1 agonist"
)

# Module A (NEW JSON-logic): metabolic
add_guideline_module("metabolic", {
    "increase_metformin": {
        "logic": {"and":[{"exists":"on_metformin"}, {"or":[{"exists":"high_A1C"},{">=":[{"var":"A1C"},6.5]}]}, {"not":[{"exists":"GI_upset"}]}]},
        "strength": "could"
    },
    "start_GLP1": {
        "logic": {"and":[{"exists":"on_metformin"},{"or":[{"exists":"high_A1C"},{">=":[{"var":"A1C"},7.0]}]}, {"not":[{"exists":"pregnant"}]}]},
        "strength": "should"
    },
    "check_B12": {
        "logic": {"exists":"on_metformin"},
        "strength": "should"
    }
})

# Module B (LEGACY): fatigue work-up
add_guideline_module("fatigue_workup", {
    "order_TSH": {"pre":["fatigue"], "contra":[], "strength":"should"}
})

# Initialize weights
weights = init_weights(list(ACTIONS.keys()), list(FEATURES))
weights.loc["order_TSH", ["fatigue","high_A1C"]] = [1.5, 0.5]
weights.loc["increase_metformin", ["on_metformin","high_A1C","GI_upset"]] = [1.0, 1.5, -2.0]
weights.loc["check_B12", ["on_metformin","GI_upset"]] = [2.0, 1.0]
weights.loc["start_GLP1", ["on_metformin","high_A1C","pregnant"]] = [1.0, 2.0, -3.0]

# Patient
patient = pd.Series({"fatigue":1, "on_metformin":1, "high_A1C":1, "A1C":7.2,
                     "pregnant":0, "GI_upset":0, "high_BMI":1})

modules = ["metabolic", "fatigue_workup"]

print("Allowed by composed guidelines:")
print(filter_safe_actions(patient, list(ACTIONS.keys()), modules))

print("\nRanked (bias + weights) within guidelines:")
for a, s in rank_actions(patient, list(ACTIONS.keys()), weights, modules):
    print(f"{a:20s} {s:+.2f}")

top = rank_actions(patient, list(ACTIONS.keys()), weights, modules, top_k=1)[0][0]
print(f"\nExplain top action: {top}")
print(explain_action(patient, top, weights, modules))

# Learn from approval
update_weights(patient, top, "approve", weights, eta=0.2, modules=modules)
print("\nAfter feedback (nudged scores):")
for a, s in rank_actions(patient, list(ACTIONS.keys()), weights, modules):
    print(f"{a:20s} {s:+.2f}")

# Pregnancy flip (hard mask)
preg = patient.copy(); preg["pregnant"] = 1
print("\n[Pregnant] Allowed by composed guidelines:")
print(filter_safe_actions(preg, list(ACTIONS.keys()), modules))
print("\n[Pregnant] Ranked:")
for a, s in rank_actions(preg, list(ACTIONS.keys()), weights, modules):
    print(f"{a:20s} {s:+.2f}")


Allowed by composed guidelines:
['order_TSH', 'increase_metformin', 'check_B12', 'start_GLP1']

Ranked (bias + weights) within guidelines:
start_GLP1           +4.00
order_TSH            +3.00
check_B12            +3.00
increase_metformin   +2.50

Explain top action: start_GLP1
{'bias_prior': 1.0, 'feature_contribs': high_A1C        2.0
on_metformin    1.0
dtype: float64}

After feedback (nudged scores):
start_GLP1           +4.20
order_TSH            +3.00
check_B12            +3.00
increase_metformin   +2.50

[Pregnant] Allowed by composed guidelines:
['order_TSH', 'increase_metformin', 'check_B12']

[Pregnant] Ranked:
order_TSH            +3.00
check_B12            +3.00
increase_metformin   +2.50


In [3]:
# ---------- ADD MORE FEATURES & ACTIONS ----------
add_features(
    # existing-ish
    "A1C", "eGFR", "BP_systolic", "LDL",
    "on_metformin", "GI_upset", "pregnant", "fatigue", "high_BMI",
    # new binary convenience flags (can be set by LLM or derived rules)
    "high_A1C", "LDL_high", "ASCVD_high", "CKD_stage3plus",
    "on_SGLT2", "on_ACEi",
    "on_SSRI", "triptan_use",
    "UTI_symptoms", "fever", "dysuria",
    "childbearing_potential", "QTc_prolonged"
)

add_actions(
    # metabolic / DM2
    start_SGLT2="Start SGLT2 inhibitor",
    switch_metformin_ER="Switch to metformin ER",
    order_microalbumin="Order urine albumin/creatinine ratio",
    # CKD / kidney safety
    start_ACEi="Start ACE inhibitor",
    avoid_metformin_escalation="Avoid metformin escalation",
    # lipids / ASCVD
    start_statin="Start statin therapy",
    order_lipid_panel="Order lipid panel",
    # HTN
    start_thiazide="Start thiazide diuretic",
    # psych / neuro interaction
    start_SSRI="Start SSRI",
    order_ECG="Order ECG",
    # reproductive safety
    order_pregnancy_test="Order pregnancy test",
    counsel_contraception="Counsel on contraception",
    # UTI route
    order_urine_culture="Order urine culture",
    start_nitrofurantoin="Start nitrofurantoin"
)

# ---------- ADD COMPOSABLE GUIDELINE MODULES (JSON-logic + legacy mix) ----------

# Metabolic / DM2 (JSON-logic)
add_guideline_module("dm2", {
    "start_SGLT2": {
        "logic": {"and":[
            {"or":[{"exists":"high_A1C"},{">=":[{"var":"A1C"},7.0]}]},
            {"not":[{"exists":"pregnant"}]}
        ]},
        "strength":"should"
    },
    "switch_metformin_ER": {
        "logic": {"and":[{"exists":"on_metformin"},{"exists":"GI_upset"}]},
        "strength":"should"
    },
    "order_microalbumin": {
        "logic": {"and":[{"exists":"on_metformin"},{"or":[{"exists":"CKD_stage3plus"},{">=":[{"var":"A1C"},6.5]}]}]},
        "strength":"should"
    },
    "increase_metformin": {   # from your original set; still allowed if GI_ok
        "logic": {"and":[{"exists":"on_metformin"},{"or":[{"exists":"high_A1C"},{">=":[{"var":"A1C"},7.0]}]},{"not":[{"exists":"GI_upset"}]}]},
        "strength":"could"
    },
    "check_B12": {
        "logic": {"exists":"on_metformin"},
        "strength":"should"
    }
})

# CKD safety / renal dosing (JSON-logic)
add_guideline_module("ckd", {
    "start_ACEi": {
        "logic": {"and":[
            {"or":[{"exists":"CKD_stage3plus"},{ "<=":[{"var":"eGFR"},60]}]},
            {"not":[{"exists":"pregnant"}]}
        ]},
        "strength":"should"
    },
    "avoid_metformin_escalation": {
        "logic": {"or":[{"<=":[{"var":"eGFR"},30]}, {"exists":"CKD_stage3plus"}]},
        "strength":"must"
    }
})

# Lipids / ASCVD (JSON-logic + numeric)
add_guideline_module("lipids", {
    "start_statin": {
        "logic": {"and":[
            {"or":[{"exists":"ASCVD_high"},{"exists":"LDL_high"},{">=":[{"var":"LDL"},190]}]},
            {"not":[{"exists":"pregnant"}]}
        ]},
        "strength":"should"
    },
    "order_lipid_panel": {
        "logic": {"or":[{"exists":"ASCVD_high"},{"exists":"LDL_high"},{"exists":"high_A1C"}]},
        "strength":"should"
    }
})

# Hypertension (LEGACY example)
add_guideline_module("htn", {
    "start_thiazide": {"pre":["BP_systolic"], "contra":["pregnant"], "strength":"could"},
    "order_ECG":      {"pre":["BP_systolic"], "contra":[],            "strength":"could"}
})

# Psych / neuro interaction (JSON-logic)
add_guideline_module("psych_neuro", {
    "start_SSRI": {
        "logic": {"and":[
            {"exists":"fatigue"},             # placeholder for depression screen; swap for PHQ-9 in real system
            {"not":[{"exists":"triptan_use"}]},   # conservative demo rule for serotonin risk
            {"not":[{"exists":"pregnant"}]}
        ]},
        "strength":"could"
    },
    "order_ECG": {
        "logic": {"exists":"QTc_prolonged"},
        "strength":"should"
    }
})

# Repro safety (JSON-logic)
add_guideline_module("repro_safety", {
    "order_pregnancy_test": {
        "logic": {"and":[{"exists":"childbearing_potential"},{"not":[{"exists":"pregnant"}]}]},
        "strength":"should"
    },
    "counsel_contraception": {
        "logic": {"exists":"childbearing_potential"},
        "strength":"could"
    }
})

# UTI quick-path (LEGACY)
add_guideline_module("uti", {
    "order_urine_culture": {"pre":["UTI_symptoms"], "contra":[], "strength":"should"},
    "start_nitrofurantoin": {"pre":["UTI_symptoms"], "contra":["fever"], "strength":"could"}
})

# ---------- INIT WEIGHTS (simple priors so ranking is interesting) ----------
# Leave existing 'weights' if you already created; otherwise:
weights = init_weights(list(ACTIONS.keys()), list(FEATURES))

# A few intuitive priors
weights.loc["start_SGLT2",       ["high_A1C","A1C","on_metformin","high_BMI"]] = [2.0, 0.2, 0.5, 0.5]
weights.loc["switch_metformin_ER",["on_metformin","GI_upset"]]                 = [1.0, 2.0]
weights.loc["order_microalbumin", ["A1C","CKD_stage3plus"]]                    = [0.1, 1.0]

weights.loc["start_ACEi",         ["CKD_stage3plus","eGFR"]]                   = [1.5, -0.01]  # lower eGFR → slightly higher priority
weights.loc["avoid_metformin_escalation", ["eGFR","CKD_stage3plus"]]           = [-0.02, 1.5]  # negative means higher score when eGFR low after bias; fine for demo

weights.loc["start_statin",       ["ASCVD_high","LDL_high","LDL"]]             = [2.0, 1.0, 0.02]
weights.loc["order_lipid_panel",  ["ASCVD_high","high_A1C"]]                   = [0.5, 0.5]

weights.loc["start_thiazide",     ["BP_systolic"]]                             = [0.02]        # higher BP nudges score
weights.loc["order_ECG",          ["QTc_prolonged","BP_systolic"]]             = [2.0, 0.005]

weights.loc["start_SSRI",         ["fatigue"]]                                 = [0.5]
weights.loc["order_pregnancy_test",["childbearing_potential"]]                 = [1.0]
weights.loc["counsel_contraception",["childbearing_potential"]]                = [0.5]

weights.loc["order_urine_culture",["UTI_symptoms"]]                            = [1.5]
weights.loc["start_nitrofurantoin",["UTI_symptoms","fever"]]                   = [1.0, -2.0]   # fever lowers this (prefer eval)

# ---------- DEMOS: MULTIPLE PATIENTS TO SHOW DECISIONING ----------
modules = ["dm2","ckd","lipids","htn","psych_neuro","repro_safety","uti","fatigue_workup"]  # include your earlier module if defined

def show(patient_dict, title):
    patient = pd.Series(patient_dict)
    print(f"\n=== {title} ===")
    allowed = filter_safe_actions(patient, list(ACTIONS.keys()), modules)
    print("Allowed:", allowed)
    ranked = rank_actions(patient, list(ACTIONS.keys()), weights, modules)
    for a, s in ranked[:8]:
        print(f"{a:24s} {s:+.2f}")
    if ranked:
        top = ranked[0][0]
        print("\nExplain top:", top, "→", explain_action(patient, top, weights, modules))

print(weights)

# Case 1: Uncontrolled DM2, CKD stage 3, GI upset on metformin, ASCVD high
show({
    "on_metformin":1, "GI_upset":1, "A1C":8.2, "high_A1C":1,
    "eGFR":45, "CKD_stage3plus":1,
    "ASCVD_high":1, "LDL":165, "LDL_high":1,
    "high_BMI":1, "pregnant":0, "QTc_prolonged":0
}, "DM2 + CKD + ASCVD (GI upset)")

# Case 2: Same but pregnant → repro mask should block SGLT2/ACEi; emphasize tests/counseling
show({
    "on_metformin":1, "GI_upset":1, "A1C":7.6, "high_A1C":1,
    "eGFR":70, "CKD_stage3plus":0,
    "ASCVD_high":0, "LDL":120, "LDL_high":0,
    "childbearing_potential":1, "pregnant":1
}, "Pregnant metabolic patient")

# Case 3: Elevated BP with prolonged QTc → ECG first, careful on SSRIs
show({
    "BP_systolic":162, "QTc_prolonged":1, "fatigue":1,
    "pregnant":0
}, "Hypertension + QTc prolonged")

# Case 4: UTI symptoms without fever
show({
    "UTI_symptoms":1, "fever":0, "dysuria":1
}, "UTI uncomplicated")

# Case 5: Migraine on triptan, fatigue (depressive symptoms) → block SSRI start, rank other safe steps
show({
    "triptan_use":1, "fatigue":1, "pregnant":0
}, "Migraine on triptan (SSRI caution)")


                            fever  on_metformin  ASCVD_high  BP_systolic  \
order_TSH                     0.0           0.0         0.0        0.000   
increase_metformin            0.0           0.0         0.0        0.000   
check_B12                     0.0           0.0         0.0        0.000   
start_GLP1                    0.0           0.0         0.0        0.000   
start_SGLT2                   0.0           0.5         0.0        0.000   
switch_metformin_ER           0.0           1.0         0.0        0.000   
order_microalbumin            0.0           0.0         0.0        0.000   
start_ACEi                    0.0           0.0         0.0        0.000   
avoid_metformin_escalation    0.0           0.0         0.0        0.000   
start_statin                  0.0           0.0         2.0        0.000   
order_lipid_panel             0.0           0.0         0.5        0.000   
start_thiazide                0.0           0.0         0.0        0.020   
start_SSRI  