In [3]:
#!/usr/bin/env python3
"""
Inference script for Optimized Hierarchical ClinicalBERT
========================================================
Returns disease names instead of numeric labels
"""

import torch
from transformers import AutoTokenizer
from safetensors.torch import load_file
from typing import List, Dict
import pandas as pd
from tqdm import tqdm

from optimized_clinical_bert import (
    ModelConfig,
    OptimizedHierarchicalClinicalBERT
)

# -----------------------------
# Label mappings (IMPORTANT)
# -----------------------------
PARENT_ID2LABEL = {
    0: "general pathological conditions",
    1: "specific disease"
}

CHILD_ID2LABEL = {
    0: "neoplasms",
    1: "digestive system diseases",
    2: "nervous system diseases",
    3: "cardiovascular diseases"
}


# -----------------------------
# Load Model
# -----------------------------
def load_model(
    model_path: str,
    model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
    device: str = None
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    config = ModelConfig(
        model_name=model_name,
        use_lora=True,
        use_gradient_checkpointing=False  # disable for inference
    )

    model = OptimizedHierarchicalClinicalBERT(config)
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer, device


# -----------------------------
# Single Text Inference
# -----------------------------
@torch.no_grad()
def predict_single(
    text: str,
    model,
    tokenizer,
    device,
    max_len: int = 256
) -> Dict:
    encoding = tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=max_len,
        return_tensors="pt"
    )

    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    outputs = model(input_ids, attention_mask)

    # Full softmax probabilities
    parent_probs = torch.softmax(outputs["parent_logits"], dim=-1).squeeze(0)  # (2,)
    child_probs = torch.softmax(outputs["child_logits"], dim=-1).squeeze(0)    # (4,)

    parent_pred = parent_probs.argmax().item()
    child_pred = child_probs.argmax().item()
    parent_conf = parent_probs.max().item()


    # Hierarchical decoding
    if parent_conf >= 0.8:  # general
        predicted_disease = PARENT_ID2LABEL[0]
        child_label = None
    else:
        predicted_disease = CHILD_ID2LABEL[child_pred]
        child_label = CHILD_ID2LABEL[child_pred]

    # Return ALL probabilities as dicts
    parent_probs_dict = {label: float(prob) for label, prob in zip(PARENT_ID2LABEL.values(), parent_probs.tolist())}
    child_probs_dict = {label: float(prob) for label, prob in zip(CHILD_ID2LABEL.values(), child_probs.tolist())}

    return {
        "text": text,
        "predicted_disease": predicted_disease,
        "parent_label": PARENT_ID2LABEL[parent_pred],
        "parent_probs": parent_probs_dict,
        "child_label": child_label,
        "child_probs": child_probs_dict
    }



# -----------------------------
# Batch Inference
# -----------------------------
@torch.no_grad()
def predict_batch(
    texts: List[str],
    model,
    tokenizer,
    device,
    batch_size: int = 8,
    max_len: int = 256
) -> List[Dict]:

    results = []

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        encoding = tokenizer(
            batch_texts,
            truncation=True,
            padding=True,
            max_length=max_len,
            return_tensors="pt"
        )

        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)

        outputs = model(input_ids, attention_mask)

        parent_probs = torch.softmax(outputs["parent_logits"], dim=-1)
        child_probs = torch.softmax(outputs["child_logits"], dim=-1)

        for j, text in enumerate(batch_texts):
            parent_pred = parent_probs[j].argmax().item()
            child_pred = child_probs[j].argmax().item()

            parent_conf = parent_probs[j].max().item()
            child_conf = child_probs[j].max().item()

            # -----------------------------
            # SAME hierarchical logic as single inference
            # -----------------------------
            if parent_conf > 0.8:
                predicted_disease = PARENT_ID2LABEL[0]
                child_label = None
                child_confidence = None
            else:
                predicted_disease = CHILD_ID2LABEL[child_pred]
                child_label = CHILD_ID2LABEL[child_pred]
                child_confidence = child_conf

            results.append({
                "text": text,
                "predicted_disease": predicted_disease,
                "parent_label": PARENT_ID2LABEL[parent_pred],
                "parent_confidence": parent_conf,
                "child_label": child_label,
                "child_confidence": child_confidence
            })


    return results


