<a href="https://colab.research.google.com/github/kondratevakate/medllm-triage-eval/blob/main/BioBert_BGEm3_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Setup and Imports
import os
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

import os
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from transformers import AutoTokenizer, AutoModel

df = pd.read_csv('random_stratified_test.csv')

# 3.1 Define raw→friendly mapping and rename
col_map = {
    'age':                  'Age'
}
df = df.rename(columns=col_map)

# 3. Preprocess Dataset
# Identify vital sign columns (already human-friendly)
vital_cols = [
    'Temperature', 'HeartRate', 'RespiratoryRate',
    'Oxygen', 'SystolicBP', 'DiastolicBP'
]
# Demographics and target (renamed)
other_cols = ['Sex', 'Age', 'ESI']
cc_cols    = [c for c in df.columns if c.startswith('cc_')]

# Validate presence
missing = set(vital_cols + cc_cols + other_cols) - set(df.columns)
if missing:
    raise KeyError(f"Missing columns: {missing}")

df = df[vital_cols + cc_cols + other_cols]

# 3.4 Clean target
df = df.dropna(subset=['ESI']).reset_index(drop=True)
df['ESI'] = df['ESI'].astype(int)

# %%
# 4. Build ComplaintText and embed via BioBERT
# Create readable complaint string
def complaint_text(row):
    names = [col.replace('cc_', '').replace('_', ' ').title()
             for col in cc_cols if row[col] == 1]
    return ' '.join(names) if names else 'NoComplaint'

df['ComplaintText'] = df.apply(complaint_text, axis=1)

# 5. Serialize cases (CC as single variable)
vital_names = ['Temperature','HeartRate','RespiratoryRate','Oxygen','SystolicBP','DiastolicBP']
demo_names = ['Age','Sex']

def serialize_row(row):
    parts = [f"{col}: {row[col] if pd.notnull(row[col]) else 'Missing'}"
             for col in vital_names]
    parts.append(f"ChiefComplaint: {row['ComplaintText']}")
    parts += [f"{col}: {row[col]}" for col in demo_names]
    return '; '.join(parts)

df['Serialized'] = df.apply(serialize_row, axis=1)


In [None]:
!git lfs install
!git clone https://huggingface.co/BAAI/bge-m3 ./bge-m3


