In [5]:
%pip install joblib

import pandas as pd
import torch
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, classification_report
from tqdm.auto import tqdm
import joblib

# ====================
#  CONFIGURACIÓN
# ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class Config:
    USE_FP16 = False
    STRIDE = 128  # Overlap between windows
    MODEL_NAME = 'dccuchile/bert-base-spanish-wwm-cased'
    #MODEL_NAME = 'PlanTL-GOB-ES/bsc-bio-ehr-es'
    #MODEL_NAME = 'IIC/bsc-bio-ehr-es-caresA'
    #MODEL_NAME = 'PlanTL-GOB-ES/roberta-base-biomedical-clinical-es'
    #MODEL_NAME = 'IIC/bert-base-spanish-wwm-cased-ctebmsp'
    #MODEL_NAME = 'IIC/xlm-roberta-large-ehealth_kd'
    THRESHOLD_TUNING_INTERVAL = 3  # Cada cuántas épocas ajustar umbrales
    USE_FEATURE_PYRAMID = True # Usar Feature Pyramid Network
    FEATURE_LAYER_WEIGHTS = [0.1, 0.3, 0.6]  # Pesos para las últimas 3 capas
    CLASS_WEIGHT_SMOOTHING = 0.1  # Suavizado para pesos de clases
    EARLY_STOP_PATIENCE = 50  # Número de épocas sin mejora para parar
    IMPROVEMENT_MARGIN = 0.0005
    MAX_LENGTH = 512 # Máxima longitud de secuencia, definida en el BERT pre-entrenado
    TRAIN_BATCH_SIZE = 4
    VAL_BATCH_SIZE = 16
    TEST_BATCH_SIZE = 32
    EPOCHS = 1000
    GRADIENT_ACCUMULATION_STEPS = 1
    WARMUP_EPOCHS = 2
    HIERARCHICAL_WEIGHTS = {'parent': 1.5, 'child': 1.0}
    LEARNING_RATE = 3e-5 # 2e-2
    DATA_PATHS = {
        'train': 'codiesp_csvs/codiesp_D_source_train.csv',
        'test': 'codiesp_csvs/codiesp_D_source_test.csv',
        'val': 'codiesp_csvs/codiesp_D_source_validation.csv'
    }
    SAVE_TOKENIZER_PATH = 'snapshots/cie10_tokenizer'
    SAVE_PATH = 'snapshots/best_hierarchical_model'
    SAVE_STATE_PATH = 'snapshots/best_hierarchical_model_state.bin'
    SAVE_MLB_PARENT_PATH = 'snapshots/best_hierarchical_model_mlb_parent'
    SAVE_MLB_CHILD_PATH = 'snapshots/best_hierarchical_model_mlb_child'
    THRESHOLDS = {'parent': 0.041, 'child': 0.12}
    PRETRAIN_EPOCHS = 10
    PRETRAIN_BATCH_SIZE = 8
    PRETRAIN_DATA_PATH = '../csv_import_scripts/cie10-es-diagnoses-expanded.csv'
    FORCE_NEW_MODEL = True



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Using device: cuda


In [6]:
#  PREPROCESAMIENTO
# ====================
def parse_code(code):
    """Divide el código en niveles jerárquicos"""
    parts = code.split('.')
    hierarchy = []
    if len(parts) >= 1:
        parent = parts[0]  # Primera categoría (ej: S62)
        hierarchy.append(parent)

        if len(parts) >= 2:
            child_part = parts[1]
            child = f"{parent}.{child_part}"  # Segunda categoría (ej: S62.14S)
            hierarchy.append(child)

    return hierarchy

def calculate_mlb_classes():
    df = pd.read_csv(Config.PRETRAIN_DATA_PATH)
    # Try to load saved MLBs first
    try:
        mlb_parent = joblib.load(Config.SAVE_MLB_PARENT_PATH)
        mlb_child = joblib.load(Config.SAVE_MLB_CHILD_PATH)
        print("Loaded saved MLBs")
        return mlb_parent, mlb_child
    except:
        print("Creating new MLBs")

        # Construir jerarquía de códigos
        all_parents = set()
        all_children = set()

        for code in df['code']:
            levels = parse_code(code.strip())
            if len(levels) >= 1: all_parents.add(levels[0])
            if len(levels) >= 2: all_children.add(levels[1])

        # Inicializar MLB
        mlb_parent = MultiLabelBinarizer().fit([all_parents])
        mlb_child = MultiLabelBinarizer().fit([all_children])

        # Save MLBs
        joblib.dump(mlb_parent, Config.SAVE_MLB_PARENT_PATH)
        joblib.dump(mlb_child, Config.SAVE_MLB_CHILD_PATH)

    print(f"Padres: {len(mlb_parent.classes_)} - Hijos: {len(mlb_child.classes_)}")
    return mlb_parent, mlb_child

In [7]:
#  PLT DE MÉTRICAS
# ====================

def plot_metrics():
    # Load metrics
    metrics_history = pd.read_csv('training_metrics.csv')

    plt.figure(figsize=(10, 6))
    plt.plot(metrics_history['epoch'], metrics_history['loss'], label='Loss')
    plt.plot(metrics_history['epoch'], metrics_history['f1_micro'], label='F1 Micro')
    plt.plot(metrics_history['epoch'], metrics_history['f1_macro'], label='F1 Macro')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Training Metrics Over Time')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_metrics.png')
    plt.close()


In [None]:
#  MODELO JERÁRQUICO
# ====================

class HierarchicalBERT(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        try:
            self.bert = AutoModel.from_pretrained(Config.SAVE_PATH)
            print("Loaded saved BERT model")
        except:
            self.bert = AutoModel.from_pretrained(Config.MODEL_NAME)
            print("Using default BERT model")

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]

        # Clasificación padre
        parent_logits = self.parent_classifier(pooled)

        # Clasificación hijo con contexto de padres
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
# Función de pérdida
# ====================
def hierarchical_loss(parent_logits, child_logits,
                      parent_labels, child_labels):

    loss_parent = torch.nn.BCEWithLogitsLoss()(parent_logits, parent_labels)
    loss_child = torch.nn.BCEWithLogitsLoss()(child_logits, child_labels)

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

In [None]:
# v3 (Sliding Window)
# ====================

import os
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, AutoConfig
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_recall_curve
from tqdm import tqdm


# ====================
#  Modelo v3
# ====================
class HierarchicalBERTv2(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        config = AutoConfig.from_pretrained(Config.MODEL_NAME, output_hidden_states=True)
        self.bert = AutoModel.from_pretrained(Config.MODEL_NAME, config=config)

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)
        self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        if Config.USE_FEATURE_PYRAMID:
            # Combine last 3 layers' [CLS] embeddings
            hidden_states = outputs.hidden_states[-3:]  # Get last 3 layers
            # Stack [CLS] embeddings (shape: [3, batch_size, hidden_size])
            pooled = torch.stack([state[:, 0] for state in hidden_states])
            # Weighted combination of layers (weights should sum to 1)
            if sum(Config.FEATURE_LAYER_WEIGHTS) != 1:
                raise ValueError("FEATURE_LAYER_WEIGHTS must sum to 1")

            # Apply weights to layers
            pooled = torch.einsum('lbd,l->bd', pooled,
                                torch.tensor(Config.FEATURE_LAYER_WEIGHTS).to(pooled.device))
        else:
            pooled = outputs.last_hidden_state[:, 0, :]

        pooled = self.dropout(pooled)

        # Jerarquía de clasificación
        parent_logits = self.parent_classifier(pooled)
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
#  FUNCIÓN DE PÉRDIDA MEJORADA
# ====================
def hierarchical_lossv2(parent_logits, child_logits,
                     parent_labels, child_labels,
                     parent_weights, child_weights):

    loss_parent = F.binary_cross_entropy_with_logits(
        parent_logits,
        parent_labels,
        pos_weight=parent_weights
    )

    loss_child = F.binary_cross_entropy_with_logits(
        child_logits,
        child_labels,
        pos_weight=child_weights
    )

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

# ====================
#  AJUSTE DINÁMICO DE UMBRALES
# ====================
def calculate_optimal_thresholds(y_true, y_probs):
    thresholds = {}
    for i in range(y_probs.shape[1]):
        if np.sum(y_true[:, i]) > 0:  # Solo clases presentes
            precision, recall, threshs = precision_recall_curve(y_true[:, i], y_probs[:, i])
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_idx = np.nanargmax(f1_scores)
            thresholds[i] = threshs[best_idx]
    return thresholds

# ====================
#  DATASET v3 (Sliding Window)
# ====================
class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        # Same label processing as before
        for codes in df['labels'].apply(eval):
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        # Generate sliding windows for each text
        for idx, text in enumerate(self.texts):
            # Tokenize whole text
            tokens = self.tokenizer(
                text,
                truncation=False,
                return_offsets_mapping=True,
                add_special_tokens=False
            )

            # Generate sliding windows
            window_size = Config.MAX_LENGTH - 2  # Account for [CLS] and [SEP]
            stride = Config.STRIDE

            for i in range(0, len(tokens['input_ids']), stride):
                # Extract window
                window_start = i
                window_end = min(i + window_size, len(tokens['input_ids']))

                # Add special tokens
                input_ids = (
                    [self.tokenizer.cls_token_id] +
                    tokens['input_ids'][window_start:window_end] +
                    [self.tokenizer.sep_token_id]
                )

                attention_mask = [1] * len(input_ids)

                # Pad if necessary
                padding_length = Config.MAX_LENGTH - len(input_ids)
                if padding_length > 0:
                    input_ids += [self.tokenizer.pad_token_id] * padding_length
                    attention_mask += [0] * padding_length

                self.examples.append({
                    'input_ids': torch.tensor(input_ids),
                    'attention_mask': torch.tensor(attention_mask),
                    'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                    'child_labels': torch.FloatTensor(self.child_labels[idx]),
                    'text_id': idx  # To group windows later
                })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