#!/usr/bin/env python3
"""
Run inference on a test CSV file
================================
Input: condition_label, medical_abstract
Output: predictions with disease names
"""
GT_ID2LABEL = {
    1: "neoplasms",
    2: "digestive system diseases",
    3: "nervous system diseases",
    4: "cardiovascular diseases",
    5: "general pathological conditions"
}

# -----------------------------
# Run inference on CSV
# -----------------------------
def run_inference_on_csv(
    csv_path: str ="./dataset/test.csv",
    model_path: str = "./results/best_model.safetensors",
    output_path: str = "./dataset/prediction.csv"
):
    # Load data
    df = pd.read_csv(csv_path)

    assert "medical_abstract" in df.columns, "CSV must contain medical_abstract"
    assert "condition_label" in df.columns, "CSV must contain condition_label"

    # Load model
    model, tokenizer, device = load_model(model_path)

    predictions = []

    for _, row in tqdm(df.iterrows(), total=len(df)):
        text = row["medical_abstract"]
        true_label = int(row["condition_label"])

        result = predict_single(
            text=text,
            model=model,
            tokenizer=tokenizer,
            device=device
        )

        predictions.append({
            "medical_abstract": text,
            "true_label_id": true_label,
            "true_disease": GT_ID2LABEL[true_label],
            "predicted_disease": result["predicted_disease"],
            "parent_label": result["parent_label"],
            "parent_confidence": result["parent_confidence"],
            "child_label": result["child_label"],
            "child_confidence": result["child_confidence"]
        })

    pred_df = pd.DataFrame(predictions)
    pred_df.to_csv(output_path, index=False)

    print(f"âœ… Inference complete. Saved to {output_path}")


In [4]:
# run_inference_on_csv()

In [45]:
!pip install gradio plotly captum

Collecting captum
  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)
Collecting matplotlib (from captum)
  Downloading matplotlib-3.10.8-cp311-cp311-win_amd64.whl.metadata (52 kB)
Collecting numpy<3.0,>=1.0 (from gradio)
  Using cached numpy-1.26.4-cp311-cp311-win_amd64.whl.metadata (61 kB)
Collecting contourpy>=1.0.1 (from matplotlib->captum)
  Downloading contourpy-1.3.3-cp311-cp311-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib->captum)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib->captum)
  Downloading fonttools-4.61.1-cp311-cp311-win_amd64.whl.metadata (116 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib->captum)
  Downloading kiwisolver-1.4.9-cp311-cp311-win_amd64.whl.metadata (6.4 kB)
Collecting pyparsing>=3 (from matplotlib->captum)
  Using cached pyparsing-3.3.1-py3-none-any.whl.metadata (5.6 kB)
Downloading captum-0.8.0-py3-none-any.whl (1.4 MB)
   --------------------------

  You can safely remove it manually.
  You can safely remove it manually.


In [5]:
import gradio as gr
import torch
import plotly.graph_objects as go


# -----------------------------
# Load model once
# -----------------------------
MODEL_PATH = "./results/best_model.safetensors"
model, tokenizer, device = load_model(MODEL_PATH)


# -----------------------------
# Function to create bar plots
# -----------------------------
def plot_probs(prob_dict, title="Probabilities"):
    labels = list(prob_dict.keys())
    values = [v if v is not None else 0 for v in prob_dict.values()]
    
    fig = go.Figure([go.Bar(x=labels, y=values, text=[round(v,3) for v in values], textposition='auto')])
    fig.update_layout(title_text=title, yaxis=dict(range=[0,1]), margin=dict(t=40, b=20))
    return fig


# -----------------------------
# Prediction function for Gradio
# -----------------------------
def classify_text(text: str):
    result = predict_single(text, model, tokenizer, device)

    # Parent and child probability dicts
    parent_display = {label: round(prob,3) for label, prob in result["parent_probs"].items()}

    if result["child_label"] is None:
        child_display = {label: 0 for label in CHILD_ID2LABEL.values()}
    else:
        child_display = {label: round(prob,3) for label, prob in result["child_probs"].items()}

    # Build plots
    parent_plot = plot_probs(parent_display, title="Parent Probabilities")
    child_plot = plot_probs(child_display, title="Child Probabilities")

    return result["predicted_disease"], parent_plot, child_plot


