## 0. Libraries 📚

In [None]:
import ast
import gc
import random
import warnings

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from utils import read_cie10_file
import optuna

warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 1. Load data 📥

In [None]:
cie10_map = read_cie10_file("data/diagnosticos_tipos.csv")
cie10_map

In [None]:
diagnoses_df = pd.read_csv("data/ground_truth_df.csv")
for col in ["Codigos_diagnosticos", "Diagnosticos_estandar"]:
    diagnoses_df[col] = diagnoses_df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df

In [None]:
import pandas as pd

# --- 1. Define high-level ICD-10 families ------------------------------ #
groups = {
    "SUBSTANCE_USE": [
        "F10.0", "F11.0", "F12.0", "F13.0", "F14.0",
        "F17.0", "F19.0"
    ],
    "SCHIZOPHRENIA_SPECTRUM": [
        "F20.0", "F21.0", "F22.0", "F25.0", "F29.0"
    ],
    "MOOD_DISORDERS": [
        "F30.9", "F31.0", "F31.1", "F31.3", "F31.81", "F31.9",
        "F32.0", "F32.3", "F32.9", "F33.0", "F33.9",
        "F34.0", "F34.1", "F39.0"
    ],
    "ANXIETY_STRESS": [
        "F4.0", "F40.0", "F40.9",
        "F41.0", "F41.1", "F41.2",
        "F42.0",
        "F43.0", "F43.1", "F43.2",
        "F44.0", "F45.0", "F48.0"
    ],
    "BEHAVIORAL_PHYSIOLOGICAL": [
        "F50.0", "F50.2", "F50.9", "ATRACON",
        "F51.0", "F52.0", "F53.0", "F54.0"
    ],
    "PERSONALITY_PARAPHILIC": [
        "F60.0", "F60.1", "F60.2", "F60.3", "F60.4",
        "F60.5", "F60.6", "F60.7", "F60.8", "F60.9",
        "F63.0", "F64.0", "F65.0", "F68.0"
    ],
    "INTELLECTUAL_DISABILITY": ["F79.0"],
    "DEVELOPMENTAL": ["F84.9", "F89.0"],
    "CHILD_ADOLESCENT": ["F90.9", "F91.3", "F94.0", "F95.0", "F98.9"],
    "UNSPECIFIED_MENTAL": ["F99.0"],
    "NON_F_MISC": [
        "E65_E68", "M79.7", "R45.851", "T14.91", "X6_",
        "Z63", "Z63.4",
        "COGNITIV", "FAM_APO", "LAB_MOB", "PAREJ",
        "No_DX", "altas_capacidades"
    ]
}

# --- 2. Reverse the mapping: code → family ----------------------------- #
code_to_group = {
    code: family
    for family, codes in groups.items()
    for code in codes
}

# Optionally, decide where to park unknown codes
DEFAULT_FAMILY = "UNMAPPED"

# --- 3. Helper to convert each list of codes --------------------------- #
def map_codes_to_groups(code_list):
    # Make sure we return unique family names while
    # preserving list structure (set → list)
    return list({
        code_to_group.get(code, DEFAULT_FAMILY)
        for code in code_list
    })

# --- 4. Apply to the dataframe ---------------------------------------- #
diagnoses_df["Familias_diagnosticos"] = (
    diagnoses_df["Codigos_diagnosticos"]
    .apply(map_codes_to_groups)
)
unique_codes = sorted(set(code for codes in diagnoses_df["Familias_diagnosticos"] for code in codes))

print(f"Total family codes: {len(unique_codes)}")
for code in unique_codes:
    print("-", code)

diagnoses_df

In [None]:
# Choose a group to filter, e.g., "MOOD_DISORDERS"
chosen_group = "SUBSTANCE_USE"
# chosen_group = "NON_F_MISC"
# chosen_group = "general"
# chosen_group = "full"