# ====================
#  ENTRENAMIENTO con Sliding Window
# ====================
def train(best_thresholds=Config.THRESHOLDS):
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERTv2(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    if not Config.FORCE_NEW_MODEL:
        if os.path.exists(f"{Config.SAVE_PATH}_2"):
            model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
            print("Loaded best model - 2")
        elif os.path.exists(Config.SAVE_PATH):
            model.load_state_dict(torch.load(Config.SAVE_PATH))
            print("Loaded best model")
        else:
            print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)

    # Calcular pesos de clases
    parent_counts = np.sum(train_dataset.parent_labels, axis=0)
    parent_weights = (len(train_dataset) - parent_counts) / (parent_counts + Config.CLASS_WEIGHT_SMOOTHING)
    parent_weights = torch.tensor(parent_weights).to(device)

    child_counts = np.sum(train_dataset.child_labels, axis=0)
    child_weights = (len(train_dataset) - child_counts) / (child_counts + Config.CLASS_WEIGHT_SMOOTHING)
    child_weights = torch.tensor(child_weights).to(device)

    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0

        # Ajuste periódico de umbrales
        if (epoch + 1) % Config.THRESHOLD_TUNING_INTERVAL == 0:
            val_probs, val_labels = get_validation_probabilities(model, val_loader, device)

            # Calcular mejores umbrales por clase
            parent_thresholds = calculate_optimal_thresholds(
                val_labels['parent'], val_probs['parent']
            )
            child_thresholds = calculate_optimal_thresholds(
                val_labels['child'], val_probs['child']
            )

            # Actualizar umbrales globales
            best_thresholds['parent'] = np.mean(list(parent_thresholds.values()))
            best_thresholds['child'] = np.mean(list(child_thresholds.values()))
            print(f"Nuevos umbrales: Parent={best_thresholds['parent']:.3f}, Child={best_thresholds['child']:.3f}")

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        text_predictions = defaultdict(lambda: {'parent': [], 'child': []})
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                parent_logits, child_logits, _ = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask']
                )

                # Store predictions by original text
                for i, text_id in enumerate(batch['text_id'].cpu().numpy()):
                    text_predictions[text_id]['parent'].append(parent_logits[i])
                    text_predictions[text_id]['child'].append(child_logits[i])

                # Immediate window-level loss
                loss = hierarchical_lossv2(
                    parent_logits,
                    child_logits,
                    batch['parent_labels'],  # Use batch labels directly
                    batch['child_labels'],   # Not the dataset's labels
                    parent_weights,
                    child_weights
                )

            # Backpropagate
            scaler.scale(loss).backward()
            if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scheduler.step()
                optimizer.zero_grad()
                scaler.update()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # After epoch completes, calculate aggregated loss
        agg_loss = 0
        for text_id in text_predictions:
            if text_id >= len(train_dataset.parent_labels):
                continue  # Skip invalid text_ids

            # Aggregate predictions
            parent_agg = torch.stack(text_predictions[text_id]['parent']).max(dim=0)[0]
            child_agg = torch.stack(text_predictions[text_id]['child']).max(dim=0)[0]

            # Get true labels from dataset
            parent_label = torch.FloatTensor(train_dataset.parent_labels[text_id]).to(device)
            child_label = torch.FloatTensor(train_dataset.child_labels[text_id]).to(device)

            # Calculate aggregated loss
            agg_loss += hierarchical_lossv2(
                parent_agg.unsqueeze(0),
                child_agg.unsqueeze(0),
                parent_label.unsqueeze(0),
                child_label.unsqueeze(0),
                parent_weights,
                child_weights
            )

        # Combine losses
        total_loss += agg_loss.item() / len(text_predictions)
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")

        # Validación
        val_metrics = evaluate(model, val_loader, device, mlb_parent, mlb_child, best_thresholds)

        # Store metrics
        loss_metric = val_metrics['f1_macro']
        if epoch == 0:
            best_f1 = loss_metric

        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN):
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_3")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement,
            'parent_threshold': best_thresholds['parent'],
            'child_threshold': best_thresholds['child']
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")

# ====================
#  FUNCIONES AUXILIARES (Con sliding windows)
# ====================
def get_validation_probabilities(model, dataloader, device):
    model.eval()
    text_predictions = defaultdict(lambda: {'parent': [], 'child': []})
    parent_labels = {}
    child_labels = {}

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            text_ids = batch['text_id'].numpy()

            # Store labels by text_id
            for i, text_id in enumerate(text_ids):
                if text_id not in parent_labels:
                    parent_labels[text_id] = batch['parent_labels'][i].numpy()
                    child_labels[text_id] = batch['child_labels'][i].numpy()

            # Get predictions
            p_logits, c_logits, _ = model(**inputs)

            # Store predictions by text_id
            for i, text_id in enumerate(text_ids):
                text_predictions[text_id]['parent'].append(p_logits[i].cpu())
                text_predictions[text_id]['child'].append(c_logits[i].cpu())

    # Aggregate predictions per text
    parent_probs, child_probs = [], []
    final_parent_labels, final_child_labels = [], []

    for text_id in text_predictions:
        # Aggregate using max pooling (same as training)
        parent_agg = torch.stack(text_predictions[text_id]['parent']).max(dim=0)[0]
        child_agg = torch.stack(text_predictions[text_id]['child']).max(dim=0)[0]

        parent_probs.append(torch.sigmoid(parent_agg).numpy())
        child_probs.append(torch.sigmoid(child_agg).numpy())

        # Get original labels
        final_parent_labels.append(parent_labels[text_id])
        final_child_labels.append(child_labels[text_id])

    return {
        'parent': np.array(parent_probs),
        'child': np.array(child_probs)
    }, {
        'parent': np.array(final_parent_labels),
        'child': np.array(final_child_labels)
    }