Git LFS initialized.
Cloning into './bge-m3'...
remote: Enumerating objects: 150, done.[K
remote: Counting objects: 100% (146/146), done.[K
remote: Compressing objects: 100% (144/144), done.[K
remote: Total 150 (delta 64), reused 0 (delta 0), pack-reused 4 (from 1)[K
Receiving objects: 100% (150/150), 3.22 MiB | 4.83 MiB/s, done.
Resolving deltas: 100% (64/64), done.
Filtering content: 100% (9/9), 4.27 GiB | 167.54 MiB/s, done.


In [None]:
!pip install -U bitsandbytes
import os
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
from torch.cuda.amp import autocast
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_predict

def load_and_preprocess(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    col_map = {
        'triage_vital_temp': 'Temperature',
        'triage_vital_hr':   'HeartRate',
        'triage_vital_rr':   'RespiratoryRate',
        'triage_vital_o2':   'Oxygen',
        'triage_vital_sbp':  'SystolicBP',
        'triage_vital_dbp':  'DiastolicBP',
        'gender':            'Sex',
        'race':              'Race',
        'age':               'Age'
    }
    df.rename(columns=col_map, inplace=True)

    vital_cols = ['Temperature','HeartRate','RespiratoryRate','Oxygen','SystolicBP','DiastolicBP']
    cc_cols    = [c for c in df.columns if c.startswith('cc_')]
    other_cols = ['Sex','Age','ESI']
    df = df[vital_cols + cc_cols + other_cols].dropna(subset=['ESI']).reset_index(drop=True)
    df['ESI'] = df['ESI'].astype(int)

    df['ComplaintText'] = df.apply(
        lambda r: ' '.join(
            c.replace('cc_','').replace('_',' ').title()
            for c in cc_cols if r[c]==1
        ) or "NoComplaint",
        axis=1
    )

    def serialize(r):
        parts = [f"{c}: {r[c]}" for c in vital_cols]
        parts.append(f"ChiefComplaint: {r['ComplaintText']}")
        parts += [f"{c}: {r[c]}" for c in ['Age','Sex']]
        return '; '.join(parts)

    df['Serialized'] = df.apply(serialize, axis=1)
    return df

def get_bge_model(token: str = None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    quant = BitsAndBytesConfig(load_in_8bit=True)
    kwargs = {'trust_remote_code': True, 'quantization_config': quant, 'device_map': 'auto'}
    if token:
        kwargs['use_auth_token'] = token

    tokenizer = AutoTokenizer.from_pretrained("./bge-m3", **kwargs)
    model     = AutoModel.from_pretrained("./bge-m3", **kwargs).eval()
    return tokenizer, model

def get_bge_model(token: str = None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    quant = BitsAndBytesConfig(load_in_8bit=True)
    kwargs = {'trust_remote_code': True, 'quantization_config': quant, 'device_map': 'auto'}
    if token:
        kwargs['use_auth_token'] = token

    tokenizer = AutoTokenizer.from_pretrained("./bge-m3", **kwargs)
    model     = AutoModel.from_pretrained("./bge-m3", **kwargs).eval()
    return tokenizer, model

def embed_serialized(df: pd.DataFrame, tokenizer, model, bs: int = 64, max_len: int = 64) -> pd.DataFrame:
    device = model.device
    embs = []
    texts = df['Serialized'].tolist()
    for i in range(0, len(texts), bs):
        batch = texts[i:i+bs]
        enc = tokenizer(batch,
                        padding='longest',
                        truncation=True,
                        max_length=max_len,
                        return_tensors='pt')
        enc = {k:v.to(device) for k,v in enc.items()}
        with torch.no_grad(), autocast():
            out = model(**enc).last_hidden_state
            embs.append(out.mean(dim=1).cpu().numpy())
    df['TextEmb'] = list(np.vstack(embs))
    return df

def train_and_evaluate(df: pd.DataFrame):
    X = np.vstack(df['TextEmb'].values)
    X = StandardScaler().fit_transform(X)
    y = df['ESI'].values

    mlp = MLPClassifier(hidden_layer_sizes=(512,256),
                        early_stopping=True, max_iter=100,
                        random_state=42)
    cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

    y_pred = cross_val_predict(mlp, X, y, cv=cv)

    return y, y_pred


# Execute
df = load_and_preprocess('random_stratified_test.csv')
tokenizer, model = get_bge_model(token=os.getenv("HF_TOKEN"))
df = embed_serialized(df, tokenizer, model)
y, y_pred = train_and_evaluate(df)




  with torch.no_grad(), autocast():


In [None]:
metrics_row = compute_metrics_row(pd.Series(np.array(y)), pd.Series(y_pred))
df_metrics = pd.DataFrame([metrics_row], index=['ExampleModel'])

df_metrics.round(2)

Unnamed: 0,Acc.,P,R,F1,HR,Mod. F1,NP,OT,UT,ER
ExampleModel,0.6,0.58,0.45,0.48,0.6,0.63,0.71,0.2,0.2,0.4


In [None]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from torch.cuda.amp import autocast

# 1) Select device dynamically
device = torch.device("cuda")
print(f"Using device: {device}")

# 1) Load & prepare BioBERT on GPU + FP16

model = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')
model.to(device).eval()
model.half()   # switch weights to FP16

def embed_texts_fast(texts, batch_size=64, max_length=64):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        # 2) tokenize on CPU, then move to GPU
        enc = tokenizer(
            batch,
            padding='longest',
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        enc = {k: v.to('cuda', non_blocking=True) for k, v in enc.items()}

        # 3) run in mixed precision
        with torch.no_grad(), autocast():
            out = model(**enc).last_hidden_state  # [B, L, D]

        # 4) mean-pool in FP16 then move to CPU float32
        emb = out.mean(dim=1).cpu().float().numpy()  # [B, D]
        embs.append(emb)

    return np.vstack(embs)

# # 5) Compute your embeddings
# df['TextEmb'] = list(embed_texts_fast(df['Serialized'].tolist(), batch_size=64, max_length=64))

Using device: cuda


In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics_row(y_true, y_pred):
    """
    Compute triage performance metrics for a single model/run.

    Returns a dict with keys:
      'Acc.', 'P', 'R', 'F1', 'HR', 'Mod. F1', 'NP', 'OT', 'UT', 'ER'
    """
    # Overall
    acc = accuracy_score(y_true, y_pred)
    p   = precision_score(y_true, y_pred, average='macro', zero_division=0)
    r   = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1  = f1_score(y_true, y_pred, average='macro', zero_division=0)

    # High-risk recall (ESI 1&2)
    high_true = y_true.isin([1,2]).astype(int)
    high_pred = pd.Series(y_pred).isin([1,2]).astype(int)
    hr = recall_score(high_true, high_pred, zero_division=0)

    # Moderate (ESI-3) F1
    mod_true = (y_true == 3).astype(int)
    mod_pred = (pd.Series(y_pred) == 3).astype(int)
    mod_f1 = f1_score(mod_true, mod_pred, zero_division=0)

    # Non-urgent precision (ESI 4&5)
    non_true = y_true.isin([4,5]).astype(int)
    non_pred = pd.Series(y_pred).isin([4,5]).astype(int)
    nprec = precision_score(non_true, non_pred, zero_division=0)

    # Over- and Under-triage rates
    over = np.sum(pd.Series(y_pred) < y_true)
    under = np.sum(pd.Series(y_pred) > y_true)
    total = len(y_true)
    ot = over / total
    ut = under / total

    # Error rate
    er = 1.0 - acc

    return {
        'Acc.':     acc,
        'P':        p,
        'R':        r,
        'F1':       f1,
        'HR':       hr,
        'Mod. F1':  mod_f1,
        'NP':       nprec,
        'OT':       ot,
        'UT':       ut,
        'ER':       er
    }

In [None]:
X_emb = embed_texts_fast(df['Serialized'].tolist(), batch_size=64, max_length=64)

# 6. Scale embeddings
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_emb)
y = df['ESI']

# 7. Cross-validated MLP training + predictions
mlp = MLPClassifier(
    hidden_layer_sizes=(512,256),
    learning_rate_init=1e-3,
    max_iter=100,
    early_stopping=True,
    random_state=42
)
cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
y_pred = cross_val_predict(mlp, X_scaled, y, cv=cv)
y_proba = cross_val_predict(mlp, X_scaled, y, cv=cv, method='predict_proba')



metrics_row = compute_metrics_row(pd.Series(np.array(y)), pd.Series(y_pred))
df_metrics = pd.DataFrame([metrics_row], index=['ExampleModel'])

df_metrics.round(2)

  with torch.no_grad(), autocast():


Unnamed: 0,Acc.,P,R,F1,HR,Mod. F1,NP,OT,UT,ER
ExampleModel,0.6,0.56,0.48,0.5,0.59,0.63,0.68,0.2,0.2,0.4