if chosen_group in groups and not chosen_group.lower() == "general" and not chosen_group.lower() == "full":
    # Get all codes belonging to the chosen group
    group_codes = set(groups[chosen_group])

    def filter_codes(codes):
        return [code for code in codes if code in group_codes]

    # Filter rows where at least one code in Codigos_diagnosticos belongs to the chosen group
    mask = diagnoses_df["Codigos_diagnosticos"].apply(lambda codes: any(code in group_codes for code in codes))
    diagnoses_df = diagnoses_df[mask].copy()

    # Keep only codes of the chosen group in Codigos_diagnosticos
    diagnoses_df["Codigos_diagnosticos"] = diagnoses_df["Codigos_diagnosticos"].apply(filter_codes)

    # Keep only corresponding Diagnosticos_estandar entries (by index of kept codes)
    def filter_estandar(row):
        kept = [i for i, code in enumerate(row["Codigos_diagnosticos"]) if code in group_codes]
        return [row["Diagnosticos_estandar"][i] for i in kept]

    diagnoses_df["Diagnosticos_estandar"] = diagnoses_df.apply(filter_estandar, axis=1)

diagnoses_df

## 2. Pre-process and splits 🧹✂️

In [None]:
mlb = MultiLabelBinarizer()
if chosen_group.lower() == "general":
    y = mlb.fit_transform(diagnoses_df["Familias_diagnosticos"])
else:
    y = mlb.fit_transform(diagnoses_df["Diagnosticos_estandar"])
# Add the prefix required by e5-Large
texts_prefixed = ["query: " + t for t in diagnoses_df["Descripcion_diagnosticos"].tolist()]

# Random splits on multilabel data, ensuring that each label's distribution is preserved across training and test sets.
msss = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.30, random_state=SEED
)
X = np.array(texts_prefixed)

for train_idx, tmp_idx in msss.split(np.zeros(len(X)), y):
    X_train, y_train = X[train_idx], y[train_idx]
    X_tmp,   y_tmp   = X[tmp_idx],   y[tmp_idx]

# 50-50 over the 30 % left ⇒ 15 %/15 %
msss_val = MultilabelStratifiedShuffleSplit(
    n_splits=1, test_size=0.50, random_state=SEED
)
for val_idx, test_idx in msss_val.split(np.zeros(len(X_tmp)), y_tmp):
    X_val,  y_val  = X_tmp[val_idx],  y_tmp[val_idx]
    X_test, y_test = X_tmp[test_idx], y_tmp[test_idx]

# Convert to lists for compatibility with SentenceTransformer
X_train, X_val, X_test = map(lambda a: a.tolist(), [X_train, X_val, X_test])

# Convert to numpy arrays and float32 for PyTorch compatibility
y_train, y_val, y_test = y_train.astype(np.float32), y_val.astype(np.float32), y_test.astype(np.float32)

In [None]:
# Check that each label maintains its ratio approx.
train_ratio = y_train.sum(axis=0) / y.sum(axis=0)
val_ratio   = y_val.sum(axis=0)   / y.sum(axis=0)
test_ratio  = y_test.sum(axis=0)  / y.sum(axis=0)

print(f"train: {np.round(train_ratio.mean(), 3)}, val: {np.round(val_ratio.mean(), 3)}, test: {np.round(test_ratio.mean(), 3)}")

## 3. Dataset and model architecture ⚙️