# ====================
#  EVALUACIÓN Con sliding windows
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child, thresholds):
    val_probs, val_labels = get_validation_probabilities(model, dataloader, device)

    # Convert probabilities to predictions
    parent_preds = (val_probs['parent'] > thresholds['parent']).astype(int)
    child_preds = (val_probs['child'] > thresholds['child']).astype(int)

    # Print example comparison
    if len(val_labels['parent']) > 0:
        idx = 0  # First example
        parent_true = np.array(mlb_parent.classes_)[val_labels['parent'][idx].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds[idx].astype(bool)]

        child_true = np.array(mlb_child.classes_)[val_labels['child'][idx].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds[idx].astype(bool)]

        print("\nExample Validation Results:")
        print(f"Expected parent: {sorted(parent_true)}")
        print(f"Predicted parent: {sorted(parent_pred)}")
        print(f"Expected child: {sorted(child_true)}")
        print(f"Predicted child: {sorted(child_pred)}")

        common_parent = len(set(parent_true) & set(parent_pred))
        common_child = len(set(child_true) & set(child_pred))
        print(f"Parent Accuracy: {common_parent/len(parent_true):.2%} | Child Accuracy: {common_child/len(child_true):.2%}")

    # Calculate metrics
    metrics = {
        'f1_micro_parent': f1_score(val_labels['parent'], parent_preds, average='micro', zero_division=0),
        'f1_macro_parent': f1_score(val_labels['parent'], parent_preds, average='macro', zero_division=0),
        'f1_micro_child': f1_score(val_labels['child'], child_preds, average='micro', zero_division=0),
        'f1_macro_child': f1_score(val_labels['child'], child_preds, average='macro', zero_division=0)
    }

    # Weighted averages
    total_weight = sum(Config.HIERARCHICAL_WEIGHTS.values())
    metrics['f1_micro'] = (Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
                          Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']) / total_weight

    metrics['f1_macro'] = (Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
                          Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']) / total_weight

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device, thresholds):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > thresholds['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > thresholds['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    best_thresholds = Config.THRESHOLDS

    train(best_thresholds)

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.path.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.path.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device, best_thresholds)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


Token indices sequence length is longer than the specified maximum sequence length for this model (828 > 512). Running this sequence through the model will result in indexing errors


Loaded saved MLBs
Loaded saved tokenizer


Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 603/603 [01:17<00:00,  7.76it/s, loss=4.13, lr=1.5e-5] 


Epoch 1 | Loss: 3.4021 | LR: 1.50E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['A00', 'A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07', 'A08', 'A09', 'A15', 'A17', 'A18', 'A19', 'A20', 'A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A30', 'A31', 'A32', 'A33', 'A34', 'A35', 'A36', 'A37', 'A38', 'A39', 'A40', 'A41', 'A42', 'A43', 'A44', 'A46', 'A48', 'A49', 'A50', 'A51', 'A52', 'A53', 'A54', 'A55', 'A56', 'A57', 'A58', 'A59', 'A60', 'A63', 'A64', 'A65', 'A66', 'A67', 'A68', 'A69', 'A70', 'A71', 'A74', 'A75', 'A77', 'A78', 'A79', 'A80', 'A81', 'A82', 'A83', 'A84', 'A85', 'A86', 'A87', 'A88', 'A89', 'A90', 'A91', 'A92', 'A93', 'A94', 'A95', 'A96', 'A98', 'A99', 'B00', 'B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B09', 'B10', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', 'B25', 'B26', 'B27', 'B30', 'B33', 'B34', 'B35', 'B36', 'B37', 'B38', 'B39', 'B40', 'B41', 'B42', 'B43', 'B44', 

Epoch 2: 100%|██████████| 603/603 [01:17<00:00,  7.79it/s, loss=0.863, lr=3e-5]  


Epoch 2 | Loss: 1.9889 | LR: 3.00E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['A01', 'A02', 'A03', 'A04', 'A09', 'A15', 'A17', 'A18', 'A23', 'A24', 'A25', 'A28', 'A31', 'A32', 'A35', 'A38', 'A40', 'A41', 'A43', 'A44', 'A48', 'A49', 'A53', 'A54', 'A55', 'A60', 'A63', 'A64', 'A69', 'A74', 'A78', 'A79', 'A87', 'A90', 'A91', 'A93', 'B00', 'B01', 'B02', 'B05', 'B06', 'B07', 'B08', 'B10', 'B15', 'B17', 'B18', 'B19', 'B20', 'B25', 'B27', 'B30', 'B33', 'B34', 'B35', 'B37', 'B44', 'B45', 'B46', 'B48', 'B55', 'B57', 'B58', 'B59', 'B65', 'B67', 'B69', 'B74', 'B83', 'B91', 'B95', 'B96', 'B97', 'B99', 'C01', 'C02', 'C04', 'C07', 'C08', 'C10', 'C12', 'C15', 'C16', 'C17', 'C18', 'C21', 'C22', 'C23', 'C25', 'C26', 'C31', 'C34', 'C40', 'C41', 'C43', 'C44', 'C48', 'C49', 'C50', 'C52', 'C53', 'C54', 'C55', 'C56', 'C60', 'C61', 'C62', 'C63', 'C64', 'C65', 'C66', 'C67', 'C68', 'C69', 'C70', 'C71', 'C72', 'C73', 

Epoch 3: 100%|██████████| 603/603 [01:16<00:00,  7.88it/s, loss=0.789, lr=3e-5]


Epoch 3 | Loss: 1.1925 | LR: 3.00E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['A01', 'A23', 'A32', 'A41', 'A49', 'A53', 'A55', 'A63', 'A69', 'B02', 'B06', 'B10', 'B19', 'B20', 'B25', 'B37', 'B57', 'B58', 'B59', 'B83', 'B95', 'B96', 'B97', 'B99', 'C04', 'C15', 'C16', 'C17', 'C18', 'C22', 'C26', 'C44', 'C48', 'C49', 'C50', 'C55', 'C56', 'C61', 'C62', 'C63', 'C64', 'C67', 'C70', 'C73', 'C74', 'C77', 'C78', 'C79', 'C80', 'C91', 'C94', 'D10', 'D12', 'D16', 'D18', 'D24', 'D30', 'D35', 'D44', 'D49', 'D50', 'D61', 'D64', 'D65', 'D68', 'D69', 'D70', 'D72', 'D73', 'D75', 'D76', 'D84', 'E03', 'E04', 'E05', 'E10', 'E11', 'E21', 'E29', 'E46', 'E66', 'E78', 'E79', 'E80', 'E83', 'E85', 'E86', 'E87', 'E88', 'F10', 'F12', 'F14', 'F17', 'F29', 'F31', 'F43', 'F50', 'F70', 'G06', 'G12', 'G25', 'G40', 'G45', 'G54', 'G56', 'G72', 'G81', 'G82', 'G83', 'G89', 'G93', 'G95', 'G97', 'H02', 'H05', 'H10', 'H11', 'H17', 

Epoch 4: 100%|██████████| 603/603 [01:18<00:00,  7.71it/s, loss=0.752, lr=2.99e-5]


Epoch 4 | Loss: 0.8319 | LR: 2.99E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['A23', 'A41', 'A55', 'A63', 'B02', 'B19', 'B20', 'B95', 'B96', 'B99', 'C18', 'C22', 'C48', 'C50', 'C61', 'C63', 'C64', 'C67', 'C70', 'C73', 'C77', 'C78', 'C80', 'D11', 'D12', 'D16', 'D49', 'D50', 'D64', 'D70', 'D72', 'E03', 'E05', 'E10', 'E11', 'E29', 'E66', 'E79', 'E80', 'E85', 'F10', 'F17', 'G45', 'G83', 'G93', 'H17', 'H18', 'H30', 'I10', 'I12', 'I25', 'I26', 'I48', 'I50', 'I51', 'I85', 'I86', 'I89', 'I95', 'I96', 'J44', 'J45', 'J90', 'J98', 'K02', 'K04', 'K12', 'K31', 'K35', 'K38', 'K40', 'K44', 'K51', 'K52', 'K55', 'K59', 'K62', 'K63', 'K65', 'K66', 'K75', 'K76', 'K85', 'L53', 'L92', 'L98', 'M06', 'M19', 'M25', 'M32', 'M45', 'M46', 'M54', 'M79', 'M89', 'N13', 'N18', 'N19', 'N20', 'N23', 'N26', 'N28', 'N30', 'N31', 'N32', 'N36', 'N39', 'N40', 'N43', 'N44', 'N45', 'N48', 'N50', 'N64', 'N92', 'Q28', 'Q51', 'Q61', 

Epoch 5: 100%|██████████| 603/603 [01:17<00:00,  7.79it/s, loss=0.276, lr=2.99e-5]


Epoch 5 | Loss: 0.5969 | LR: 2.99E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'C48', 'C50', 'C64', 'C67', 'C78', 'D49', 'D72', 'E03', 'E11', 'E66', 'E79', 'E80', 'E85', 'F17', 'I10', 'I82', 'I85', 'I86', 'I96', 'J45', 'J98', 'K31', 'K40', 'K52', 'K59', 'K62', 'K63', 'K65', 'K66', 'L53', 'L98', 'M06', 'M19', 'M25', 'M54', 'M79', 'M89', 'N13', 'N18', 'N20', 'N28', 'N30', 'N32', 'N39', 'N40', 'Q61', 'Q64', 'R10', 'R14', 'R16', 'R19', 'R20', 'R22', 'R30', 'R31', 'R33', 'R35', 'R50', 'R52', 'R53', 'R59', 'R60', 'R61', 'R63', 'R68', 'R69', 'Z79', 'Z87', 'Z90', 'Z92', 'Z99']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.0', 'A15.4', 'A15.6', 'A15.9', 'A18.09', 'A18.10', 'A18.12', 'A18.15', 'A18.2', 'A18.4', 'A18.50', 'A23.9', 'A32.11', 'A32.9', 'A40.3', 'A41.9', 'A43.9', 'A

Epoch 6: 100%|██████████| 603/603 [01:16<00:00,  7.89it/s, loss=0.234, lr=2.99e-5]


Epoch 6 | Loss: 0.3623 | LR: 2.99E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E11', 'F17', 'I10', 'I86', 'I96', 'K52', 'K59', 'K63', 'L53', 'L98', 'M25', 'M54', 'N13', 'N18', 'N28', 'N32', 'N39', 'N50', 'R10', 'R19', 'R20', 'R31', 'R50', 'R52', 'R53', 'R59', 'R60', 'R63', 'R69', 'Z90']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.4', 'A15.9', 'A18.10', 'A18.12', 'A18.4', 'A23.9', 'A32.9', 'A40.3', 'A41.9', 'A53.9', 'A63.0', 'B02.29', 'B10.89', 'B18.2', 'B27.99', 'B45.9', 'B74.9', 'B95.4', 'B95.5', 'B95.61', 'B95.8', 'B96.1', 'B96.20', 'B96.3', 'B96.4', 'B96.5', 'B96.7', 'B99.9', 'C02.1', 'C04.9', 'C08.0', 'C15.9', 'C16.0', 'C16.9', 'C18.6', 'C18.9', 'C22.0', 'C34.90', 'C34.91', 'C40.20', 'C41.0', 'C44.49', 'C48.0', 'C48.2', 'C49.21', 'C49.5', 'C49.A2'

Epoch 7: 100%|██████████| 603/603 [01:17<00:00,  7.78it/s, loss=0.271, lr=2.98e-5]


Epoch 7 | Loss: 0.2988 | LR: 2.98E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E03', 'E11', 'F17', 'I10', 'I86', 'I96', 'K31', 'K52', 'K55', 'K59', 'K63', 'L53', 'L98', 'M54', 'N13', 'N18', 'N28', 'N32', 'R10', 'R19', 'R20', 'R30', 'R31', 'R50', 'R52', 'R59', 'R60', 'R63', 'R69', 'Z90']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.9', 'A18.10', 'A18.12', 'A40.3', 'A41.9', 'B02.29', 'B45.9', 'B95.4', 'B95.5', 'B95.61', 'B95.62', 'B96.1', 'B96.20', 'B96.3', 'B96.4', 'B96.5', 'B96.7', 'B99.9', 'C02.1', 'C04.9', 'C08.0', 'C16.9', 'C18.6', 'C34.91', 'C40.20', 'C41.0', 'C44.49', 'C48.0', 'C48.2', 'C50.312', 'C50.519', 'C50.91', 'C50.919', 'C53.9', 'C62.90', 'C62.91', 'C62.92', 'C63.10', 'C64.1', 'C64.2', 'C64.9', 'C67.7', 'C67.9', 'C70.9', 'C74.90', 'C75.0',

Epoch 8: 100%|██████████| 603/603 [01:17<00:00,  7.78it/s, loss=0.0866, lr=2.98e-5]


Epoch 8 | Loss: 0.2240 | LR: 2.98E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E03', 'E11', 'F17', 'I10', 'I86', 'I96', 'K52', 'K59', 'K63', 'L53', 'L98', 'M25', 'M54', 'N18', 'N28', 'N32', 'N39', 'Q64', 'R10', 'R19', 'R30', 'R31', 'R50', 'R52', 'R59', 'R60', 'R63', 'R69']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.9', 'A18.10', 'A18.12', 'A41.9', 'B95.4', 'B95.5', 'B95.61', 'B96.20', 'B96.3', 'B96.5', 'B96.7', 'B99.9', 'C02.1', 'C16.9', 'C18.6', 'C48.0', 'C48.2', 'C50.312', 'C50.919', 'C53.9', 'C62.90', 'C62.91', 'C62.92', 'C63.10', 'C64.1', 'C64.2', 'C64.9', 'C67.7', 'C67.9', 'C78.00', 'C78.1', 'C78.5', 'C78.6', 'C78.7', 'C79.02', 'C79.51', 'C79.52', 'C79.70', 'C80.0', 'D11.0', 'D12.1', 'D16.5', 'D23.11', 'D30.01', 'D30.3', 'D37.030', 'D40.10', 'D4

Epoch 9: 100%|██████████| 603/603 [01:17<00:00,  7.79it/s, loss=0.161, lr=2.98e-5] 


Epoch 9 | Loss: 0.1430 | LR: 2.98E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E11', 'F17', 'I10', 'I86', 'I96', 'K59', 'L53', 'L98', 'M54', 'N18', 'N28', 'N32', 'R10', 'R19', 'R31', 'R50', 'R52', 'R59', 'R60', 'R63', 'R69']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.9', 'A18.10', 'A18.12', 'A41.9', 'B95.4', 'B95.5', 'B95.61', 'B96.20', 'B96.3', 'B96.5', 'B99.9', 'C02.1', 'C41.1', 'C48.0', 'C48.2', 'C50.312', 'C50.919', 'C62.90', 'C62.91', 'C62.92', 'C63.10', 'C64.2', 'C64.9', 'C67.7', 'C67.9', 'C78.00', 'C78.6', 'C78.7', 'C79.51', 'C79.52', 'C79.70', 'D11.0', 'D16.5', 'D23.11', 'D30.01', 'D30.3', 'D37.030', 'D49.0', 'D49.2', 'D49.4', 'D49.511', 'D49.59', 'D49.89', 'D69.0', 'D72.0', 'D72.829', 'E03.9', 'E11.42', 'E11.9', 'E66.3', 'E78.00', 'E78.5', '

Epoch 10: 100%|██████████| 603/603 [01:17<00:00,  7.77it/s, loss=0.149, lr=2.98e-5] 


Epoch 10 | Loss: 0.1363 | LR: 2.98E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E03', 'E11', 'F17', 'I10', 'I86', 'I96', 'K31', 'K52', 'K59', 'K63', 'K75', 'L53', 'L98', 'M25', 'M32', 'M54', 'N13', 'N18', 'N28', 'N32', 'N39', 'R10', 'R11', 'R19', 'R31', 'R35', 'R50', 'R52', 'R53', 'R59', 'R60', 'R63', 'R69', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A04.5', 'A15.9', 'A18.10', 'A41.9', 'B95.4', 'B95.5', 'B95.61', 'B96.20', 'B96.3', 'B96.5', 'B99.9', 'C48.2', 'C62.92', 'C63.10', 'C64.9', 'C67.7', 'C67.9', 'C78.00', 'C78.6', 'C79.51', 'C79.70', 'C80.0', 'D16.5', 'D23.11', 'D30.01', 'D30.3', 'D37.030', 'D49.0', 'D49.4', 'D49.59', 'D49.89', 'D50.9', 'D72.0', 'D72.829', 'E03.9', 'E11.42', 'E11.9', 'E66.3', 'E78.00', 'E78.5', 'E79.0', 'E80.7', 'E85.9', 'F10.21',

Epoch 11: 100%|██████████| 603/603 [01:17<00:00,  7.76it/s, loss=0.0831, lr=2.97e-5]


Epoch 11 | Loss: 0.1085 | LR: 2.97E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'E11', 'I10', 'I86', 'I96', 'L53', 'L98', 'N18', 'N28', 'R10', 'R31', 'R50', 'R52', 'R60', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A41.9', 'B95.4', 'B95.5', 'B95.61', 'B96.20', 'B96.5', 'B99.9', 'C48.2', 'C64.9', 'C67.7', 'C67.9', 'C78.6', 'D11.0', 'D16.5', 'D23.11', 'D30.01', 'D30.3', 'D49.0', 'D49.2', 'D49.4', 'D49.59', 'D72.0', 'D72.829', 'E03.9', 'E11.42', 'E11.9', 'E66.3', 'E78.00', 'E78.5', 'F17.210', 'F17.290', 'I12.9', 'I80.8', 'I82.90', 'I86.1', 'I86.4', 'J98.4', 'J98.51', 'K08.109', 'K11.1', 'K11.9', 'K12.1', 'K12.2', 'K31.1', 'K31.89', 'K40.90', 'K44.9', 'K52.9', 'K57.32', 'K59.8', 'K76.81', 'K83.3', 'K85.20', 'L53.9', 'L92.9', 'L98.8', 'L98.9', 'M25.40', 'M25.50', 'M25.5

Epoch 12: 100%|██████████| 603/603 [01:16<00:00,  7.89it/s, loss=0.0605, lr=2.97e-5]


Epoch 12 | Loss: 0.0703 | LR: 2.97E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'I10', 'I86', 'L53', 'N18', 'N28', 'R10', 'R31', 'R50', 'R52', 'R60']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A15.9', 'B95.5', 'B95.61', 'B96.20', 'B96.3', 'B96.5', 'B99.9', 'C48.2', 'C62.92', 'C63.10', 'C64.9', 'C67.7', 'C67.9', 'D16.5', 'D23.11', 'D30.01', 'D30.3', 'D49.0', 'D49.2', 'D49.4', 'D49.511', 'D49.59', 'D72.0', 'D72.829', 'E11.42', 'E11.9', 'E66.3', 'E78.00', 'E78.5', 'F17.210', 'F17.290', 'G40.119', 'I12.9', 'I86.1', 'J98.51', 'K08.109', 'K11.1', 'K11.9', 'K12.2', 'K31.1', 'K44.9', 'K52.9', 'K59.8', 'K76.81', 'K83.3', 'K85.20', 'L53.9', 'L92.9', 'L98.8', 'M25.40', 'M25.50', 'M25.551', 'M26.601', 'M26.609', 'M32.9', 'M89.9', 'N13.30', 'N18.5', 'N18.9', 'N20.0', 'N20.1', 'N28.1'

Epoch 13: 100%|██████████| 603/603 [01:18<00:00,  7.70it/s, loss=0.124, lr=2.97e-5] 


Epoch 13 | Loss: 0.0710 | LR: 2.97E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'D72', 'E11', 'I10', 'I86', 'I96', 'K52', 'K59', 'L53', 'L98', 'N18', 'N28', 'R10', 'R50', 'R52', 'R60', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A15.9', 'A41.9', 'B95.5', 'B95.61', 'B96.20', 'B96.3', 'B96.5', 'B99.9', 'C48.2', 'C67.9', 'D23.11', 'D30.01', 'D49.0', 'D49.4', 'D49.59', 'D72.0', 'D72.829', 'E11.9', 'E66.3', 'E78.00', 'E78.5', 'F17.210', 'F17.290', 'I12.9', 'I82.90', 'I86.4', 'J98.51', 'K08.109', 'K11.1', 'K11.9', 'K12.1', 'K31.1', 'K31.89', 'K40.90', 'K52.9', 'K57.32', 'K59.00', 'K59.8', 'K75.9', 'K76.81', 'K83.3', 'K85.20', 'L53.9', 'L92.9', 'L98.8', 'L98.9', 'M19.90', 'M25.40', 'M25.50', 'M25.551', 'M26.601', 'M26.609', 'M32.9', 'M85.8', 'M89.9', 'N18.2', 'N18.5', 'N1

Epoch 14: 100%|██████████| 603/603 [01:18<00:00,  7.71it/s, loss=0.0806, lr=2.96e-5]


Epoch 14 | Loss: 0.0583 | LR: 2.96E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B19', 'B95', 'B96', 'B99', 'D72', 'E11', 'I10', 'I86', 'K59', 'L53', 'N18', 'N28', 'R10', 'R30', 'R31', 'R50', 'R52', 'R60', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A41.9', 'B95.4', 'B95.5', 'B95.61', 'B96.20', 'B96.5', 'B99.9', 'C48.2', 'C64.9', 'C67.7', 'C67.9', 'C78.6', 'C79.51', 'D49.0', 'D49.4', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F10.21', 'F17.210', 'F17.290', 'I82.90', 'I85.00', 'I86.4', 'J98.51', 'K31.1', 'K31.89', 'K40.90', 'K52.9', 'K59.00', 'K59.8', 'K75.9', 'K76.7', 'K76.81', 'K83.3', 'K85.20', 'L53.9', 'L92.9', 'L98.9', 'M25.40', 'M25.50', 'M26.601', 'M26.609', 'M32.9', 'M89.9', 'N18.5', 'N18.9', 'N20.1', 'N28.1', 'N28.82', 'N28.9', 'N32.0', 'N32.89', 'N32.9', 'N39.0', 'N45.3', 'N48.6', 'N48.8

Epoch 15: 100%|██████████| 603/603 [01:17<00:00,  7.81it/s, loss=0.0331, lr=2.96e-5]


Epoch 15 | Loss: 0.0372 | LR: 2.96E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'I10', 'I86', 'L53', 'N18', 'N28', 'R50', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B96.5', 'B99.9', 'C67.9', 'C79.51', 'D49.0', 'D49.4', 'D49.59', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F17.210', 'F17.290', 'J98.51', 'K31.1', 'K31.89', 'K59.8', 'K85.20', 'L53.9', 'M25.40', 'M25.50', 'M26.601', 'M26.609', 'N18.5', 'N18.9', 'N20.1', 'N28.1', 'N28.82', 'N28.89', 'N28.9', 'N32.0', 'N32.89', 'N32.9', 'N39.0', 'N48.89', 'N60.01', 'Q64.4', 'R10.13', 'R10.84', 'R30.0', 'R31.0', 'R31.9', 'R40.4', 'R50.9', 'R60.0', 'R63.4', 'Z53.31', 'Z87.891', 'Z90.710']
Parent Accuracy: 10.00% | Child Accuracy: 36.36%
Saving best model... 0.01798 -> 0.01859
F1 Validation | Micro: 0

Epoch 16: 100%|██████████| 603/603 [01:18<00:00,  7.73it/s, loss=0.0471, lr=2.96e-5]


Epoch 16 | Loss: 0.0388 | LR: 2.96E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'L53', 'N18', 'N28', 'R50', 'R52', 'R60']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.5', 'B95.61', 'B96.20', 'B96.5', 'B99.9', 'D49.0', 'D49.4', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F17.210', 'F17.290', 'J98.51', 'K11.9', 'K12.1', 'K52.9', 'K59.8', 'K85.20', 'L53.9', 'L92.9', 'M25.40', 'M26.601', 'M26.609', 'M32.9', 'M89.9', 'N18.5', 'N18.9', 'N20.1', 'N28.1', 'N28.82', 'N39.0', 'N48.6', 'N48.81', 'N48.89', 'N60.01', 'Q53.10', 'Q64.4', 'R30.0', 'R31.9', 'R40.4', 'R50.9', 'R60.0', 'R60.9', 'Z53.31', 'Z87.891', 'Z90.710']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
Saving best model... 0.01859 -> 0.01942
F1 Validation | Micro: 0.14526 | Macro: 0.01942 | Best: 0.01942 | Epochs without

