# EPL test

In [None]:
import json
import os
import re

import dotenv
import xgboost as xgb
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import Client
import streamlit as st
from PyPDF2 import PdfReader
import shap
import torch
from transformers import AutoTokenizer, AutoModel

In [None]:
def predict_ctg(ctg_model, features):
    ctg_class_map = {0: "Normal", 1: "Suspect", 2: "Pathological"}
    
    pred = ctg_model.predict(features.reshape(1, -1))[0]
    return ctg_class_map.get(pred, "Unknown")


def predict_epl(MA, EM, GSD, EL, YSD, EHR):
    """
    Predict Early Pregnancy Loss (EPL) risk score, risk percentage, and explain reasons.
    Inputs:
    - MA: Maternal Age (years)
    - EM: Endometrium thickness on transfer day (mm)
    - GSD: Gestational Sac Diameter (mm)
    - EL: Embryo Length (mm)
    - YSD: Yolk Sac Diameter (mm)
    - EHR: Embryonic Heart Rate (bpm)
    """

    score = 0
    reasons = []

    # Maternal Age
    if MA < 30:
        score += 0
    elif 30 <= MA < 35:
        score += 1
        reasons.append("Maternal age between 30–34 years slightly increases EPL risk.")
    elif 35 <= MA < 40:
        score += 2
        reasons.append("Maternal age between 35–39 years moderately increases EPL risk due to decreased egg quality.")
    else:
        score += 3
        reasons.append("Maternal age ≥40 years strongly increases EPL risk due to higher chromosomal abnormalities.")

    # Endometrium Thickness
    if EM >= 9:
        score += 0
    elif 7 <= EM < 9:
        score += 1
        reasons.append("Endometrium thickness between 7–8.9 mm shows borderline uterine receptivity.")
    else:
        score += 2
        reasons.append("Endometrium thickness <7 mm indicates poor uterine lining and reduced implantation potential.")

    # Gestational Sac Diameter
    if GSD >= 18:
        score += 0
    elif 14 <= GSD < 18:
        score += 1
        reasons.append("Gestational sac diameter between 14–17.9 mm is slightly smaller than expected.")
    else:
        score += 2
        reasons.append("Gestational sac diameter <14 mm suggests abnormal or delayed growth.")

    # Embryo Length
    if EL >= 3:
        score += 0
    elif 1.5 <= EL < 3:
        score += 1
        reasons.append("Embryo length between 1.5–2.9 mm indicates slower growth than expected.")
    else:
        score += 2
        reasons.append("Embryo length <1.5 mm indicates poor embryonic development and higher risk of EPL.")

    # Yolk Sac Diameter
    if 3 <= YSD <= 4:
        score += 0
    else:
        score += 1
        reasons.append("Abnormal yolk sac size (<3 mm or >4 mm) increases risk of developmental abnormalities.")

    # Embryonic Heart Rate
    if EHR >= 100:
        score += 0
    elif 60 <= EHR < 100:
        score += 1
        reasons.append("Embryonic heart rate between 60–99 bpm indicates possible early fetal distress.")
    else:
        score += 2
        reasons.append("Embryonic heart rate <60 bpm strongly predicts early pregnancy loss.")

    # Calculate risk percentage
    if score < 5:
        risk_percentage = round((score / 5) * 30, 1)
        risk_level = "Low"
    else:
        risk_percentage = round(30 + (score - 5) * 10, 1)
        if risk_percentage < 50:
            risk_level = "Moderate"
        elif risk_percentage < 70:
            risk_level = "High"
        else:
            risk_level = "Very High"


    return {
        "model": "EPL",
        "score": score,
        "risk": f"{risk_percentage}%",
        "risk_level": risk_level,
        "reasons": reasons
    }


