# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder

import lightning as L
from torch.utils.data import TensorDataset, DataLoader

import json
import polars as pl
import pandas as pd

In [None]:
with open('../../params.json', 'r') as file :
    params = json.load(file)

DATASET, VERSION, DATA_FOLD = params['dataset'], params['version'], params['data_folder']

print(f'Working on {DATASET} dataset {VERSION}')

In [None]:
df = pl.read_parquet(f'{DATA_FOLD}/{VERSION}/2.clean_data/{DATASET}/temporal/five_days_dataset.parquet')

# Pré-traitement

In [None]:
df_pre_treated = (df
        .with_columns(
            pl.when(pl.col('deces_datediff') < 5).then(pl.lit(5)).otherwise('deces_datediff')
        ).with_columns(
            pl.when((pl.col('deces_datediff')-pl.col('delta_hour')/24) < 1).then(True)
            .otherwise(False)
                .alias('survival_inf24'),
            ((pl.col('deces_datediff') >= 90) | (pl.col('deces_datediff').is_null())).alias('j90_survival')
        )
)

df_pre_treated['j90_survival'].value_counts()

In [None]:
print(df_pre_treated.columns)
df_pre_treated.head()


# Datasets 

In [None]:
from sklearn.model_selection import train_test_split

## Identification des features

In [None]:
features = ['pam', 'pas', 'pad', 'heart_rate', 'spo2', 'nad_dose_poids', 'fio2_corr', 'is_ventilated', 'gender', 'age']
targets = ['survival_inf24', 'j90_survival', 'los']

In [None]:
encounter_set = df_pre_treated.select('encounterId', 'j90_survival').unique().to_pandas()

## Train et Test sets

In [None]:
train_encounters, test_encounters = train_test_split(encounter_set, stratify = encounter_set['j90_survival'], test_size=0.3 )

In [None]:
from sklearn.preprocessing import LabelEncoder, StandardScaler

data_trans = df_pre_treated.to_pandas().copy()

In [None]:
scaler = StandardScaler()
le = LabelEncoder()

numeric_data = ['pam', 'pas', 'pad', 'heart_rate', 'spo2', 'nad_dose_poids', 'fio2_corr','age']
cat_data = ['is_ventilated','gender', 'survival_inf24', 'j90_survival']
for n in numeric_data :
    data_trans[n] = scaler.fit_transform(data_trans[[n]])

for c in cat_data :
    data_trans[c] = le.fit_transform(data_trans[c])

In [None]:
len(data_trans.delta_hour.unique())

In [None]:
X =data_trans[['encounterId', 'delta_hour'] + features]
Y = data_trans[['encounterId', 'delta_hour'] + targets]


X_train = X[X['encounterId'].isin(train_encounters['encounterId'])].sort_values(by=['encounterId', 'delta_hour']).reset_index(drop=True)    
y_train = Y[Y['encounterId'].isin(train_encounters['encounterId'])].sort_values(by=['encounterId', 'delta_hour']).reset_index(drop=True)  
X_test = X[X['encounterId'].isin(test_encounters['encounterId'])].sort_values(by=['encounterId', 'delta_hour']).drop_duplicates(subset=['encounterId', 'delta_hour'], keep='first').reset_index(drop=True)  
y_test = Y[Y['encounterId'].isin(test_encounters['encounterId'])].sort_values(by=['encounterId', 'delta_hour']).drop_duplicates(subset=['encounterId', 'delta_hour'], keep='first').reset_index(drop=True)  


y_train_dropped = y_train.drop_duplicates('encounterId').sort_values(by='encounterId').reset_index(drop=True)
y_test_dropped = y_test.drop_duplicates('encounterId').sort_values(by='encounterId').reset_index(drop=True)
y_train_numpy = y_train_dropped['j90_survival'].to_numpy()
y_test_numpy = y_test_dropped['j90_survival'].to_numpy()

X_train_3d = X_train.drop(columns=['encounterId', 'delta_hour']).to_numpy().reshape(len(train_encounters), len(X_test.delta_hour.unique()), len(features))
X_test_3d = X_test.drop(columns=['encounterId', 'delta_hour']).to_numpy().reshape(len(test_encounters), len(X_test.delta_hour.unique()), (len(features)))

## Datasets Early/Late/Full

In [None]:
X_train_full = X_train_3d
X_train_early = X_train_3d[:, :48, :]
X_train_late = X_train_3d[:, -48:, :]


X_test_full = X_test_3d
X_test_early = X_test_3d[:, :48, :] 
X_test_late = X_test_3d[:, -48:, :]  

In [None]:
print(X_train_3d.shape)
print(X_test_3d.shape)

# LSTM standard