Epoch 17: 100%|██████████| 603/603 [01:17<00:00,  7.75it/s, loss=0.0436, lr=2.95e-5]


Epoch 17 | Loss: 0.0322 | LR: 2.95E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D49', 'I10', 'I86', 'L53', 'N18', 'N28', 'R50', 'R52', 'R60']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B96.5', 'B99.9', 'D49.0', 'D49.4', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F17.210', 'F17.290', 'J98.51', 'K52.9', 'K59.8', 'K83.3', 'L53.9', 'M25.40', 'M26.601', 'M26.609', 'M32.9', 'M89.9', 'N18.5', 'N18.9', 'N20.1', 'N28.1', 'N28.82', 'N32.89', 'N32.9', 'N39.0', 'N48.89', 'Q53.10', 'Q64.4', 'R30.0', 'R31.0', 'R31.9', 'R50.9', 'R60.0', 'R60.9', 'Z53.31', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.15456 | Macro: 0.01881 | Best: 0.01942 | Epochs without improvement: 10
Nuevos umbrales: Parent=0.254, Child=0.241


Epoch 18: 100%|██████████| 603/603 [01:14<00:00,  8.09it/s, loss=0.0297, lr=2.95e-5] 


Epoch 18 | Loss: 0.0199 | LR: 2.95E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'I86', 'L53', 'N28', 'R50', 'R52']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B99.9', 'D49.0', 'D49.4', 'D72.0', 'D72.829', 'E78.5', 'F17.210', 'F17.290', 'J98.51', 'K59.00', 'K59.8', 'K85.20', 'L53.9', 'M26.601', 'M26.609', 'M32.9', 'N18.5', 'N18.9', 'N28.1', 'N28.82', 'N32.9', 'N39.0', 'Q64.4', 'R30.0', 'R31.0', 'R31.9', 'R50.9', 'R63.4', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.16789 | Macro: 0.01894 | Best: 0.01942 | Epochs without improvement: 11


