## 0. Libraries 📚

In [None]:
import ast
import gc
import random
import warnings

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import optuna

warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 1. Load data 📥

In [None]:
diagnoses_df = pd.read_csv("data/ground_truth_df.csv")
for col in ["Codigos_diagnosticos", "Diagnosticos_estandar"]:
    diagnoses_df[col] = diagnoses_df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df

## 2. Pre-process and splits 🧹✂️

In [None]:
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(diagnoses_df["Diagnosticos_estandar"])
# Add the prefix required by e5-Large
texts_prefixed = ["query: " + t for t in diagnoses_df["Descripcion_diagnosticos"].tolist()]

# Random splits on multilabel data, ensuring that each label's distribution is preserved across training and test sets.
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.30, random_state=SEED
)
X = np.array(texts_prefixed)

for train_idx, tmp_idx in msss.split(np.zeros(len(X)), y):
    X_train, y_train = X[train_idx], y[train_idx]
    X_tmp,   y_tmp   = X[tmp_idx],   y[tmp_idx]

# 50-50 over the 30 % left ⇒ 15 %/15 %
msss_val = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.50, random_state=SEED
)
for val_idx, test_idx in msss_val.split(np.zeros(len(X_tmp)), y_tmp):
    X_val,  y_val  = X_tmp[val_idx],  y_tmp[val_idx]
    X_test, y_test = X_tmp[test_idx], y_tmp[test_idx]

# Convert to lists for compatibility with SentenceTransformer
X_train, X_val, X_test = map(lambda a: a.tolist(), [X_train, X_val, X_test])

# Convert to numpy arrays and float32 for PyTorch compatibility
y_train, y_val, y_test = y_train.astype(np.float32), y_val.astype(np.float32), y_test.astype(np.float32)

In [None]:
# Check that each label maintains its ratio approx.
train_ratio = y_train.sum(axis=0) / y.sum(axis=0)
val_ratio   = y_val.sum(axis=0)   / y.sum(axis=0)
test_ratio  = y_test.sum(axis=0)  / y.sum(axis=0)

print(f"train: {np.round(train_ratio.mean(), 3)}, val: {np.round(val_ratio.mean(), 3)}, test: {np.round(test_ratio.mean(), 3)}")

## 3. Dataset and model architecture ⚙️