In [None]:
# === Load model and tokenizer ===
MODEL_NAME = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# === Custom classifier model ===
class DiagnosisClassifier(nn.Module):
    def __init__(self, base_model, base_model_output_dim=1024, hidden_dim=768, num_labels=10):
        super().__init__()
        self.base = base_model
        self.classifier = nn.Sequential(
            nn.Linear(base_model_output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        return self.classifier(pooled_output)

# === Dataset wrapper ===
class DiagnosisDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.encodings = tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch

# Create figure and axis
fig, ax = plt.subplots(figsize=(11, 7))
ax.axis("off")

# Define boxes: (x, y, width, height, label, dashed?)
boxes = {
    "input": (0.05, 0.5, 0.18, 0.12, "Descripción\npsiquiátrica", False),
    "general": (0.30, 0.5, 0.20, 0.12, "Modelo\ngeneral", False),
    "famA": (0.55, 0.77, 0.20, 0.12, "Modelo\nFamilia A", False),
    "famB": (0.55, 0.57, 0.20, 0.12, "Modelo\nFamilia B", False),
    "famC": (0.55, 0.37, 0.20, 0.12, "Modelo\nFamilia C", False),
    "famD": (0.55, 0.17, 0.20, 0.12, "Modelo\nFamilia D\n(No empleado)", True),
    "output": (0.80, 0.57, 0.18, 0.12, "Lista final\ndiagnósticos", False),
}

# Draw boxes in orange
for key, (x, y, w, h, label, dashed) in boxes.items():
    style = "round,pad=0.02"
    patch = FancyBboxPatch(
        (x, y), w, h,
        boxstyle=style,
        facecolor="orange",      # Orange fill
        edgecolor="black",       # Black border for clarity
        linestyle="--" if dashed else "-",
        linewidth=1.5
    )
    ax.add_patch(patch)
    ax.text(x + w / 2, y + h / 2, label, ha="center", va="center", fontsize=10)

# Arrow helper
def draw_arrow(src, dst, curve=0.0, dashed=False):
    xs, ys, ws, hs, _, _ = (*boxes[src][:4], None, None)  # unpack first 4
    xd, yd, wd, hd, _, _ = (*boxes[dst][:4], None, None)
    start = (xs + ws, ys + hs / 2)
    end = (xd, yd + hd / 2)
    arrow = FancyArrowPatch(
        start,
        end,
        arrowstyle="-|>",
        mutation_scale=20,
        linewidth=1.5,
        linestyle="--" if dashed else "-",
        connectionstyle=f"arc3,rad={curve}",
    )
    ax.add_patch(arrow)

# Draw arrows for active path
draw_arrow("input", "general")
draw_arrow("general", "famA", curve=0.4)
draw_arrow("general", "famB", curve=0.1)
draw_arrow("general", "famC", curve=-0.2)

# Arrows from family models to output
draw_arrow("famA", "output", curve=0.4)
draw_arrow("famB", "output", curve=0.1)
draw_arrow("famC", "output", curve=-0.2)

plt.tight_layout()
plt.show()

## 4. Baseline single model 🌱

-------------------------------------------------
We freeze the encoder and train ONLY one linear layer.

In [None]:
EPOCHS_BASELINE = 10
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-2 # NEW
BATCH_SIZE = 16
HIDDEN_DIM = 1024
MAX_LEN = 256

In [None]:
train_ds = DiagnosisDataset(X_train, y_train, tokenizer, MAX_LEN)
val_ds   = DiagnosisDataset(X_val,   y_val,   tokenizer, MAX_LEN)
test_ds = DiagnosisDataset(X_test, y_test, tokenizer, MAX_LEN)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

In [None]:
# === Initialize model ===
base_model = AutoModel.from_pretrained(MODEL_NAME)
for p in base_model.parameters(): p.requires_grad = True
baseline = DiagnosisClassifier(base_model=base_model, hidden_dim=HIDDEN_DIM, num_labels=y.shape[1]).to(device)

# === Training setup ===
optimizer = torch.optim.AdamW(baseline.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

In [None]:
def epoch_loop(model, data_loader, criterion, optimizer, scheduler=None, train=False):
    model.train() if train else model.eval()
    losses, logits_list, labels_list = [], [], []
    for batch in tqdm(data_loader, desc=f"Batches ({'train' if train else 'eval'})"):
        input_ids  = batch["input_ids"].to(device)
        attn_mask  = batch["attention_mask"].to(device)
        labels     = batch["labels"].to(device)

        outputs = model(input_ids, attn_mask)
        loss = criterion(outputs, labels)
        if train:
            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if scheduler: scheduler.step()
        losses.append(loss.item())

        logits_list.append(torch.sigmoid(outputs).detach().cpu())
        labels_list.append(labels.detach().cpu())
    y_pred = (torch.vstack(logits_list) > 0.5).int().numpy()
    y_true = torch.vstack(labels_list).numpy()
    return (np.mean(losses),
            f1_score(y_true, y_pred, average="micro", zero_division=0),
            precision_score(y_true, y_pred, average="micro", zero_division=0),
            recall_score(y_true, y_pred, average="micro", zero_division=0)
            )

In [None]:
best_f1 = 0
for epoch in range(1, EPOCHS_BASELINE+1):
    print(f"🔹 Epoch {epoch:02d} / {EPOCHS_BASELINE}")
    train_loss, train_f1, _, _ = epoch_loop(baseline, train_dl, criterion, optimizer, train=True)
    val_loss, val_f1, _, _ = epoch_loop(baseline, val_dl, criterion, optimizer, train=False)
    print(f"Epoch [{epoch:02d}]  train f1={train_f1:.4f} | val f1={val_f1:.4f}")
    best_f1 = max(best_f1, val_f1)

_, f1_train, precision_train, recall_train = epoch_loop(baseline, train_dl, criterion, optimizer, train=False)
_, f1_val, precision_val, recall_val = epoch_loop(baseline, val_dl, criterion, optimizer, train=False)
_, f1_test, precision_test, recall_test = epoch_loop(baseline, test_dl, criterion, optimizer, train=False)
print(f"🔹 TRAIN: F1 {f1_train:.4f} / Precision {precision_train:.4f} / Recall {recall_train:.4f}")
print(f"🔹 VAL:   F1 {f1_val:.4f} / Precision {precision_val:.4f} / Recall {recall_val:.4f}")
print(f"🔹 TEST:  F1 {f1_test:.4f} / Precision {precision_test:.4f} / Recall {recall_test:.4f}")

In [None]:
# Save the model with the best hyperparameters found by Optuna
save_path = f"models/Family_model_{chosen_group if 'chosen_group' in locals() else 'general'}.pt"
torch.save({
    "model_state": baseline.state_dict(),
    "tokenizer": MODEL_NAME,
    "mlb_classes": mlb.classes_.tolist(),
}, save_path)
print(f"📦 Model saved as {save_path}")

In [None]:
del train_ds, val_ds, test_ds
del train_dl, val_dl, test_dl
del base_model, baseline, optimizer, criterion
torch.cuda.empty_cache()
gc.collect()

## 5. Full model 🧠📦

In [None]:
train_ds = DiagnosisDataset(X_train, y_train, tokenizer, MAX_LEN)
val_ds   = DiagnosisDataset(X_val,   y_val,   tokenizer, MAX_LEN)
test_ds = DiagnosisDataset(X_test, y_test, tokenizer, MAX_LEN)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

In [None]:
import os

MODELS_DIR = "models/family-models"

def load_model(model_path, device):
    checkpoint = torch.load(model_path, map_location=device)
    cls_w = checkpoint["model_state"]["classifier.0.weight"]
    hidden_dim_loaded = cls_w.shape[0]
    mlb_classes = checkpoint["mlb_classes"]
    
    model = DiagnosisClassifier(
        base_model=AutoModel.from_pretrained(checkpoint["tokenizer"]),
        hidden_dim=hidden_dim_loaded,
        num_labels=len(mlb_classes)
    ).to(device)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()
    
    return model, mlb_classes

In [None]:
def predict_labels(model, mlb_classes, data_loader, threshold=0.5):
    model.eval()
    logits_list = []
    
    for batch in tqdm(data_loader, desc=f"Batches (eval)"):
        input_ids  = batch["input_ids"].to(device)
        attn_mask  = batch["attention_mask"].to(device)

        outputs = model(input_ids, attn_mask)

        logits_list.append(torch.sigmoid(outputs).detach().cpu())
    
    y_pred = (torch.vstack(logits_list) > threshold).int().numpy()
    return [
        [mlb_classes[i] for i, flag in enumerate(row) if flag == 1]
        for row in y_pred
    ]

In [None]:
general_model, general_mlb_classes = load_model(os.path.join(MODELS_DIR, "Family_model_general.pt"), device)
general_preds = predict_labels(general_model, general_mlb_classes, train_dl)
del general_model, general_mlb_classes
torch.cuda.empty_cache()
gc.collect()

final_preds = [[] for _ in range(len(general_preds))]
groups_names = [group_name for group_name in groups.keys() if len(groups[group_name]) > 1]

for i, preds in enumerate(general_preds):
    if "INTELLECTUAL_DISABILITY" in preds:
        final_preds[i].append(cie10_map[groups["INTELLECTUAL_DISABILITY"][0]])
    if "UNSPECIFIED_MENTAL" in preds:
        final_preds[i].append(cie10_map[groups["UNSPECIFIED_MENTAL"][0]])

for group_name in groups_names:
    print(f"*** {group_name} ***")
    group_filter = [group_name in preds for preds in general_preds]
    group_model, group_mlb_classes = load_model(os.path.join(MODELS_DIR, f"Family_model_{group_name}.pt"), device)
    group_preds = predict_labels(group_model, group_mlb_classes, train_dl)

    for i, (pred, keep) in enumerate(zip(group_preds, group_filter)):
        if keep:
            final_preds[i].extend(pred)


    del group_model, group_mlb_classes
    torch.cuda.empty_cache()
    gc.collect()

y_pred = mlb.transform(final_preds)
train_f1_score = f1_score(y_train, y_pred, average="micro", zero_division=0)
train_precision = precision_score(y_train, y_pred, average="micro", zero_division=0)
train_recall = recall_score(y_train, y_pred, average="micro", zero_division=0)
print(f"🔹 Train F1 score:  {train_f1_score:.4f}")
print(f"🔹 Train Precision: {train_precision:.4f}")
print(f"🔹 Train Recall: {train_recall:.4f}")

In [None]:
general_model, general_mlb_classes = load_model(os.path.join(MODELS_DIR, "Family_model_general.pt"), device)
general_preds = predict_labels(general_model, general_mlb_classes, val_dl)
del general_model, general_mlb_classes
torch.cuda.empty_cache()
gc.collect()

final_preds = [[] for _ in range(len(general_preds))]
groups_names = [group_name for group_name in groups.keys() if len(groups[group_name]) > 1]

for i, preds in enumerate(general_preds):
    if "INTELLECTUAL_DISABILITY" in preds:
        final_preds[i].append(cie10_map[groups["INTELLECTUAL_DISABILITY"][0]])
    if "UNSPECIFIED_MENTAL" in preds:
        final_preds[i].append(cie10_map[groups["UNSPECIFIED_MENTAL"][0]])

for group_name in groups_names:
    print(f"*** {group_name} ***")
    group_filter = [group_name in preds for preds in general_preds]
    group_model, group_mlb_classes = load_model(os.path.join(MODELS_DIR, f"Family_model_{group_name}.pt"), device)
    group_preds = predict_labels(group_model, group_mlb_classes, val_dl)

    for i, (pred, keep) in enumerate(zip(group_preds, group_filter)):
        if keep:
            final_preds[i].extend(pred)


    del group_model, group_mlb_classes
    torch.cuda.empty_cache()
    gc.collect()

y_pred = mlb.transform(final_preds)
val_f1_score = f1_score(y_val, y_pred, average="micro", zero_division=0)
val_precision = precision_score(y_val, y_pred, average="micro", zero_division=0)
val_recall = recall_score(y_val, y_pred, average="micro", zero_division=0)
print(f"🔹 Val F1 score:  {val_f1_score:.4f}")
print(f"🔹 Val Precision: {val_precision:.4f}")
print(f"🔹 Val Recall: {val_recall:.4f}")

In [None]:
general_model, general_mlb_classes = load_model(os.path.join(MODELS_DIR, "Family_model_general.pt"), device)
general_preds = predict_labels(general_model, general_mlb_classes, test_dl)
del general_model, general_mlb_classes
torch.cuda.empty_cache()
gc.collect()

final_preds = [[] for _ in range(len(general_preds))]
groups_names = [group_name for group_name in groups.keys() if len(groups[group_name]) > 1]

for i, preds in enumerate(general_preds):
    if "INTELLECTUAL_DISABILITY" in preds:
        final_preds[i].append(cie10_map[groups["INTELLECTUAL_DISABILITY"][0]])
    if "UNSPECIFIED_MENTAL" in preds:
        final_preds[i].append(cie10_map[groups["UNSPECIFIED_MENTAL"][0]])

for group_name in groups_names:
    print(f"*** {group_name} ***")
    group_filter = [group_name in preds for preds in general_preds]
    group_model, group_mlb_classes = load_model(os.path.join(MODELS_DIR, f"Family_model_{group_name}.pt"), device)
    group_preds = predict_labels(group_model, group_mlb_classes, test_dl)

    for i, (pred, keep) in enumerate(zip(group_preds, group_filter)):
        if keep:
            final_preds[i].extend(pred)


    del group_model, group_mlb_classes
    torch.cuda.empty_cache()
    gc.collect()

y_pred = mlb.transform(final_preds)
test_f1_score = f1_score(y_test, y_pred, average="micro", zero_division=0)
test_precision = precision_score(y_test, y_pred, average="micro", zero_division=0)
test_recall = recall_score(y_test, y_pred, average="micro", zero_division=0)
print(f"🔹 Test F1 score:  {test_f1_score:.4f}")
print(f"🔹 Test Precision: {test_precision:.4f}")
print(f"🔹 Test Recall: {test_recall:.4f}")