Epoch 19: 100%|██████████| 603/603 [01:16<00:00,  7.93it/s, loss=0.0225, lr=2.95e-5] 


Epoch 19 | Loss: 0.0218 | LR: 2.95E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'I86', 'L53', 'N18', 'N28', 'N32', 'R30', 'R31', 'R52']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B96.5', 'B99.9', 'C67.9', 'D49.0', 'D49.4', 'D72.0', 'E11.9', 'E78.5', 'F17.290', 'J98.51', 'K59.8', 'K85.20', 'L53.9', 'L92.9', 'M19.90', 'M25.40', 'M26.609', 'N18.5', 'N18.9', 'N20.1', 'N28.1', 'N28.82', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'Q53.10', 'Q61.3', 'Q64.4', 'R30.0', 'R31.0', 'R31.9', 'R50.9', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.15850 | Macro: 0.01795 | Best: 0.01942 | Epochs without improvement: 12


Epoch 20: 100%|██████████| 603/603 [01:13<00:00,  8.16it/s, loss=0.0104, lr=2.95e-5] 


Epoch 20 | Loss: 0.0190 | LR: 2.95E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'I86', 'K59', 'L53', 'N18', 'N28', 'N32', 'Q64', 'R50', 'R52', 'R60', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['A15.9', 'B95.61', 'B96.20', 'B96.5', 'B99.9', 'D49.4', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F17.290', 'J98.51', 'K40.90', 'K44.9', 'K59.00', 'K59.8', 'K85.20', 'L53.9', 'M26.609', 'M32.9', 'N18.5', 'N18.9', 'N28.1', 'N28.82', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'Q61.3', 'Q64.4', 'R30.0', 'R31.9', 'R50.9', 'R60.0', 'R60.9', 'R63.4', 'Z87.891', 'Z90.710']
Parent Accuracy: 20.00% | Child Accuracy: 36.36%
F1 Validation | Micro: 0.15992 | Macro: 0.01935 | Best: 0.01942 | Epochs without improvement: 13
Nuevos umbrales: Parent=0.241, Child=0.213


Epoch 21: 100%|██████████| 603/603 [01:14<00:00,  8.14it/s, loss=0.00798, lr=2.94e-5]


Epoch 21 | Loss: 0.0110 | LR: 2.94E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B99', 'I10', 'L53', 'N18', 'N28', 'R50', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B99.9', 'D49.4', 'D72.0', 'D72.829', 'E78.5', 'F17.290', 'J98.51', 'K59.00', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N32.89', 'N32.9', 'N39.0', 'Q64.4', 'R30.0', 'R31.9', 'R50.9', 'R60.0', 'R63.4', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 27.27%
Saving best model... 0.01942 -> 0.02010
F1 Validation | Micro: 0.18508 | Macro: 0.02010 | Best: 0.02010 | Epochs without improvement: 13


Epoch 22: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00588, lr=2.94e-5]


Epoch 22 | Loss: 0.0121 | LR: 2.94E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'L53', 'N28', 'R50', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B96.5', 'B99.9', 'D49.4', 'D72.0', 'D72.829', 'E78.5', 'F17.290', 'J98.51', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N32.9', 'N39.0', 'Q64.4', 'R31.9', 'R50.9', 'R60.0', 'R63.4', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.16866 | Macro: 0.01954 | Best: 0.02010 | Epochs without improvement: 14


Epoch 23: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.0121, lr=2.94e-5] 


Epoch 23 | Loss: 0.0105 | LR: 2.94E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'I10', 'K59', 'L53', 'N18', 'N28', 'N32', 'R50', 'R52', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B96.5', 'B99.9', 'D49.4', 'D72.0', 'D72.829', 'E78.5', 'F17.290', 'J98.51', 'K59.00', 'K59.8', 'L53.9', 'N18.9', 'N28.82', 'N32.9', 'N39.0', 'Q64.4', 'R30.0', 'R31.9', 'R50.9', 'R60.0', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.17639 | Macro: 0.01967 | Best: 0.02010 | Epochs without improvement: 15
Nuevos umbrales: Parent=0.211, Child=0.174


Epoch 24: 100%|██████████| 603/603 [01:12<00:00,  8.34it/s, loss=0.00324, lr=2.93e-5]


Epoch 24 | Loss: 0.0060 | LR: 2.93E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B99', 'I10', 'K59', 'L53', 'N18', 'N28', 'R50', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'D49.4', 'D72.0', 'E78.5', 'F17.290', 'J98.51', 'K59.00', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N32.9', 'N39.0', 'Q64.4', 'R50.9', 'R60.0', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18938 | Macro: 0.01999 | Best: 0.02010 | Epochs without improvement: 16


Epoch 25: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00945, lr=2.93e-5]


Epoch 25 | Loss: 0.0071 | LR: 2.93E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'B99', 'D72', 'I10', 'K59', 'L53', 'N18', 'N28', 'N32', 'N39', 'Q64', 'R30', 'R31', 'R50', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.4', 'B95.61', 'B96.20', 'B96.5', 'B99.9', 'C67.9', 'C79.51', 'D49.4', 'D72.0', 'D72.829', 'E11.9', 'E78.5', 'F17.290', 'K44.9', 'K59.00', 'K59.8', 'K85.20', 'L53.9', 'L92.9', 'M32.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'N39.0', 'Q61.3', 'Q64.4', 'R18.8', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R63.0', 'R63.4', 'Z87.891', 'Z90.710']
Parent Accuracy: 40.00% | Child Accuracy: 36.36%
F1 Validation | Micro: 0.16880 | Macro: 0.01959 | Best: 0.02010 | Epochs without improvement: 17


Epoch 26: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00608, lr=2.93e-5]


Epoch 26 | Loss: 0.0072 | LR: 2.93E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B99', 'D11', 'D49', 'G83', 'I86', 'J98', 'K59', 'L53', 'N28', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.5', 'B95.61', 'B99.9', 'C78.7', 'C79.51', 'C80.0', 'D11.0', 'D49.0', 'E78.5', 'F17.290', 'G83.9', 'I85.00', 'I86.1', 'J98.51', 'K11.9', 'K59.00', 'K59.8', 'K76.9', 'K85.20', 'L53.9', 'L92.9', 'M32.9', 'M54.40', 'M89.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'Q61.3', 'Q64.4', 'R60.0', 'R63.0', 'R63.4', 'T85.41X', 'Z80.3', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.16182 | Macro: 0.01908 | Best: 0.02010 | Epochs without improvement: 18
Nuevos umbrales: Parent=0.231, Child=0.169


Epoch 27: 100%|██████████| 603/603 [01:12<00:00,  8.34it/s, loss=0.00303, lr=2.92e-5]


Epoch 27 | Loss: 0.0044 | LR: 2.92E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B99', 'D49', 'I10', 'I86', 'K31', 'K59', 'L53', 'N18', 'N28', 'R50', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'C79.51', 'D11.0', 'D49.59', 'E78.5', 'F17.290', 'I85.00', 'I86.1', 'J98.51', 'K59.00', 'K59.8', 'L53.9', 'L92.9', 'M32.9', 'N13.30', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'N39.0', 'Q61.3', 'Q64.4', 'R19.00', 'R30.0', 'R31.9', 'R50.9', 'R63.0', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 36.36%
F1 Validation | Micro: 0.18245 | Macro: 0.01984 | Best: 0.02010 | Epochs without improvement: 19


Epoch 28: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00801, lr=2.92e-5]


Epoch 28 | Loss: 0.0043 | LR: 2.92E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'K59', 'L53', 'N18', 'N28', 'R50', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'C79.51', 'D49.4', 'D72.0', 'E78.5', 'J98.51', 'K59.00', 'K59.8', 'L53.9', 'L92.9', 'M32.9', 'N18.9', 'N28.1', 'N28.89', 'N32.9', 'N39.0', 'Q64.4', 'R30.0', 'R50.9', 'R60.0', 'R63.0', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.18167 | Macro: 0.01970 | Best: 0.02010 | Epochs without improvement: 20