# --- Function to extract EPL data from PDF ---
def extract_epl_from_pdf(pdf_file, client):
    reader = PdfReader(pdf_file)
    text = "".join([page.extract_text() or "" for page in reader.pages])
    prompt_text = f"""
Extract Early Pregnancy Loss (EPL) clinical data from the text.
Return JSON with keys: MA, EM, GSD, EL, YSD, EHR and numeric values.
Text:
{text}
"""
    completion = client.chat.completions.create(
        model="qwen/qwen3-vl-30b-a3b-thinking",
        messages=[{"role": "user", "content": [{"type": "text", "text": prompt_text}]}],
    )
    try:
        content = completion.choices[0].message.content
        data = json.loads(content)
        return data  # dict: key -> numeric value
    except Exception as e:
        raise ValueError(f"Could not extract EPL data from PDF: {e}")

# --- Function to extract CTG data from PDF ---
def extract_ctg_from_pdf(pdf_file, client):
    reader = PdfReader(pdf_file)
    text = "".join([page.extract_text() or "" for page in reader.pages])
    prompt_text = f"""
Extract CTG data as a table from the text.
Columns: baseline value, accelerations, fetal_movement, uterine_contractions,
light_decelerations, severe_decelerations, prolongued_decelerations,
abnormal_short_term_variability, mean_value_of_short_term_variability,
percentage_of_time_with_abnormal_long_term_variability,
histogram_min, histogram_max, histogram_number_of_peaks, histogram_number_of_zeroes,
histogram_mode, histogram_mean, histogram_median, histogram_variance, histogram_tendency, fetal_health

Return JSON with column names as keys and each value as a list of numeric values.
Text:
{text}
"""
    completion = client.chat.completions.create(
        model="qwen/qwen3-vl-30b-a3b-thinking",
        messages=[{"role": "user", "content": [{"type": "text", "text": prompt_text}]}],
    )
    try:
        content = completion.choices[0].message.content
        data = json.loads(content)
        # Flatten first row for input
        first_row = [data[col][0] for col in data.keys()]
        return np.array(first_row)
    except Exception as e:
        raise ValueError(f"Could not extract CTG data from PDF: {e}")


# Compute embeddings for all advice docs
def compute_embedding(text, emb_model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = emb_model(**inputs)
    # Mean pooling over token embeddings
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()


def get_relevant_docs(doc_embeddings, emb_model, tokenizer, query, top_k=5):
    query_vec = compute_embedding(query, emb_model, tokenizer)
    scores = cosine_similarity([query_vec], doc_embeddings).flatten()
    top_indices = scores.argsort()[-top_k:][::-1]
    return top_indices


def llm_generate(prompt, client: Client, max_output_tokens=2048, model_name="openai/gpt-oss-120b"):
    response = client.completions.create(
        prompt=prompt,
        model=model_name,
        max_tokens=max_output_tokens,
        temperature=0.2
    )

    # Extract the text content from the response
    try:
        text_output = response.choices[0].text
    except KeyError:
        raise ValueError(f"Unexpected response format: {response}")
    return text_output


def safe_parse_json(text: str):
    if not text or not isinstance(text, str):
        raise ValueError("Input text must be a non-empty string.")

    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    try:
        # Extract the first {...} or [...] block
        match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL)
        if match:
            candidate = match.group(1)
            return json.loads(candidate)
    except json.JSONDecodeError:
        raise ValueError(f"Failed to parse JSON from text: {text}")



# Load models and data

In [None]:
# Load XGBoost model
ctg_model = xgb.XGBClassifier()
ctg_model.load_model("models/fetal_xgb_model.json")

# Load MedEmbed model
model_name = "abhinand/MedEmbed-base-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
emb_model = AutoModel.from_pretrained(model_name)

# Load advice documents
advice_docs = json.load(open("data/advices.jsonl", "r", encoding="utf-8"))
print(advice_docs[0])

doc_embeddings = np.array([compute_embedding(doc['advice'], emb_model, tokenizer) for doc in advice_docs])


dotenv.load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
client = Client(
    api_key=OPENROUTER_API_KEY,
    base_url="https://openrouter.ai/api/v1",
)