# -----------------------------
# Gradio Interface
# -----------------------------
iface = gr.Interface(
    fn=classify_text,
    inputs=gr.Textbox(lines=1, placeholder="Enter medical abstract here..."),
    outputs=[
        gr.Label(label="Predicted Disease"),
        gr.Plot(label="Parent Probabilities"),
        gr.Plot(label="Child Probabilities")
    ],
    title="Hierarchical ClinicalBERT Disease Classifier",
    description="Enter a medical abstract to get predicted disease and probability plots for parent and child disease categories."
)

iface.launch(share=True)


  warn(


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://fd0f43955e3391ec39.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [46]:
import torch
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer

# Use your model and tokenizer
model.eval()

tokenizer = tokenizer  # already loaded

# Encode text
text = "Noninvasive determination of pulmonary artery wedge pressure: comparative analysis of pulsed Doppler echocardiography and right heart catheterization. To compare left ventricular filling variables as derived by transmitral pulsed Doppler echocardiography (tpDE) and hemodynamic variables as assessed at right heart catheterization (RHC), 104 ICU patients (64 male, 40 female) aged 26 to 73 yr (mean 54.6 +/- 10.3) without valvular heart disease were examined. Simultaneously with RHC, transmitral flow velocity profiles were obtained by tpDE, and the ratio of the velocity-time integrals of late diastolic active (A wave) and early diastolic passive inflow into the left ventricle (E wave) was calculated (A/E ratio). Invasively determined pulmonary capillary wedge pressure (WP) ranged from 3 to 36 mm Hg (median 13.35, 5%/95% 6/31 mm Hg). Linear regression analysis showed a highly significant correlation between the A/E ratio and WP (r = .98, p less than .001, standard error of the estimate [SEE] = 0.10). The A/E ratio also correlated with other hemodynamic variables such as cardiac output (r = -.68, p less than .001, SEE = 0.33), cardiac index (r = -.74, p less than .001, SEE = 0.31), and stroke volume index (r = -.68, p less than .001, SEE = 0.34). The interobserver agreement (derived by intraclass correlation analysis between two examiners) on the A/E ratio was high (r = .95, p less than .001, n = 26). We conclude that WP can be accurately determined noninvasively by tpDE. For the assessment of systolic ventricular function, tpDE is of limited diagnostic value."
inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]

# Define a wrapper for Captum
def forward_parent(input_ids):
    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    outputs = model(input_ids, attention_mask)
    # Return logits for the predicted parent class
    parent_logits = outputs["parent_logits"]
    parent_class = parent_logits.argmax(dim=-1)
    return parent_logits[:, parent_class]

# Integrated Gradients
ig = IntegratedGradients(forward_parent)
attributions, delta = ig.attribute(input_ids, return_convergence_delta=True)

# Decode tokens
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Map token -> importance
word_importance = {tok: float(attr) for tok, attr in zip(tokens, attributions[0])}

# Sort by absolute importance
sorted_importance = dict(sorted(word_importance.items(), key=lambda x: abs(x[1]), reverse=True))

print(sorted_importance)


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [6]:
!pip install shap numpy==2.3
import shap

MODEL_PATH = "./results/best_model.safetensors"
model, tokenizer, device = load_model(MODEL_PATH)
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(["Patient presents with chest pain and shortness of breath."])

shap.plots.text(shap_values[0])




TypeError: OptimizedHierarchicalClinicalBERT.forward() missing 1 required positional argument: 'attention_mask'

In [8]:
from pprint import pprint 
from captum.attr import IntegratedGradients
import torch

model, tokenizer, device = load_model(
    model_path="./results/best_model.safetensors"
)

text = "Patient presents with chest pain and shortness of breath."
inputs = tokenizer(text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# -----------------------------
# Wrapper for Captum
# -----------------------------
def forward_parent(input_ids):
    # Ensure correct dtype
    input_ids = input_ids.long()
    
    outputs = model(input_ids, attention_mask)
    
    # Get parent logits for the predicted class
    parent_logits = outputs["parent_logits"]
    parent_class = parent_logits.argmax(dim=-1)
    return parent_logits[:, parent_class]


# -----------------------------
# Integrated Gradients
# -----------------------------
ig = IntegratedGradients(forward_parent)

attributions, delta = ig.attribute(input_ids, return_convergence_delta=True)

tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
word_importance = {tok: float(attr) for tok, attr in zip(tokens, attributions[0])}


AssertionError: Target not provided when necessary, cannot take gradient with respect to multiple outputs.