In [21]:
import pandas as pd
import numpy as np
import re
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt


In [23]:

# ==== Paramètres BERT ====
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT" # ou "bert-base-uncased" si indisponible
MAX_LEN = 128
EPOCHS = 3
BATCH_SIZE = 8
SEED = 42



In [25]:
# ==== Extraction âge ====
def extract_age(text):
    if not isinstance(text, str):
        return -1
    mapping = {'newborn': 0.1, 'infant': 0.5, 'toddler': 2, 'child': 7, 'adolescent': 15, 'teenager': 15}
    decades = {"twenties": 25, "thirties": 35, "forties": 45, "fifties": 55, "sixties": 65, "seventies": 75, "eighties": 85, "nineties": 95}
    patterns = [
        r'\bI\s*am\s*(\d{1,3})\b', r"\bI'm\s*(\d{1,3})\b", r'\bI\s*am\s*(\d{1,3})\s*yrs?\b', r"\bI'm\s*(\d{1,3})\s*yrs?\b",
        r'\b(\d{1,3})\s*(years? old|yrs? old|yr old|yo|ans?)\b', r'\b(\d{1,3})\s*-\s*year[- ]*old\b',
        r'\b(\d{1,3})\s?(?:M|F|m|f)[, ]', r'\b(age|âge|Age|Âge)\s*:?[\s]*?(\d{1,3})\b',
        r'\b(\d{1,3})\s*[.,]', r'\bmy\s+\d{1,3}-year-old', r'\b(\d{1,3})\b(?=\s*[-]*year[- ]*old)',
        r'\baged\s*(\d{1,3})\b', r'\bwhen I was (\d{1,3})\s*yrs?\b']
    for pat in patterns:
        match = re.search(pat, text, re.IGNORECASE)
        if match:
            for group in match.groups():
                if group and group.isdigit():
                    return int(group)
            if match.lastindex and match.lastindex >= 2 and match.group(2).isdigit():
                return int(match.group(2))
    match = re.search(r'\b([a-z]+)\s+year[s]? old\b', text.lower())
    if match:
        words2num = {'zero': 0, 'one': 1, 'two':2, 'three':3, 'four':4, 'five':5, 'six':6, 'seven':7, 'eight':8,
            'nine':9, 'ten':10, 'eleven':11, 'twelve':12, 'thirteen':13, 'fourteen':14, 'fifteen':15, 
            'sixteen':16, 'seventeen':17, 'eighteen':18, 'nineteen':19, 'twenty':20}
        val = match.group(1)
        if val in words2num:
            return words2num[val]
    match = re.search(r'(\d{1,2})\s*(months?|mos?|mo\.?)\s*(old)?\b', text, re.IGNORECASE)
    if match:
        mois = int(match.group(1))
        return round(mois / 12, 2)
    relatives = [
        'daughter', 'son', 'wife', 'mother[- ]in[- ]law', 'father[- ]in[- ]law', 'husband',
        'mother', 'father', 'sister', 'brother'
    ]
    rel_pat = r'\bmy\s+(?:' + '|'.join(relatives) + r')\s*(?:is|aged)?\s*([\d]{1,3})\b'
    match = re.search(rel_pat, text, re.IGNORECASE)
    if match:
        return int(match.group(1))
    match = re.search(r'in\s+(?:his|her|their)?\s*(twenties|thirties|forties|fifties|sixties|seventies|eighties|nineties)\b', text.lower())
    if match:
        return decades[match.group(1)]
    match = re.search(r'(\w+) and a half years?', text.lower())
    if match:
        words = {'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10}
        val = match.group(1)
        if val in words:
            return words[val] + 0.5
    for label, val in mapping.items():
        if label in text.lower():
            return val
    return -1

# ==== Extraction sexe ====
def extract_sex(text):
    t = text.lower()
    if re.search(r'\bmale\b|\bman\b|\bm\b|\bboy\b|\bson\b|\bhusband\b|\bhe\b|\bhis\b|\bprostate\b|\bpenis\b|\btesticle\b|\bgentleman\b|\bdad\b|\bfather\b|(semen|sperm|testical|motility|penis|testicles)', t, re.I):
        return 'M'
    elif re.search(r'\bfemale\b|\bwoman\b|\bf\b|\bgirl\b|\bdaughter\b|\bwife\b|\bshe\b|\bher\b|\bpregnant\b|\bvagina\b|\buterus\b|\bovary\b|\bmenstruation\b|\bmother\b|(pregnan|embryo|IVF|fetal|delivery|uterus|ovary|menstruation|period|girlfriend|wife|partner is pregnant|married woman|dysmenorrhea)', t, re.I):
        return 'F'
    if re.search(r'(sexual dysfunction|semen analysis|sperm count|erectile dysfunction|testosterone|gynecomastia|foreskin|errection|scrotum|testicle|sack|my brother)', t):
        return 'M'
    if re.search(r'(fertility treatment|fetal|fetus|oocyte|embryo transfer|gestational|insémination|dysmenorrhea|menstruation|vaginal|cervix|ovary|uterine|period|breast|menopause|ivf|iui|endometrio|cervix|pregnan|period|contraceptive|pregnancy|labia|pcos)|my sister', t):
        return 'F'
    return 'U'


In [27]:
# ==== Chargement et extraction ====
df = pd.read_csv("/Users/ines/NLP/emergency_chatbot/data/triage_dataset_preprocess.csv", encoding="utf8")
df['age'] = df['question'].apply(extract_age)
df['sex'] = df['question'].apply(extract_sex)
df['label'] = df['triage'].map({'non-urgent': 0, 'urgent': 1})

df['age'] = df['age'].replace(-1, np.nan)
df['age'] = df['age'].fillna(df['age'].median())
df['sex'] = df['sex'].fillna('U')

counts = df['sex'].value_counts()
total = len(df)
print(counts)
print("\nRATIO (%) :")
for cat in ['F', 'M', 'U']:
    pct = 100 * counts.get(cat, 0) / total
    print(f"{cat}: {counts.get(cat, 0)} ({pct:.1f}%)")

nb_age_detected = (df['age'] != -1).sum()
nb_total = len(df)
taux_age = nb_age_detected / nb_total * 100
print(f"Taux de lignes avec âge détecté : {nb_age_detected} / {nb_total} = {taux_age:.1f}%")



sex
U    21876
M    14082
F     8947
Name: count, dtype: int64

RATIO (%) :
F: 8947 (19.9%)
M: 14082 (31.4%)
U: 21876 (48.7%)
Taux de lignes avec âge détecté : 44905 / 44905 = 100.0%


In [29]:
display(df.head(10))

Unnamed: 0,question,triage,question_clean,age,sex,label
0,"I am 35 years old unmarried , i was diagonized...",non-urgent,35 year old unmarried diagonize hepatitis b su...,35.0,U,0
1,I have been having abdominal pain and burning ...,non-urgent,abdominal pain burn no relieve omeprazole bowe...,22.0,U,0
2,"sir, Day before yesterday i had an oil fried i...",urgent,sir day yesterday oil fry item snack chappathi...,22.0,M,1
3,"friend has a lump where their coccyx is, has b...",urgent,friend lump coccyx complete agony literally sc...,22.0,U,1
4,Which demographic should raise suspicion of a ...,non-urgent,demographic raise suspicion possible rubella i...,22.0,U,0
5,What bacterial infection can lead to the devel...,non-urgent,bacterial infection lead development risus sar...,22.0,U,0
6,Hi my daughter is two years old and lately she...,urgent,daughter year old lately loose stool black col...,2.0,F,1
7,I have a large anechoic cyst in my right kidne...,non-urgent,large anechoic cyst right kidney year reach 10...,13.0,U,0
8,What is the patella reflex and which nerve roo...,non-urgent,patella reflex nerve root test,22.0,U,0
9,"Hello doctor, I am 42 years old. I had a heart...",non-urgent,doctor 42 year old heart attack 12 year ago st...,42.0,U,0


In [31]:
# ==== Construction du texte d'entrée BERT ====
def build_input(row):
    text = row['question_clean']
    age = int(row['age']) if not pd.isnull(row['age']) else "unknown"
    sex = row['sex'] if row['sex'] != "U" else "unknown"
    return f"Age: {age}; Sex: {sex}; Symptoms: {text}"

df['input_text'] = df.apply(build_input, axis=1)



In [33]:
# ==== Split Train/Test ====
train_texts, test_texts, train_labels, test_labels = train_test_split(
    df['input_text'], df['label'], test_size=0.2, stratify=df['label'], random_state=SEED
)

# ==== Dataset HuggingFace ====
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

class TriageDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.encodings = tokenizer(texts.tolist(), padding=True, truncation=True, max_length=max_len)
        self.labels = labels.tolist()
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item
    def __len__(self):
        return len(self.labels)

train_dataset = TriageDataset(train_texts, train_labels, tokenizer, MAX_LEN)
test_dataset = TriageDataset(test_texts, test_labels, tokenizer, MAX_LEN)



In [35]:
import torch

if torch.cuda.is_available():
    print("✅ GPU disponible :", torch.cuda.get_device_name(0))
else:
    print("❌ GPU NON disponible, uniquement CPU")


❌ GPU NON disponible, uniquement CPU


In [37]:
# ==== Modèle & entraînement ====
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="no",
    logging_dir='./logs',
    logging_steps=20,
    seed=SEED
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

trainer.train()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# ==== Prédiction & évaluation ====
preds = trainer.predict(test_dataset)
y_pred = np.argmax(preds.predictions, axis=1)
print(classification_report(test_labels, y_pred, target_names=['non-urgent', 'urgent']))


In [None]:
# ==== Courbes ROC et PR ====
y_prob = torch.nn.functional.softmax(torch.tensor(preds.predictions), dim=1)[:,1].numpy()
fpr, tpr, _ = roc_curve(test_labels, y_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(7,5))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
plt.xlabel('Taux de faux positifs (1-Spécificité)')
plt.ylabel('Taux de vrais positifs (Sensibilité)')
plt.title('Courbe ROC - Modèle de triage urgence (BERT)')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()

cm = confusion_matrix(test_labels, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Non-urgent", "Urgent"])
disp.plot(cmap=plt.cm.Blues)
plt.title('Matrice de confusion')
plt.show()

precision, recall, thresholds = precision_recall_curve(test_labels, y_prob)
avg_prec = average_precision_score(test_labels, y_prob)

plt.figure(figsize=(7,5))
plt.plot(recall, precision, color='purple', lw=2, label=f'Courbe PR (AP={avg_prec:.2f})')
plt.xlabel('Recall (Sensibilité)')
plt.ylabel('Precision')
plt.title('Courbe Precision-Recall - Modèle de triage urgence (BERT)')
plt.legend(loc="upper right")
plt.grid(True)
plt.show()