In [None]:
# === Load model and tokenizer ===
MODEL_NAME = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# === Custom classifier model ===
class DiagnosisClassifier(nn.Module):
    def __init__(self, base_model, base_model_output_dim=1024, hidden_dim=768, num_labels=10):
        super().__init__()
        self.base = base_model
        self.classifier = nn.Sequential(
            nn.Linear(base_model_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        return self.classifier(pooled_output)

# === Dataset wrapper ===
class DiagnosisDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.encodings = tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

## 4. Baseline single model 🌱

In [None]:
EPOCHS_BASELINE = 10
LEARNING_RATE = 2e-5
BATCH_SIZE = 16
HIDDEN_DIM = 1024
MAX_LEN = 256

In [None]:
train_ds = DiagnosisDataset(X_train, y_train, tokenizer, MAX_LEN)
val_ds   = DiagnosisDataset(X_val,   y_val,   tokenizer, MAX_LEN)
test_ds = DiagnosisDataset(X_test, y_test, tokenizer, MAX_LEN)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

In [None]:
# === Load model and tokenizer ===
base_model = AutoModel.from_pretrained(MODEL_NAME)

# === Initialize model ===
baseline = DiagnosisClassifier(base_model=base_model, hidden_dim=HIDDEN_DIM, num_labels=y.shape[1])
baseline.to(device)

# === Training setup ===
optimizer = torch.optim.AdamW(baseline.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

In [None]:
def epoch_loop(model, data_loader, criterion, optimizer, train=False):
    model.train() if train else model.eval()
    losses, logits_list, labels_list = [], [], []

    for batch in tqdm(data_loader, desc=f"Batches ({'train' if train else 'eval'})"):
        input_ids  = batch["input_ids"].to(device)
        attn_mask  = batch["attention_mask"].to(device)
        labels     = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attn_mask)
        loss = criterion(outputs, labels)
        if train:
            optimizer.zero_grad(); loss.backward(); optimizer.step()

        losses.append(loss.item())
        logits_list.append(torch.sigmoid(outputs).detach().cpu())
        labels_list.append(labels.detach().cpu())

    y_pred = (torch.vstack(logits_list) > 0.5).int().numpy()
    y_true = torch.vstack(labels_list).numpy()
    return (np.mean(losses),
            f1_score(y_true, y_pred, average="micro", zero_division=0),
            precision_score(y_true, y_pred, average="micro", zero_division=0),
            recall_score(y_true, y_pred, average="micro", zero_division=0)
            )

In [None]:
best_f1 = 0
for epoch in range(1, EPOCHS_BASELINE+1):
    print(f"🔹 Epoch {epoch:02d} / {EPOCHS_BASELINE}")
    train_loss, train_f1, _, _ = epoch_loop(baseline, train_dl, criterion, optimizer, train=True)
    val_loss, val_f1, _, _     = epoch_loop(baseline, val_dl, criterion, optimizer, train=False)
    print(f"Epoch [{epoch:02d}]  train loss={train_loss:.4f} | val loss={val_loss:.4f}")
    print(f"Epoch [{epoch:02d}]  train f1={train_f1:.4f} | val f1={val_f1:.4f}")
    best_f1 = max(best_f1, val_f1)

_, f1_train, precision_train, recall_train = epoch_loop(baseline, train_dl, criterion, optimizer, train=False)
_, f1_val, precision_val, recall_val = epoch_loop(baseline, val_dl, criterion, optimizer, train=False)
_, f1_test, precision_test, recall_test = epoch_loop(baseline, test_dl, criterion, optimizer, train=False)
print(f"🔹 TRAIN: F1 {f1_train:.4f} / Precision {precision_train:.4f} / Recall {recall_train:.4f}")
print(f"🔹 VAL:   F1 {f1_val:.4f} / Precision {precision_val:.4f} / Recall {recall_val:.4f}")
print(f"🔹 TEST:  F1 {f1_test:.4f} / Precision {precision_test:.4f} / Recall {recall_test:.4f}")

In [None]:
# Save the model with the best hyperparameters found by Optuna
torch.save({
    "model_state": baseline.state_dict(),
    "tokenizer": MODEL_NAME,
    "mlb_classes": mlb.classes_.tolist(),
    "params": {
        "hidden_dim": HIDDEN_DIM,
        "max_len": MAX_LEN,
        "lr": LEARNING_RATE,
        "epochs": EPOCHS_BASELINE
    }
}, "models/baseline_embeddings_and_nn.pt")
print("📦 Baseline model saved as models/baseline_embeddings_and_nn.pt")

In [None]:
del train_ds, val_ds, test_ds
del train_dl, val_dl, test_dl
del base_model, baseline
del optimizer, criterion
gc.collect()

## 5. Hyperparameter Search with Optuna 🚀
-------------------------------------------------
We search for LR, hidden_dim, max_len, and batch_size.  
We use pruning to stop underperforming experiments.

In [None]:
BATCH_SIZE_OPTUNA = 16

In [None]:
def objective(trial):
    # --- Hyperparameters to explore ---
    lr          = trial.suggest_float("lr", 1e-5, 5e-3, log=True)
    hidden_dim  = trial.suggest_int("hidden_dim", 256, 1536, step=64)
    max_len     = trial.suggest_int("max_len", 128, 256, step=64)
    epochs      = trial.suggest_int("epochs", 3, 20)

    print(f"🔍 Trial {trial.number} | lr={lr:.6f}, hidden_dim={hidden_dim}, max_len={max_len}, epochs={epochs}")

    # ---- Dataset - (re-tokenizo for the specific max_len) ----
    token_tmp = tokenizer
    train_ds_t = DiagnosisDataset(X_train, y_train, token_tmp, max_len)
    val_ds_t   = DiagnosisDataset(X_val,   y_val,   token_tmp, max_len)
    train_dl_t = DataLoader(train_ds_t, batch_size=BATCH_SIZE_OPTUNA, shuffle=True)
    val_dl_t   = DataLoader(val_ds_t,   batch_size=BATCH_SIZE_OPTUNA)

    base_model = AutoModel.from_pretrained(MODEL_NAME)
    for p in base_model.parameters(): p.requires_grad = True
    model_t = DiagnosisClassifier(base_model=base_model, hidden_dim=hidden_dim, num_labels=y.shape[1]).to(device)

    optimizer_t = torch.optim.AdamW(model_t.parameters(), lr=lr)
    criterion_t = nn.BCEWithLogitsLoss()

    # ---- Trained for a few epochs with pruning ----
    for epoch in range(epochs):
        print(f"🔹 Trial {trial.number} - Epoch {epoch+1}/{epochs}")
        train_loss_epoch, train_f1_epoch, _, _ = epoch_loop(model_t, train_dl_t, criterion_t, optimizer_t, train=True)
        val_loss_epoch, val_f1_epoch, _, _ = epoch_loop(model_t, val_dl_t, criterion_t, optimizer_t, train=False)
        trial.report(val_f1_epoch, epoch)
        print(f"Epoch [{epoch:02d}]  train loss={train_loss_epoch:.4f} | val loss={val_loss_epoch:.4f}")
        print(f"Epoch [{epoch:02d}]  train f1={train_f1_epoch:.4f} | val f1={val_f1_epoch:.4f}")

        if trial.should_prune(): raise optuna.TrialPruned()
    
    # ---- Delete everything to free memory ----
    del train_ds_t, val_ds_t
    del train_dl_t, val_dl_t
    del base_model, model_t
    del optimizer_t, criterion_t
    gc.collect()

    # ---- Return the final validation F1 score ----
    return val_f1_epoch

In [None]:
study = optuna.create_study(
    direction="maximize",
    study_name="diagnosis_cls_baseline",
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=1),
    storage="sqlite:///optuna/Embeddings_and_classification_layer.db",
    load_if_exists=True
)

TOTAL_TRIALS = 18
remaining_trials = max(TOTAL_TRIALS - len(study.trials), 0)
study.optimize(objective, n_trials=remaining_trials, n_jobs=1)

In [None]:
print("✅ Best trial:", study.best_trial.number)
print("🏆 Best configuration:", study.best_params)
print("🔝 Best F1 val:", study.best_value)
optuna.visualization.plot_optimization_history(study)

## 6. Retrain the model with the best configuration 🔄

In [None]:
# --- 1) Retrieve the best search space ---

best_params = study.best_params
print(best_params)

BEST_LR             = best_params["lr"]
BEST_HIDDEN         = best_params["hidden_dim"]
BEST_MAX_LEN        = best_params["max_len"]
BEST_EPOCHS         = best_params["epochs"]

# --- 2) Dataset: train + val (80 %) ---
token_tmp = tokenizer
train_ds_t = DiagnosisDataset(X_train, y_train, token_tmp, BEST_MAX_LEN)
val_ds_t   = DiagnosisDataset(X_val,   y_val,   token_tmp, BEST_MAX_LEN)
test_ds_t  = DiagnosisDataset(X_test, y_test, token_tmp, BEST_MAX_LEN)
train_dl_t = DataLoader(train_ds_t, batch_size=BATCH_SIZE_OPTUNA, shuffle=True)
val_dl_t   = DataLoader(val_ds_t,   batch_size=BATCH_SIZE_OPTUNA)
test_dl_t  = DataLoader(test_ds_t,  batch_size=BATCH_SIZE_OPTUNA)

# --- 3) Final model ---
base_model = AutoModel.from_pretrained(MODEL_NAME)
for p in base_model.parameters(): p.requires_grad = True
model_t = DiagnosisClassifier(base_model=base_model, hidden_dim=BEST_HIDDEN, num_labels=y.shape[1]).to(device)

optimizer_t = torch.optim.AdamW(model_t.parameters(), lr=BEST_LR)
criterion_t = nn.BCEWithLogitsLoss()

# --- 4) Training + early-stopping ---
for epoch in range(BEST_EPOCHS):
    print(f"🔹 Epoch {epoch+1:02d} / {BEST_EPOCHS}")
    train_loss_epoch, train_f1_epoch, _, _ = epoch_loop(model_t, train_dl_t, criterion_t, optimizer_t, train=True)
    val_loss_epoch, val_f1_epoch, _, _ = epoch_loop(model_t, val_dl_t, criterion_t, optimizer_t, train=False)
    print(f"Epoch [{epoch+1:02d}]  train loss={train_loss_epoch:.4f} | val loss={val_loss_epoch:.4f}")
    print(f"Epoch [{epoch+1:02d}]  train f1={train_f1_epoch:.4f} | val f1={val_f1_epoch:.4f}")
    

_, f1_train, precision_train, recall_train = epoch_loop(model_t, train_dl_t, criterion_t, optimizer_t, train=False)[1]
_, f1_val, precision_val, recall_val = epoch_loop(model_t, val_dl_t, criterion_t, optimizer_t, train=False)[1]
_, f1_test, precision_test, recall_test = epoch_loop(model_t, test_dl_t, criterion_t, optimizer_t, train=False)[1]
print(f"🔹 TRAIN: F1 {f1_train:.4f} / Precision {precision_train:.4f} / Recall {recall_train:.4f}")
print(f"🔹 VAL:   F1 {f1_val:.4f} / Precision {precision_val:.4f} / Recall {recall_val:.4f}")
print(f"🔹 TEST:  F1 {f1_test:.4f} / Precision {precision_test:.4f} / Recall {recall_test:.4f}")

In [None]:
# Save the model with the best hyperparameters found by Optuna
torch.save({
    "model_state": model_t.state_dict(),
    "tokenizer": MODEL_NAME,
    "mlb_classes": mlb.classes_.tolist(),
    "params": best_params
}, "models/optimized_embeddings_and_nn.pt")
print("📦 Optuna best model saved as models/optimized_embeddings_and_nn.pt")