Epoch 29: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00171, lr=2.92e-5]


Epoch 29 | Loss: 0.0035 | LR: 2.92E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I85', 'I86', 'K59', 'L53', 'N18', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'E78.5', 'F17.290', 'I85.00', 'K59.00', 'K59.8', 'L53.9', 'M32.9', 'N18.9', 'N39.0', 'Q64.4', 'R30.0', 'R63.0', 'R63.4', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18581 | Macro: 0.01950 | Best: 0.02010 | Epochs without improvement: 21
Nuevos umbrales: Parent=0.201, Child=0.143


Epoch 30: 100%|██████████| 603/603 [01:12<00:00,  8.34it/s, loss=0.000538, lr=2.92e-5]


Epoch 30 | Loss: 0.0019 | LR: 2.92E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'K59', 'L53', 'N18', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'K59.00', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'R30.0', 'R50.9', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.19562 | Macro: 0.01978 | Best: 0.02010 | Epochs without improvement: 22


Epoch 31: 100%|██████████| 603/603 [01:13<00:00,  8.22it/s, loss=0.00072, lr=2.91e-5] 


Epoch 31 | Loss: 0.0025 | LR: 2.91E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'K59', 'L53', 'N28', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B99.9', 'E78.5', 'I85.00', 'K31.89', 'K59.00', 'K59.8', 'L53.9', 'M32.9', 'N18.9', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R18.8', 'R30.0', 'R63.0', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18364 | Macro: 0.01965 | Best: 0.02010 | Epochs without improvement: 23


Epoch 32: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00287, lr=2.91e-5] 


Epoch 32 | Loss: 0.0055 | LR: 2.91E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I86', 'K31', 'K59', 'L53', 'N28', 'R50', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'I86.1', 'J98.51', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R11.10', 'R30.0', 'R31.9', 'R50.9', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.14879 | Macro: 0.01657 | Best: 0.02010 | Epochs without improvement: 24
Nuevos umbrales: Parent=0.231, Child=0.162


Epoch 33: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.00202, lr=2.91e-5] 


Epoch 33 | Loss: 0.0026 | LR: 2.91E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I86', 'L53', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'Q64.4', 'R11.10', 'R63.4']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18551 | Macro: 0.01792 | Best: 0.02010 | Epochs without improvement: 25


Epoch 34: 100%|██████████| 603/603 [01:13<00:00,  8.24it/s, loss=0.00148, lr=2.9e-5]  


Epoch 34 | Loss: 0.0019 | LR: 2.90E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'K59', 'L53', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'E78.5', 'F17.290', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R11.10', 'R63.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18352 | Macro: 0.01839 | Best: 0.02010 | Epochs without improvement: 26


Epoch 35: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00101, lr=2.9e-5] 


Epoch 35 | Loss: 0.0015 | LR: 2.90E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I86', 'K59', 'L53', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18786 | Macro: 0.01910 | Best: 0.02010 | Epochs without improvement: 27
Nuevos umbrales: Parent=0.198, Child=0.134


Epoch 36: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.000924, lr=2.9e-5]


Epoch 36 | Loss: 0.0008 | LR: 2.90E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'L53', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'L53.9', 'N18.9', 'N28.89', 'Q64.4']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.19256 | Macro: 0.01890 | Best: 0.02010 | Epochs without improvement: 28


Epoch 37: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000347, lr=2.89e-5]


Epoch 37 | Loss: 0.0011 | LR: 2.89E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'K59', 'L53', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B99.9', 'E78.5', 'K59.8', 'L53.9', 'N18.9', 'N28.1', 'N28.89', 'Q64.4', 'R11.10', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18844 | Macro: 0.01903 | Best: 0.02010 | Epochs without improvement: 29


Epoch 38: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00067, lr=2.89e-5] 


Epoch 38 | Loss: 0.0009 | LR: 2.89E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'K59', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'F17.290', 'I85.00', 'K59.8', 'N18.9', 'N28.89', 'N32.9', 'Q64.4', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18767 | Macro: 0.01854 | Best: 0.02010 | Epochs without improvement: 30
Nuevos umbrales: Parent=0.192, Child=0.123


Epoch 39: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.000549, lr=2.89e-5]


Epoch 39 | Loss: 0.0005 | LR: 2.89E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I86', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['N18.9', 'N28.89', 'Q64.4']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.19397 | Macro: 0.01835 | Best: 0.02010 | Epochs without improvement: 31


Epoch 40: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000393, lr=2.89e-5]


Epoch 40 | Loss: 0.0007 | LR: 2.89E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I10', 'I86', 'K31', 'K59', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['D49.4', 'D49.59', 'F17.290', 'K59.8', 'N18.9', 'N28.89', 'N32.9', 'Q64.4', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18861 | Macro: 0.01865 | Best: 0.02010 | Epochs without improvement: 32


Epoch 41: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00113, lr=2.88e-5] 


Epoch 41 | Loss: 0.0029 | LR: 2.88E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'E78', 'I10', 'L53', 'N18', 'N28', 'N32', 'Q64', 'R31', 'R50', 'R52', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.4', 'B95.61', 'B96.20', 'D49.4', 'D49.59', 'E11.9', 'E13.9', 'E78.5', 'F17.290', 'K59.8', 'L53.9', 'L92.9', 'N13.30', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'N39.0', 'Q61.3', 'Q64.4', 'R31.9', 'R50.9', 'R60.0', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.17254 | Macro: 0.01793 | Best: 0.02010 | Epochs without improvement: 33
Nuevos umbrales: Parent=0.228, Child=0.135


Epoch 42: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.000943, lr=2.88e-5]


Epoch 42 | Loss: 0.0005 | LR: 2.88E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'E78', 'I86', 'L53', 'N18', 'N28', 'N32', 'Q64', 'R31', 'R50', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'D49.4', 'D49.59', 'E78.5', 'F17.290', 'L53.9', 'L92.9', 'N13.30', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R31.9', 'R50.9', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18664 | Macro: 0.01864 | Best: 0.02010 | Epochs without improvement: 34


Epoch 43: 100%|██████████| 603/603 [01:13<00:00,  8.26it/s, loss=0.000515, lr=2.88e-5]


Epoch 43 | Loss: 0.0007 | LR: 2.88E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'L53', 'N18', 'N28', 'Q64', 'R50', 'R52', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'D49.4', 'D49.59', 'E78.5', 'F17.290', 'K59.8', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N39.0', 'Q64.4', 'R31.9', 'R50.9', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 27.27%
F1 Validation | Micro: 0.18488 | Macro: 0.01887 | Best: 0.02010 | Epochs without improvement: 35


Epoch 44: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000233, lr=2.87e-5]


Epoch 44 | Loss: 0.0005 | LR: 2.87E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'L53', 'N18', 'N28', 'Q64', 'R50', 'R63', 'Z87']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'D49.4', 'D49.59', 'E78.5', 'F17.290', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R19.00', 'R31.9', 'R50.9', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18778 | Macro: 0.01884 | Best: 0.02010 | Epochs without improvement: 36
Nuevos umbrales: Parent=0.218, Child=0.129


Epoch 45: 100%|██████████| 603/603 [01:12<00:00,  8.36it/s, loss=0.000599, lr=2.87e-5]


Epoch 45 | Loss: 0.0003 | LR: 2.87E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'L53', 'N18', 'N28', 'Q64', 'R50', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'D49.59', 'E78.5', 'F17.290', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R31.9', 'R50.9', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.19363 | Macro: 0.01870 | Best: 0.02010 | Epochs without improvement: 37


Epoch 46: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00024, lr=2.87e-5] 


Epoch 46 | Loss: 0.0004 | LR: 2.87E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'L53', 'N18', 'N28', 'Q64', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'D49.59', 'E78.5', 'F17.290', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'Q64.4', 'R19.00', 'R50.9', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.19025 | Macro: 0.01891 | Best: 0.02010 | Epochs without improvement: 38


Epoch 47: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000433, lr=2.86e-5]


Epoch 47 | Loss: 0.0021 | LR: 2.86E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'K31', 'K59', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'D49.59', 'E78.5', 'I85.00', 'I86.1', 'L53.9', 'L92.9', 'N17.9', 'N18.9', 'N28.89', 'Q64.4', 'R19.00', 'R31.9', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.16928 | Macro: 0.01836 | Best: 0.02010 | Epochs without improvement: 39
Nuevos umbrales: Parent=0.235, Child=0.136


Epoch 48: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.000341, lr=2.86e-5]


Epoch 48 | Loss: 0.0004 | LR: 2.86E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'D49', 'I86', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['D49.59', 'E78.5', 'I85.00', 'I86.1', 'L53.9', 'N18.9', 'N28.89', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.17813 | Macro: 0.01853 | Best: 0.02010 | Epochs without improvement: 40


Epoch 49: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000171, lr=2.86e-5]


Epoch 49 | Loss: 0.0005 | LR: 2.86E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I86', 'K59', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['I86.1', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'Q64.4', 'R19.00', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.17754 | Macro: 0.01928 | Best: 0.02010 | Epochs without improvement: 41


Epoch 50: 100%|██████████| 603/603 [01:13<00:00,  8.26it/s, loss=0.000515, lr=2.86e-5]


Epoch 50 | Loss: 0.0004 | LR: 2.86E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I86', 'K59', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['E78.5', 'F17.290', 'I86.1', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'N32.9', 'R19.00', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18005 | Macro: 0.01943 | Best: 0.02010 | Epochs without improvement: 42
Nuevos umbrales: Parent=0.214, Child=0.122


Epoch 51: 100%|██████████| 603/603 [01:12<00:00,  8.36it/s, loss=0.000101, lr=2.85e-5]


Epoch 51 | Loss: 0.0002 | LR: 2.85E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I86', 'K59', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['F17.290', 'I86.1', 'L53.9', 'N18.9', 'N28.1', 'N28.82', 'N28.89', 'R19.00', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18589 | Macro: 0.01887 | Best: 0.02010 | Epochs without improvement: 43