In [None]:
class LSTM(L.LightningModule):

    def __init__(self, input_size, hidden_size=32):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size)
        self.fc = nn.Linear(hidden_size, 1)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, input):
        input = input.permute(1, 0, 2)
        lstm_out, _ = self.lstm(input)
        return self.fc(lstm_out[-1]).squeeze(-1)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss


In [None]:
configs = {
    "full":   (X_train_3d, y_train_numpy),
    "early":  (X_train_3d[:, :48, :], y_train_numpy),
    "late":   (X_train_3d[:, -48:, :], y_train_numpy)
}


In [None]:
torch.set_float32_matmul_precision('medium')

trained_models = {}

for config_name, (X, y) in configs.items():
    print(f"\n🟢 Training config: {config_name} | shape = {X.shape}")

    # Crée le Dataset et DataLoader
    inputs = torch.tensor(X).float()
    labels = torch.tensor(y).float()
    dataset = TensorDataset(inputs, labels)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Instancie un modèle LSTM propre
    model = LSTM(input_size=X.shape[2])  # n_features

    # Trainer Lightning
    trainer = L.Trainer(
        max_epochs=100,
        log_every_n_steps=2,
        enable_progress_bar=True,
        logger=False,  # tu peux activer un logger si besoin
        enable_checkpointing=False
    )

    trainer.fit(model, train_dataloaders=dataloader)

    trained_models[config_name] = model


In [None]:
predict = torch.tensor(X_test_3d).to(torch.float32)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score, f1_score, roc_curve
from sklearn.calibration import calibration_curve

# Suppose : trained_models["full"], ["early"], ["late"]
#           X_test_full, X_test_early, X_test_late
#           y_test_numpy

X_test_dict = {
    "full": X_test_full,
    "early": X_test_early,
    "late": X_test_late
}

results = []
roc_curves = {}
calib_curves = {}

for config_name, model in trained_models.items():
    print(f"⏳ Évaluation : {config_name}")
    X = X_test_dict[config_name]
    model.eval()
    with torch.no_grad():
        logits = model(torch.tensor(X).float())
        probs = torch.sigmoid(logits).cpu().numpy()

    y_pred_bin = (probs > 0.5).astype(int)
    auc = roc_auc_score(y_test_numpy, probs)
    f1 = f1_score(y_test_numpy, y_pred_bin)

    # Sauvegarde
    results.append({"Configuration": config_name, "AUC": auc, "F1 Score": f1})

    fpr, tpr, _ = roc_curve(y_test_numpy, probs)
    roc_curves[config_name] = (fpr, tpr, auc)

    prob_true, prob_pred = calibration_curve(y_test_numpy, probs, n_bins=10, strategy='quantile')
    calib_curves[config_name] = (prob_pred, prob_true)

# ➕ DataFrame résultats
results_df = pd.DataFrame(results)
print("\n📊 Tableau récapitulatif :")
print(results_df)

# 📈 Courbe ROC
plt.figure(figsize=(6, 5))
for name, (fpr, tpr, auc) in roc_curves.items():
    plt.plot(fpr, tpr, label=f"{name} (AUC={auc:.2f})")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves")
plt.legend()
plt.grid(True)
plt.show()

# 📈 Courbe de calibration
plt.figure(figsize=(6, 5))
for name, (pred, true) in calib_curves.items():
    plt.plot(pred, true, marker='o', label=name)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel("Proba prédite (moyenne)")
plt.ylabel("Fréquence observée")
plt.title("Calibration Curve (10 bins)")
plt.legend()
plt.grid(True)
plt.show()


Métriques

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve

auc = roc_auc_score(y_test_numpy, probs)
print(f"AUC: {auc:.3f}")

In [None]:
from sklearn.metrics import precision_recall_curve

precision, recall, thresholds = precision_recall_curve(y_test_numpy, probs)

f2_scores = 5 * (precision * recall) / (4 * precision + recall)
best_idx = np.argmax(f2_scores)
best_thresh = thresholds[best_idx]

print(f"Best F2-score: {f2_scores[best_idx]:.3f} at threshold {best_thresh:.2f}")


In [None]:
import matplotlib.pyplot as plt

fpr, tpr, _ = roc_curve(y_test_numpy, probs)

plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

# y_test_numpy : vraies étiquettes binaires (0/1)
# probs         : probabilités prédites (sigmoidées)

# Calcule les probabilités moyennes dans chaque bin
prob_true, prob_pred = calibration_curve(y_test_numpy, probs, n_bins=10, strategy='quantile')

# Affichage
plt.figure()
plt.plot(prob_pred, prob_true, marker='o', label="Modèle")
plt.plot([0, 1], [0, 1], 'k--', label="Idéal")
plt.xlabel("Probabilité prédite (moyenne par bin)")
plt.ylabel("Fréquence observée")
plt.title("Courbe de calibration (10 bins)")
plt.legend()
plt.grid(True)
plt.show()