# all together

In [None]:
def run_risk_system(epl_inputs, ctg_features, top_k_docs=3, derived_features=None):
    """
    Complete risk assessment system:
    - Predict EPL risk (score and percentage)
    - Predict CTG classification
    - Retrieve relevant advice documents
    - Generate structured LLM-based recommendations with citations
    Returns a structured JSON report
    """

    # --- 1. Normalize / validate inputs ---
    epl_inputs = epl_inputs.copy()
    # Ensure all expected keys are in epl_inputs, add defaults if necessary
    expected_epl_keys = ["MA", "EM", "GSD", "EL", "YSD", "EHR"]
    for key in expected_epl_keys:
        if key not in epl_inputs or epl_inputs[key] is None:
            epl_inputs[key] = 0 # Or a sensible default based on the feature

    if 'EHR' in epl_inputs and epl_inputs['EHR'] is not None and epl_inputs['EHR'] < 10:  # assume heart rate in bpm
        epl_inputs['EHR'] = max(0, epl_inputs['EHR'])  # safeguard
    if derived_features:
        epl_inputs.update(derived_features)

    # --- 2. Predict EPL ---
    epl_result = predict_epl(**epl_inputs)

    # Ensure consistent score-to-risk mapping
    if epl_result['score'] < 0: epl_result['score'] = 0
    if epl_result['score'] > 100: epl_result['score'] = 100 # Assuming score could theoretically exceed 100
    # Recalculate risk percentage based on the score
    # A simple mapping from score to a percentage (this needs refinement based on actual model training)
    # Using a linear scaling example for demonstration, adjust based on your model's score distribution
    epl_percentage = min(max(int(epl_result['score']) * 5, 0), 100) # Example: 1 point = 5% risk
    epl_result['risk'] = f"{epl_percentage}%"

    # Recalculate risk level based on the percentage
    if epl_percentage < 30:
        epl_result['risk_level'] = "Low"
    elif epl_percentage < 60:
        epl_result['risk_level'] = "Moderate"
    elif epl_percentage < 80:
        epl_result['risk_level'] = "High"
    else:
        epl_result['risk_level'] = "Very High"


    # --- 3. CTG Prediction ---
    ctg_features_array = np.array(ctg_features).reshape(1, -1)
    ctg_pred = predict_ctg(ctg_model, ctg_features_array)

    # CTG SHAP explanation
    # explainer = shap.TreeExplainer(ctg_model)
    # shap_values = explainer.shap_values(ctg_features_array)



    # --- 4. Retrieve relevant documents ---
    query = f"EPL risk: {epl_result['risk_level']}, CTG class: {ctg_pred}"
    relevant_docs_indices = get_relevant_docs(doc_embeddings, emb_model, tokenizer, query)[:top_k_docs]
    relevant_docs = [advice_docs[i] for i in relevant_docs_indices]

    if not relevant_docs:
        doc_text = "No medical literature documents found for the query."
    else:
        doc_text = "\n".join([
            f"[Document {i+1}] {d['advice']}\nSource: {d['source']}, Page: {d.get('page_number','N/A')}"
            for i, d in enumerate(relevant_docs)
        ])

    # --- 5. Structured LLM prompt ---
    prompt = f"""
You are a clinical decision support assistant. Generate structured, evidence-based analysis
using ONLY the provided clinical data and referenced medical literature.

CLINICAL DATA:
- Maternal EPL Risk Score: {epl_result['score']}/100 ({epl_result['risk']})
- Risk Classification: {epl_result['risk_level']}
- Contributing Risk Factors: {', '.join(epl_result['reasons'])}
- Cardiotocography (CTG) Classification: {ctg_pred}
  (Note: Early CTG in first trimester cannot reliably exclude EPL)

MEDICAL LITERATURE REFERENCES:
{doc_text}

TASK:
1. RISK ASSESSMENT:
   - Interpret EPL score and clinical significance.
   - Explain contribution of each risk factor.
   - Explain implications of CTG classification.

2. EVIDENCE-BASED RECOMMENDATIONS:
   - Provide recommendations STRICTLY based on the documents.
   - Each recommendation MUST cite at least one document (e.g., [Document 1]).
   - Prioritize recommendations by clinical urgency and importance.

3. MONITORING AND FOLLOW-UP:
   - Suggest monitoring frequency and key parameters.
   - Highlight warning signs requiring immediate attention.

OUTPUT FORMAT:
Return ONLY a valid JSON object in this exact structure:
{{
    "Risk_Assessment": "text summary here",
    "Recommendations": [
        {{
            "recommendation": "text here",
            "sources": "source with page number if available"
        }},
        ...
    ],
    "Monitoring": "text here"
}}
STRICTLY follow JSON format. Use precise medical terminology. Do not add external information.
"""

    # --- 6. Generate LLM output ---
    # --- 6. Generate LLM output ---
    llm_text = llm_generate(prompt, client)

    # --- 7. Safely parse JSON ---
    llm_output = safe_parse_json(llm_text)


    # --- 8. Build final report ---
    report = {
        "EPL": epl_result,
        "CTG": {"class": ctg_pred, "note": "Early CTG alone may not exclude EPL in first trimester"},
        "Recommendations": llm_output
    }

    return report