Epoch 52: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.00037, lr=2.85e-5] 


Epoch 52 | Loss: 0.0003 | LR: 2.85E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['I86', 'K59', 'L53', 'N18', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['F17.290', 'L53.9', 'N18.9', 'N28.89', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18306 | Macro: 0.01976 | Best: 0.02010 | Epochs without improvement: 44


Epoch 53: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000358, lr=2.85e-5]


Epoch 53 | Loss: 0.0002 | LR: 2.85E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I86', 'K59', 'L53', 'N18', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['F17.290', 'I86.1', 'L53.9', 'N18.9', 'N28.1', 'N28.89', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18278 | Macro: 0.01944 | Best: 0.02010 | Epochs without improvement: 45
Nuevos umbrales: Parent=0.212, Child=0.120


Epoch 54: 100%|██████████| 603/603 [01:12<00:00,  8.35it/s, loss=0.000336, lr=2.84e-5]


Epoch 54 | Loss: 0.0001 | LR: 2.84E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I86', 'L53', 'N18', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['F17.290', 'I86.1', 'L53.9', 'N18.9', 'N28.89', 'R31.9', 'R63.4']
Parent Accuracy: 10.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18761 | Macro: 0.01869 | Best: 0.02010 | Epochs without improvement: 46


Epoch 55: 100%|██████████| 603/603 [01:12<00:00,  8.26it/s, loss=0.00011, lr=2.84e-5] 


Epoch 55 | Loss: 0.0002 | LR: 2.84E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['D49', 'I86', 'K59', 'L53', 'N18', 'N28', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['F17.290', 'L53.9', 'N18.9', 'N28.89', 'N32.9', 'R31.9', 'R63.0', 'R63.4']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.18593 | Macro: 0.01996 | Best: 0.02010 | Epochs without improvement: 47


Epoch 56: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000367, lr=2.84e-5]


Epoch 56 | Loss: 0.0062 | LR: 2.84E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'B96', 'C64', 'D49', 'D72', 'E11', 'E78', 'E79', 'I10', 'I85', 'I86', 'K59', 'N18', 'N28', 'Q61', 'R52', 'R60', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'C64.1', 'C64.9', 'C79.51', 'D72.829', 'E11.9', 'E78.5', 'E79.0', 'F17.290', 'I50.9', 'I85.00', 'I86.1', 'K40.90', 'N17.9', 'N18.9', 'N28.1', 'N28.89', 'N32.9', 'N39.0', 'Q44.6', 'Q61.3', 'Q64.4', 'R10.10', 'R19.00', 'R53.1', 'R60.0', 'R60.9', 'R63.0', 'Z87.891']
Parent Accuracy: 20.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.13555 | Macro: 0.01688 | Best: 0.02010 | Epochs without improvement: 48
Nuevos umbrales: Parent=0.273, Child=0.160


Epoch 57: 100%|██████████| 603/603 [01:12<00:00,  8.36it/s, loss=0.000244, lr=2.83e-5]


Epoch 57 | Loss: 0.0005 | LR: 2.83E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'K59', 'N18', 'N28', 'N32', 'Q64', 'R31', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'C79.51', 'D72.829', 'E11.9', 'E78.5', 'F17.290', 'L53.9', 'L92.9', 'N17.9', 'N18.9', 'N28.1', 'N28.89', 'N32.9', 'Q64.4', 'R18.8', 'R19.00', 'R31.9', 'R50.9', 'R63.0']
Parent Accuracy: 30.00% | Child Accuracy: 18.18%
F1 Validation | Micro: 0.15910 | Macro: 0.01729 | Best: 0.02010 | Epochs without improvement: 49


Epoch 58: 100%|██████████| 603/603 [01:13<00:00,  8.25it/s, loss=0.000151, lr=2.83e-5]


Epoch 58 | Loss: 0.0007 | LR: 2.83E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['B95.61', 'B96.20', 'B99.9', 'C79.51', 'E11.9', 'E78.5', 'F17.290', 'I86.1', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.89', 'N32.9', 'Q64.4', 'R50.9', 'R63.0']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.17684 | Macro: 0.01816 | Best: 0.02010 | Epochs without improvement: 50


Epoch 59: 100%|██████████| 603/603 [01:13<00:00,  8.24it/s, loss=9.78e-5, lr=2.83e-5] 


Epoch 59 | Loss: 0.0002 | LR: 2.83E-05

Example Validation Results:
Expected parent: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent: ['B95', 'I10', 'I86', 'L53', 'N18', 'N28', 'R52', 'R63']
Expected child: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child: ['E78.5', 'F17.290', 'I86.1', 'L53.9', 'L92.9', 'N18.9', 'N28.1', 'N28.89', 'N32.9', 'Q64.4', 'R63.0', 'Z87.891']
Parent Accuracy: 10.00% | Child Accuracy: 9.09%
F1 Validation | Micro: 0.18129 | Macro: 0.01867 | Best: 0.02010 | Epochs without improvement: 51
Early stopping at epoch 60
Loaded saved MLBs
Loaded saved tokenizer


Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using default BERT model


AttributeError: module 'os' has no attribute 'exists'

In [None]:
# V2 modelo jerárquico
# # ====================

import os
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, AutoConfig
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_recall_curve
from tqdm import tqdm


# ====================
#  Modelo v2
# ====================
class HierarchicalBERTv2(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        config = AutoConfig.from_pretrained(Config.MODEL_NAME, output_hidden_states=True)
        self.bert = AutoModel.from_pretrained(Config.MODEL_NAME, config=config)

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)
        self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        if Config.USE_FEATURE_PYRAMID:
            # Combine last 3 layers' [CLS] embeddings
            hidden_states = outputs.hidden_states[-3:]  # Get last 3 layers
            # Stack [CLS] embeddings (shape: [3, batch_size, hidden_size])
            pooled = torch.stack([state[:, 0] for state in hidden_states])
            # Weighted combination of layers (weights should sum to 1)
            pooled = torch.einsum('lbd,l->bd', pooled,
                                torch.tensor(Config.FEATURE_LAYER_WEIGHTS).to(pooled.device))
        else:
            pooled = outputs.last_hidden_state[:, 0, :]

        pooled = self.dropout(pooled)

        # Jerarquía de clasificación
        parent_logits = self.parent_classifier(pooled)
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
#  FUNCIÓN DE PÉRDIDA MEJORADA
# ====================
def hierarchical_lossv2(parent_logits, child_logits,
                     parent_labels, child_labels,
                     parent_weights, child_weights):

    loss_parent = F.binary_cross_entropy_with_logits(
        parent_logits,
        parent_labels,
        pos_weight=parent_weights
    )

    loss_child = F.binary_cross_entropy_with_logits(
        child_logits,
        child_labels,
        pos_weight=child_weights
    )

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

# ====================
#  AJUSTE DINÁMICO DE UMBRALES
# ====================
def calculate_optimal_thresholds(y_true, y_probs):
    thresholds = {}
    for i in range(y_probs.shape[1]):
        if np.sum(y_true[:, i]) > 0:  # Solo clases presentes
            precision, recall, threshs = precision_recall_curve(y_true[:, i], y_probs[:, i])
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_idx = np.nanargmax(f1_scores)
            thresholds[i] = threshs[best_idx]
    return thresholds

# ====================
#  Dataset
# ====================
class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        for codes in df['labels'].apply(eval): # FIXME: Unsafe eval
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        for idx in range(len(self.texts)):
            encoding = self.tokenizer(
                self.texts[idx],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            self.examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                'child_labels': torch.FloatTensor(self.child_labels[idx]),
            })

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.examples[idx]

# ====================
#  ENTRENAMIENTO
# ====================
def train(best_thresholds=Config.THRESHOLDS):
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERTv2(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    if os.path.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.path.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")
    else:
        print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)

    # Calcular pesos de clases
    parent_counts = np.sum(train_dataset.parent_labels, axis=0)
    parent_weights = (len(train_dataset) - parent_counts) / (parent_counts + Config.CLASS_WEIGHT_SMOOTHING)
    parent_weights = torch.tensor(parent_weights).to(device)

    child_counts = np.sum(train_dataset.child_labels, axis=0)
    child_weights = (len(train_dataset) - child_counts) / (child_counts + Config.CLASS_WEIGHT_SMOOTHING)
    child_weights = torch.tensor(child_weights).to(device)

    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0

        # Ajuste periódico de umbrales
        if (epoch + 1) % Config.THRESHOLD_TUNING_INTERVAL == 0:
            val_probs, val_labels = get_validation_probabilities(model, val_loader, device)

            # Calcular mejores umbrales por clase
            parent_thresholds = calculate_optimal_thresholds(
                val_labels['parent'], val_probs['parent']
            )
            child_thresholds = calculate_optimal_thresholds(
                val_labels['child'], val_probs['child']
            )

            # Actualizar umbrales globales
            best_thresholds['parent'] = np.mean(list(parent_thresholds.values()))
            best_thresholds['child'] = np.mean(list(child_thresholds.values()))
            print(f"Nuevos umbrales: Parent={best_thresholds['parent']:.3f}, Child={best_thresholds['child']:.3f}")

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                outputs = model(**inputs)

                # Línea corregida
                loss = hierarchical_loss(
                    outputs[0],  # parent_logits
                    outputs[1],  # child_logits
                    parent_labels,
                    child_labels
                )

                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()

                if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scheduler.step()
                    optimizer.zero_grad()
                    scaler.update()

                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # Validación
        val_metrics = evaluate(model, val_loader, device, mlb_parent, mlb_child, best_thresholds)

        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")
        # Store metrics

        loss_metric = val_metrics['f1_macro']
        if epoch == 0:
            best_f1 = loss_metric

        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN):
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_3")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement,
            'parent_threshold': best_thresholds['parent'],
            'child_threshold': best_thresholds['child']
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")