In [None]:
# Sample CTG features from test data
ctg_sample = np.array([[-0.1111143 ,  0.46860644, -0.2071137 , -0.80777555,  0.39270884,
                        -0.05431254, -0.27217386, -1.25480395,  0.18897025, -0.53840522,
                         0.5167156 ,  0.55443516, -0.5904121 ,  0.23322547, -0.02605247,
                         2.53435472,  0.22815306,  0.22936866,  0.14451318, -0.09831217,
                         1.11838295]])

# Dummy EPL inputs (replace with real test data)
epl_sample_inputs = {
    "MA": 32,        # Maternal Age
    "EM": 1,         # Embryo Metric / Previous miscarriage count
    "GSD": 7.5,      # Gestational Sac Diameter
    "EL": 1.2,       # Embryo Length
    "YSD": 0.8,      # Yolk Sac Diameter
    "EHR": 1         # Early Heart Rate (or other relevant metric)
}

# Run the system
report = run_risk_system(epl_sample_inputs, ctg_sample)

# Display output
print(json.dumps(report, indent=4))

In [None]:
@st.cache_resource
def load_embedding_model():
    model_name = "abhinand/MedEmbed-base-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    emb_model = AutoModel.from_pretrained(model_name)
    return tokenizer, emb_model



@st.cache_resource
def precompute_doc_embeddings(docs, _emb_model):
    return np.array([compute_embedding(doc['advice'], _emb_model, tokenizer) for doc in docs])
    

tokenizer , emb_model = load_embedding_model()
doc_embeddings = precompute_doc_embeddings(advice_docs, emb_model)


# --- Streamlit app ---
st.set_page_config(page_title="Early Pregnancy Loss Risk System", layout="wide")
st.title("🍼 Early Pregnancy Loss (EPL) & CTG Risk Assessment")
st.markdown("""
This app allows you to:
- Input maternal and embryonic data manually or via PDF upload
- Get risk assessment for early pregnancy loss
- Review CTG classification
- See relevant evidence-based recommendations with sources
""")

# --- Two-column layout ---
col_elp, col_ctg = st.columns(2)