# ====================
#  FUNCIONES AUXILIARES
# ====================
def get_validation_probabilities(model, dataloader, device):
    model.eval()
    parent_probs, child_probs = [], []
    parent_labels, child_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            p_logits, c_logits, _ = model(**inputs)

            parent_probs.append(torch.sigmoid(p_logits).cpu().numpy())
            child_probs.append(torch.sigmoid(c_logits).cpu().numpy())

            parent_labels.append(batch['parent_labels'].numpy())
            child_labels.append(batch['child_labels'].numpy())

    return {
        'parent': np.concatenate(parent_probs),
        'child': np.concatenate(child_probs)
    }, {
        'parent': np.concatenate(parent_labels),
        'child': np.concatenate(child_labels)
    }


# ====================
#  EVALUACIÓN
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child, thresholds):
    model.eval()
    parent_preds_all = []
    child_preds_all = []
    parent_labels_all = []
    child_labels_all = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            # Get labels
            parent_labels = batch['parent_labels'].numpy()
            child_labels = batch['child_labels'].numpy()

            parent_logits, child_logits, pooled = model(**inputs)

            # Convert to predictions
            parent_preds = (torch.sigmoid(parent_logits).cpu().numpy() > thresholds['parent']).astype(int)
            child_preds = (torch.sigmoid(child_logits).cpu().numpy() > thresholds['child']).astype(int)

            # Append to lists
            parent_preds_all.extend(parent_preds)
            child_preds_all.extend(child_preds)
            parent_labels_all.extend(parent_labels)
            child_labels_all.extend(child_labels)

    # Convert to numpy arrays
    parent_preds_all = np.array(parent_preds_all)
    child_preds_all = np.array(child_preds_all)
    parent_labels_all = np.array(parent_labels_all)
    child_labels_all = np.array(child_labels_all)

    # Print example comparison for parent level
    if len(parent_labels_all) > 0:
        parent_true = np.array(mlb_parent.classes_)[parent_labels_all[0].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds_all[0].astype(bool)]
        common_labels = len(set(parent_true) & set(parent_pred))
        total_labels = len(set(parent_true))
        accuracy_parent = common_labels / total_labels if total_labels > 0 else 0

        child_true = np.array(mlb_child.classes_)[child_labels_all[0].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds_all[0].astype(bool)]
        common_labels = len(set(child_true) & set(child_pred))
        total_labels = len(set(child_true))
        accuracy_child = common_labels / total_labels if total_labels > 0 else 0

        print("Expected parent labels:", sorted(parent_true))
        print("Predicted parent labels:", sorted(parent_pred))
        print("Expected child labels:", sorted(child_true))
        print("Predicted child labels:", sorted(child_pred))

        print(f"Percentage of correct parent labels: {accuracy_parent:.2%} | {accuracy_child:.2%}")

    # Calculate F1 scores for each level
    metrics = {
        'f1_micro_parent': f1_score(parent_labels_all, parent_preds_all, average='micro' , zero_division=0),
        'f1_macro_parent': f1_score(parent_labels_all, parent_preds_all, average='macro', zero_division=0),
        'f1_micro_child': f1_score(child_labels_all, child_preds_all, average='micro', zero_division=0),
        'f1_macro_child': f1_score(child_labels_all, child_preds_all, average='macro', zero_division=0)
    }

    # Calculate weighted average F1 scores
    metrics['f1_micro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    metrics['f1_macro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device, thresholds):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > thresholds['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > thresholds['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    best_thresholds = Config.THRESHOLDS

    train(best_thresholds)

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.path.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.path.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device, best_thresholds)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


In [None]:
# Hierarchical - V1
# ====================

import os

class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        for codes in df['labels'].apply(eval): # FIXME: Unsafe eval
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        for idx in range(len(self.texts)):
            encoding = self.tokenizer(
                self.texts[idx],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            self.examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                'child_labels': torch.FloatTensor(self.child_labels[idx]),
            })

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.examples[idx]

# ====================
#  ENTRENAMIENTO
# ====================
def train():
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    try:
        model.load_state_dict(torch.load(Config.SAVE_STATE_PATH))
        print("Loaded previously saved best model")
    except:
        print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)
    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                outputs = model(**inputs)

                # Línea corregida
                loss = hierarchical_loss(
                    outputs[0],  # parent_logits
                    outputs[1],  # child_logits
                    parent_labels,
                    child_labels
                )

                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()

                if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scheduler.step()
                    optimizer.zero_grad()
                    scaler.update()

                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # Validación
        val_metrics = evaluate(model, val_loader, device,
                              mlb_parent, mlb_child)
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")
        # Store metrics

        loss_metric = val_metrics['f1_macro']
        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN) and epoch > 0:
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_2")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")


# ====================
#  EVALUACIÓN
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child):
    model.eval()
    parent_preds_all = []
    child_preds_all = []
    parent_labels_all = []
    child_labels_all = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            # Get labels
            parent_labels = batch['parent_labels'].numpy()
            child_labels = batch['child_labels'].numpy()

            parent_logits, child_logits = model(**inputs)

            # Convert to predictions
            parent_preds = (torch.sigmoid(parent_logits).cpu().numpy() > Config.THRESHOLDS['parent']).astype(int)
            child_preds = (torch.sigmoid(child_logits).cpu().numpy() > Config.THRESHOLDS['child']).astype(int)

            # Append to lists
            parent_preds_all.extend(parent_preds)
            child_preds_all.extend(child_preds)
            parent_labels_all.extend(parent_labels)
            child_labels_all.extend(child_labels)

    # Convert to numpy arrays
    parent_preds_all = np.array(parent_preds_all)
    child_preds_all = np.array(child_preds_all)
    parent_labels_all = np.array(parent_labels_all)
    child_labels_all = np.array(child_labels_all)

    # Print example comparison for parent level
    if len(parent_labels_all) > 0:
        parent_true = np.array(mlb_parent.classes_)[parent_labels_all[0].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds_all[0].astype(bool)]
        common_labels = len(set(parent_true) & set(parent_pred))
        total_labels = len(set(parent_true))
        accuracy_parent = common_labels / total_labels if total_labels > 0 else 0

        child_true = np.array(mlb_child.classes_)[child_labels_all[0].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds_all[0].astype(bool)]
        common_labels = len(set(child_true) & set(child_pred))
        total_labels = len(set(child_true))
        accuracy_child = common_labels / total_labels if total_labels > 0 else 0

        print("Expected parent labels:", sorted(parent_true))
        print("Predicted parent labels:", sorted(parent_pred))
        print("Expected child labels:", sorted(child_true))
        print("Predicted child labels:", sorted(child_pred))

        print(f"Percentage of correct parent labels: {accuracy_parent:.2%} | {accuracy_child:.2%}")

    # Calculate F1 scores for each level
    metrics = {
        'f1_micro_parent': f1_score(parent_labels_all, parent_preds_all, average='micro' , zero_division=0),
        'f1_macro_parent': f1_score(parent_labels_all, parent_preds_all, average='macro', zero_division=0),
        'f1_micro_child': f1_score(child_labels_all, child_preds_all, average='micro', zero_division=0),
        'f1_macro_child': f1_score(child_labels_all, child_preds_all, average='macro', zero_division=0)
    }

    # Calculate weighted average F1 scores
    metrics['f1_micro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    metrics['f1_macro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > Config.THRESHOLDS['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > Config.THRESHOLDS['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    train()

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.exists(f"{Config.SAVE_PATH}2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.exists(Config.SAVE_PATH"):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


In [None]:
#  V0 Entrenamiento con cie10 dataset
# ====================

# Dataset especializado para pre-entrenamiento
class PretrainDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.tokenizer = tokenizer
        self.examples = []
        self.mlb_parent = mlb_parent
        self.mlb_child = mlb_child

        for _, row in df.iterrows():
            code = row['code'].strip()
            desc = row['description'].strip()
            levels = parse_code(code)

            # Generar múltiples variantes textuales
            # Base variant is the description itself
            variants = []

            # Extract text inside parentheses and brackets
            parentheses_matches = re.findall(r'\((.*?)\)', desc)
            bracket_matches = re.findall(r'\[(.*?)\]', desc)

            # Get text outside parentheses and brackets
            clean_text = re.sub(r'\([^)]*\)|\[[^\]]*\]', '', desc).strip()
            if clean_text != desc:
                variants.append(clean_text)

            # Add matches from parentheses and brackets
            variants.extend(parentheses_matches)
            variants.extend(bracket_matches)

            for variant in variants:
                if len(variant) > 5:
                    self.examples.append({
                        'text': variant,
                        'levels': levels
                    })
                else:
                    print(f"Skipping short variant: {variant}")

        tokenized_examples = []
        for example in self.examples:
            encoding = self.tokenizer(
                example['text'],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            levels = example['levels']

            tokenized_examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.mlb_parent.transform([[levels[0]]] if levels else [[]])[0]),
                'child_labels': torch.FloatTensor(self.mlb_child.transform([[levels[1]]] if len(levels)>1 else [[]])[0])
            })
        self.examples = tokenized_examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

def pretrain():
    # Cargar datos de códigos CIE10
    df = pd.read_csv(Config.PRETRAIN_DATA_PATH)

    mlb_parent, mlb_child = calculate_mlb_classes()

    # Inicializar componentes
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
    dataset = PretrainDataset(df, tokenizer, mlb_parent, mlb_child)
    dataloader = DataLoader(dataset, batch_size=Config.PRETRAIN_BATCH_SIZE, shuffle=True)

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)

    # Bucle de pre-entrenamiento (similar al entrenamiento normal)
    for epoch in range(Config.PRETRAIN_EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Pre-train Epoch {epoch+1}")

        for batch in progress_bar:
            optimizer.zero_grad()

            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            outputs = model(**inputs)

            loss = hierarchical_loss(
                outputs[0], outputs[1],
                parent_labels, child_labels
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

        print(f"Pre-train Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}")

    # Guardar modelo pre-entrenado
    torch.save(model.state_dict(), Config.SAVE_PATH)
    tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
    print(f"Modelo pre-entrenado guardado en {Config.SAVE_PATH}")

pretrain()