# --- EPL Column ---
with col_elp:
    st.header("🧬 EPL Test")
    elp_method = st.radio("Input Method", ["Manual Input", "Upload PDF"], key="elp_method")

    epl_inputs = {}
    if elp_method == "Manual Input":
        epl_inputs["MA"] = st.number_input("Maternal Age (years)", min_value=18, max_value=50)
        epl_inputs["EM"] = st.number_input("Endometrium Thickness (mm)", min_value=0.0, max_value=20.0, step=0.1)
        epl_inputs["GSD"] = st.number_input("Gestational Sac Diameter (mm)", min_value=0.0, max_value=50.0, step=0.1)
        epl_inputs["EL"] = st.number_input("Embryo Length (mm)", min_value=0.0, max_value=10.0, step=0.1)
        epl_inputs["YSD"] = st.number_input("Yolk Sac Diameter (mm)", min_value=0.0, max_value=10.0, step=0.1)
        epl_inputs["EHR"] = st.number_input("Embryonic Heart Rate (bpm)", min_value=0, max_value=200)
    else:
        uploaded_file = st.file_uploader("Upload EPL PDF", type=["pdf"])
        if uploaded_file:
            extracted = extract_epl_from_pdf(uploaded_file, client)
            st.text_area("Extracted EPL Data", json.dumps(extracted, indent=2), height=150)
            # Populate manual fields automatically
            epl_inputs = {k: extracted.get(k, 0) for k in ["MA","EM","GSD","EL","YSD","EHR"]}

# --- CTG Column ---
with col_ctg:
    st.header("💓 CTG Features")
    ctg_method = st.radio("Input Method", ["Manual Input", "Upload PDF"], key="ctg_method")

    ctg_features = None
    if ctg_method == "Manual Input":
        ctg_text = st.text_area("Enter 21 CTG features separated by commas", height=200)
        try:
            ctg_features = np.array([float(x.strip()) for x in ctg_text.split(",")])
            if len(ctg_features) != 21:
                st.warning("Please provide exactly 21 features.")
                ctg_features = None
        except Exception as e:
            st.warning(f"Invalid CTG input. Enter numeric values separated by commas. {e}")
            ctg_features = None
    else:
        uploaded_file = st.file_uploader("Upload CTG PDF", type=["pdf"], key="ctg_pdf")
        if uploaded_file:
            ctg_features = extract_ctg_from_pdf(uploaded_file, client)
            if ctg_features is not None:
                st.text_area("Extracted CTG Features", ", ".join(map(str, ctg_features)), height=150)

# --- Run Risk Assessment ---
if st.button("Run Risk Assessment"):
    if epl_inputs and ctg_features is not None:
        with st.spinner("Calculating risk and retrieving recommendations..."):
            report = run_risk_system(epl_inputs, ctg_features)

        # --- Display EPL Results ---
        st.subheader("📊 EPL Risk Assessment")
        epl = report["EPL"]
        st.markdown(f"""
        **Score:** {epl['score']}/100
        **Risk:** {epl['risk']}
        **Risk Level:** {epl['risk_level']}
        **Contributing Factors:**
        """)
        for r in epl['reasons']:
            st.markdown(f"- {r}")

        # --- Display CTG Results ---
        st.subheader("💓 CTG Classification")
        ctg = report["CTG"]
        st.markdown(f"**Class:** {ctg['class']}")
        st.markdown(f"**Note:** {ctg.get('note','')}")

        # --- Recommendations ---
        st.subheader("📚 Evidence-Based Recommendations")
        recs = report["Recommendations"]
        if "Recommendations" in recs:
            for r in recs["Recommendations"]:
                st.markdown(
                    f"<p style='font-family:Courier New; font-size:16px;'>• {r['recommendation']} <br><i>Source: {r['sources']}</i></p>",
                    unsafe_allow_html=True
                )
        elif "error" in recs:
            st.error(f"LLM output error: {recs['error']}")
            st.text(recs.get("raw_output",""))

        # --- Monitoring ---
        st.subheader("🩺 Monitoring & Follow-Up")
        monitoring_text = recs.get("Monitoring", "No monitoring instructions available.")
        st.markdown(f"<p style='font-family:Courier New; font-size:16px;'>{monitoring_text}</p>", unsafe_allow_html=True)
    else:
        st.warning("Please provide all required EPL and CTG inputs before running the assessment.")