In [2]:
%pip install joblib

import pandas as pd
import torch
import re
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, classification_report
from tqdm.auto import tqdm
import joblib

# ====================
#  CONFIGURACIÓN
# ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

class Config:
    USE_FP16 = False
    STRIDE = 128  # Overlap between windows
    MODEL_NAME = 'dccuchile/bert-base-spanish-wwm-cased'
    #MODEL_NAME = 'PlanTL-GOB-ES/bsc-bio-ehr-es'
    #MODEL_NAME = 'IIC/bsc-bio-ehr-es-caresA'
    #MODEL_NAME = 'PlanTL-GOB-ES/roberta-base-biomedical-clinical-es'
    #MODEL_NAME = 'IIC/bert-base-spanish-wwm-cased-ctebmsp'
    #MODEL_NAME = 'IIC/xlm-roberta-large-ehealth_kd'
    THRESHOLD_TUNING_INTERVAL = 3  # Cada cuántas épocas ajustar umbrales
    USE_FEATURE_PYRAMID = True # Usar Feature Pyramid Network
    FEATURE_LAYER_WEIGHTS = [0.1, 0.3, 0.6]  # Pesos para las últimas 3 capas
    CLASS_WEIGHT_SMOOTHING = 0.1  # Suavizado para pesos de clases
    EARLY_STOP_PATIENCE = 50  # Número de épocas sin mejora para parar
    IMPROVEMENT_MARGIN = 0.0005
    MAX_LENGTH = 512 # Máxima longitud de secuencia, definida en el BERT pre-entrenado
    TRAIN_BATCH_SIZE = 4
    VAL_BATCH_SIZE = 16
    TEST_BATCH_SIZE = 32
    EPOCHS = 1000
    GRADIENT_ACCUMULATION_STEPS = 1
    WARMUP_EPOCHS = 2
    HIERARCHICAL_WEIGHTS = {'parent': 1.5, 'child': 1.0}
    LEARNING_RATE = 3e-5 # 2e-2
    DATA_PATHS = {
        'train': 'codiesp_csvs/codiesp_D_source_train.csv',
        'test': 'codiesp_csvs/codiesp_D_source_test.csv',
        'val': 'codiesp_csvs/codiesp_D_source_validation.csv'
    }
    SAVE_TOKENIZER_PATH = 'snapshots/cie10_tokenizer'
    SAVE_PATH = 'snapshots/best_hierarchical_model'
    SAVE_STATE_PATH = 'snapshots/best_hierarchical_model_state.bin'
    SAVE_MLB_PARENT_PATH = 'snapshots/best_hierarchical_model_mlb_parent'
    SAVE_MLB_CHILD_PATH = 'snapshots/best_hierarchical_model_mlb_child'
    THRESHOLDS = {'parent': 0.041, 'child': 0.12}
    PRETRAIN_EPOCHS = 10
    PRETRAIN_BATCH_SIZE = 8
    PRETRAIN_DATA_PATH = '../csv_import_scripts/cie10-es-diagnoses-expanded.csv'
    FORCE_NEW_MODEL = True



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Using device: cuda


In [3]:
#  PREPROCESAMIENTO
# ====================
def parse_code(code):
    """Divide el código en niveles jerárquicos"""
    parts = code.split('.')
    hierarchy = []
    if len(parts) >= 1:
        parent = parts[0]  # Primera categoría (ej: S62)
        hierarchy.append(parent)

        if len(parts) >= 2:
            child_part = parts[1]
            child = f"{parent}.{child_part}"  # Segunda categoría (ej: S62.14S)
            hierarchy.append(child)

    return hierarchy

def calculate_mlb_classes():
    df = pd.read_csv(Config.PRETRAIN_DATA_PATH)
    # Try to load saved MLBs first
    try:
        mlb_parent = joblib.load(Config.SAVE_MLB_PARENT_PATH)
        mlb_child = joblib.load(Config.SAVE_MLB_CHILD_PATH)
        print("Loaded saved MLBs")
        return mlb_parent, mlb_child
    except:
        print("Creating new MLBs")

        # Construir jerarquía de códigos
        all_parents = set()
        all_children = set()

        for code in df['code']:
            levels = parse_code(code.strip())
            if len(levels) >= 1: all_parents.add(levels[0])
            if len(levels) >= 2: all_children.add(levels[1])

        # Inicializar MLB
        mlb_parent = MultiLabelBinarizer().fit([all_parents])
        mlb_child = MultiLabelBinarizer().fit([all_children])

        # Save MLBs
        joblib.dump(mlb_parent, Config.SAVE_MLB_PARENT_PATH)
        joblib.dump(mlb_child, Config.SAVE_MLB_CHILD_PATH)

    print(f"Padres: {len(mlb_parent.classes_)} - Hijos: {len(mlb_child.classes_)}")
    return mlb_parent, mlb_child

In [4]:
#  PLT DE MÉTRICAS
# ====================

def plot_metrics():
    # Load metrics
    metrics_history = pd.read_csv('training_metrics.csv')

    plt.figure(figsize=(10, 6))
    plt.plot(metrics_history['epoch'], metrics_history['loss'], label='Loss')
    plt.plot(metrics_history['epoch'], metrics_history['f1_micro'], label='F1 Micro')
    plt.plot(metrics_history['epoch'], metrics_history['f1_macro'], label='F1 Macro')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Training Metrics Over Time')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_metrics.png')
    plt.close()


In [5]:
#  MODELO JERÁRQUICO
# ====================

class HierarchicalBERT(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        try:
            self.bert = AutoModel.from_pretrained(Config.SAVE_PATH)
            print("Loaded saved BERT model")
        except:
            self.bert = AutoModel.from_pretrained(Config.MODEL_NAME)
            print("Using default BERT model")

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]

        # Clasificación padre
        parent_logits = self.parent_classifier(pooled)

        # Clasificación hijo con contexto de padres
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
# Función de pérdida
# ====================
def hierarchical_loss(parent_logits, child_logits,
                      parent_labels, child_labels):

    loss_parent = torch.nn.BCEWithLogitsLoss()(parent_logits, parent_labels)
    loss_child = torch.nn.BCEWithLogitsLoss()(child_logits, child_labels)

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

In [None]:
# v3 (Sliding Window)
# ====================

import os
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, AutoConfig
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_recall_curve
from tqdm import tqdm


# ====================
#  Modelo v3
# ====================
class HierarchicalBERTv2(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        config = AutoConfig.from_pretrained(Config.MODEL_NAME, output_hidden_states=True)
        self.bert = AutoModel.from_pretrained(Config.MODEL_NAME, config=config)

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)
        self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        if Config.USE_FEATURE_PYRAMID:
            # Combine last 3 layers' [CLS] embeddings
            hidden_states = outputs.hidden_states[-3:]  # Get last 3 layers
            # Stack [CLS] embeddings (shape: [3, batch_size, hidden_size])
            pooled = torch.stack([state[:, 0] for state in hidden_states])
            # Weighted combination of layers (weights should sum to 1)
            if sum(Config.FEATURE_LAYER_WEIGHTS) != 1:
                raise ValueError("FEATURE_LAYER_WEIGHTS must sum to 1")

            # Apply weights to layers
            pooled = torch.einsum('lbd,l->bd', pooled,
                                torch.tensor(Config.FEATURE_LAYER_WEIGHTS).to(pooled.device))
        else:
            pooled = outputs.last_hidden_state[:, 0, :]

        pooled = self.dropout(pooled)

        # Jerarquía de clasificación
        parent_logits = self.parent_classifier(pooled)
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
#  FUNCIÓN DE PÉRDIDA MEJORADA
# ====================
def hierarchical_lossv2(parent_logits, child_logits,
                     parent_labels, child_labels,
                     parent_weights, child_weights):

    loss_parent = F.binary_cross_entropy_with_logits(
        parent_logits,
        parent_labels,
        pos_weight=parent_weights
    )

    loss_child = F.binary_cross_entropy_with_logits(
        child_logits,
        child_labels,
        pos_weight=child_weights
    )

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

# ====================
#  AJUSTE DINÁMICO DE UMBRALES
# ====================
def calculate_optimal_thresholds(y_true, y_probs):
    thresholds = {}
    for i in range(y_probs.shape[1]):
        if np.sum(y_true[:, i]) > 0:  # Solo clases presentes
            precision, recall, threshs = precision_recall_curve(y_true[:, i], y_probs[:, i])
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_idx = np.nanargmax(f1_scores)
            thresholds[i] = threshs[best_idx]
    return thresholds

# ====================
#  DATASET v3 (Sliding Window)
# ====================
class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        # Same label processing as before
        for codes in df['labels'].apply(eval):
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        # Generate sliding windows for each text
        for idx, text in enumerate(self.texts):
            # Tokenize whole text
            tokens = self.tokenizer(
                text,
                truncation=False,
                return_offsets_mapping=True,
                add_special_tokens=False
            )

            # Generate sliding windows
            window_size = Config.MAX_LENGTH - 2  # Account for [CLS] and [SEP]
            stride = Config.STRIDE

            for i in range(0, len(tokens['input_ids']), stride):
                # Extract window
                window_start = i
                window_end = min(i + window_size, len(tokens['input_ids']))

                # Add special tokens
                input_ids = (
                    [self.tokenizer.cls_token_id] +
                    tokens['input_ids'][window_start:window_end] +
                    [self.tokenizer.sep_token_id]
                )

                attention_mask = [1] * len(input_ids)

                # Pad if necessary
                padding_length = Config.MAX_LENGTH - len(input_ids)
                if padding_length > 0:
                    input_ids += [self.tokenizer.pad_token_id] * padding_length
                    attention_mask += [0] * padding_length

                self.examples.append({
                    'input_ids': torch.tensor(input_ids),
                    'attention_mask': torch.tensor(attention_mask),
                    'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                    'child_labels': torch.FloatTensor(self.child_labels[idx]),
                    'text_id': idx  # To group windows later
                })

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


# ====================
#  ENTRENAMIENTO con Sliding Window
# ====================
def train(best_thresholds=Config.THRESHOLDS):
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERTv2(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    if not Config.FORCE_NEW_MODEL:
        if os.path.exists(f"{Config.SAVE_PATH}_2"):
            model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
            print("Loaded best model - 2")
        elif os.path.exists(Config.SAVE_PATH):
            model.load_state_dict(torch.load(Config.SAVE_PATH))
            print("Loaded best model")
        else:
            print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)

    # Calcular pesos de clases
    parent_counts = np.sum(train_dataset.parent_labels, axis=0)
    parent_weights = (len(train_dataset) - parent_counts) / (parent_counts + Config.CLASS_WEIGHT_SMOOTHING)
    parent_weights = torch.tensor(parent_weights).to(device)

    child_counts = np.sum(train_dataset.child_labels, axis=0)
    child_weights = (len(train_dataset) - child_counts) / (child_counts + Config.CLASS_WEIGHT_SMOOTHING)
    child_weights = torch.tensor(child_weights).to(device)

    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0

        # Ajuste periódico de umbrales
        if (epoch + 1) % Config.THRESHOLD_TUNING_INTERVAL == 0:
            val_probs, val_labels = get_validation_probabilities(model, val_loader, device)

            # Calcular mejores umbrales por clase
            parent_thresholds = calculate_optimal_thresholds(
                val_labels['parent'], val_probs['parent']
            )
            child_thresholds = calculate_optimal_thresholds(
                val_labels['child'], val_probs['child']
            )

            # Actualizar umbrales globales
            best_thresholds['parent'] = np.mean(list(parent_thresholds.values()))
            best_thresholds['child'] = np.mean(list(child_thresholds.values()))
            print(f"Nuevos umbrales: Parent={best_thresholds['parent']:.3f}, Child={best_thresholds['child']:.3f}")

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        text_predictions = defaultdict(lambda: {'parent': [], 'child': []})
        for step, batch in enumerate(progress_bar):
            batch = {k: v.to(device) for k, v in batch.items()}

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                parent_logits, child_logits, _ = model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask']
                )

                # Store predictions by original text
                for i, text_id in enumerate(batch['text_id'].cpu().numpy()):
                    text_predictions[text_id]['parent'].append(parent_logits[i])
                    text_predictions[text_id]['child'].append(child_logits[i])

                # Immediate window-level loss
                loss = hierarchical_lossv2(
                    parent_logits,
                    child_logits,
                    batch['parent_labels'],  # Use batch labels directly
                    batch['child_labels'],   # Not the dataset's labels
                    parent_weights,
                    child_weights
                )

            # Backpropagate
            scaler.scale(loss).backward()
            if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scheduler.step()
                optimizer.zero_grad()
                scaler.update()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # After epoch completes, calculate aggregated loss
        agg_loss = 0
        for text_id in text_predictions:
            if text_id >= len(train_dataset.parent_labels):
                continue  # Skip invalid text_ids

            # Aggregate predictions
            parent_agg = torch.stack(text_predictions[text_id]['parent']).max(dim=0)[0]
            child_agg = torch.stack(text_predictions[text_id]['child']).max(dim=0)[0]

            # Get true labels from dataset
            parent_label = torch.FloatTensor(train_dataset.parent_labels[text_id]).to(device)
            child_label = torch.FloatTensor(train_dataset.child_labels[text_id]).to(device)

            # Calculate aggregated loss
            agg_loss += hierarchical_lossv2(
                parent_agg.unsqueeze(0),
                child_agg.unsqueeze(0),
                parent_label.unsqueeze(0),
                child_label.unsqueeze(0),
                parent_weights,
                child_weights
            )

        # Combine losses
        total_loss += agg_loss.item() / len(text_predictions)
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")

        # Validación
        val_metrics = evaluate(model, val_loader, device, mlb_parent, mlb_child, best_thresholds)

        # Store metrics
        loss_metric = val_metrics['f1_macro']
        if epoch == 0:
            best_f1 = loss_metric

        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN):
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_3")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement,
            'parent_threshold': best_thresholds['parent'],
            'child_threshold': best_thresholds['child']
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")

# ====================
#  FUNCIONES AUXILIARES (Con sliding windows)
# ====================
def get_validation_probabilities(model, dataloader, device):
    model.eval()
    text_predictions = defaultdict(lambda: {'parent': [], 'child': []})
    parent_labels = {}
    child_labels = {}

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            text_ids = batch['text_id'].numpy()

            # Store labels by text_id
            for i, text_id in enumerate(text_ids):
                if text_id not in parent_labels:
                    parent_labels[text_id] = batch['parent_labels'][i].numpy()
                    child_labels[text_id] = batch['child_labels'][i].numpy()

            # Get predictions
            p_logits, c_logits, _ = model(**inputs)

            # Store predictions by text_id
            for i, text_id in enumerate(text_ids):
                text_predictions[text_id]['parent'].append(p_logits[i].cpu())
                text_predictions[text_id]['child'].append(c_logits[i].cpu())

    # Aggregate predictions per text
    parent_probs, child_probs = [], []
    final_parent_labels, final_child_labels = [], []

    for text_id in text_predictions:
        # Aggregate using max pooling (same as training)
        parent_agg = torch.stack(text_predictions[text_id]['parent']).max(dim=0)[0]
        child_agg = torch.stack(text_predictions[text_id]['child']).max(dim=0)[0]

        parent_probs.append(torch.sigmoid(parent_agg).numpy())
        child_probs.append(torch.sigmoid(child_agg).numpy())

        # Get original labels
        final_parent_labels.append(parent_labels[text_id])
        final_child_labels.append(child_labels[text_id])

    return {
        'parent': np.array(parent_probs),
        'child': np.array(child_probs)
    }, {
        'parent': np.array(final_parent_labels),
        'child': np.array(final_child_labels)
    }

# ====================
#  EVALUACIÓN Con sliding windows
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child, thresholds):
    val_probs, val_labels = get_validation_probabilities(model, dataloader, device)

    # Convert probabilities to predictions
    parent_preds = (val_probs['parent'] > thresholds['parent']).astype(int)
    child_preds = (val_probs['child'] > thresholds['child']).astype(int)

    # Print example comparison
    if len(val_labels['parent']) > 0:
        idx = 0  # First example
        parent_true = np.array(mlb_parent.classes_)[val_labels['parent'][idx].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds[idx].astype(bool)]

        child_true = np.array(mlb_child.classes_)[val_labels['child'][idx].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds[idx].astype(bool)]

        print("\nExample Validation Results:")
        print(f"Expected parent: {sorted(parent_true)}")
        print(f"Predicted parent: {sorted(parent_pred)}")
        print(f"Expected child: {sorted(child_true)}")
        print(f"Predicted child: {sorted(child_pred)}")

        common_parent = len(set(parent_true) & set(parent_pred))
        common_child = len(set(child_true) & set(child_pred))
        print(f"Parent Accuracy: {common_parent/len(parent_true):.2%} | Child Accuracy: {common_child/len(child_true):.2%}")

    # Calculate metrics
    metrics = {
        'f1_micro_parent': f1_score(val_labels['parent'], parent_preds, average='micro', zero_division=0),
        'f1_macro_parent': f1_score(val_labels['parent'], parent_preds, average='macro', zero_division=0),
        'f1_micro_child': f1_score(val_labels['child'], child_preds, average='micro', zero_division=0),
        'f1_macro_child': f1_score(val_labels['child'], child_preds, average='macro', zero_division=0)
    }

    # Weighted averages
    total_weight = sum(Config.HIERARCHICAL_WEIGHTS.values())
    metrics['f1_micro'] = (Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
                          Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']) / total_weight

    metrics['f1_macro'] = (Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
                          Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']) / total_weight

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device, thresholds):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > thresholds['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > thresholds['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    best_thresholds = Config.THRESHOLDS

    train(best_thresholds)

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device, best_thresholds)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


Token indices sequence length is longer than the specified maximum sequence length for this model (828 > 512). Running this sequence through the model will result in indexing errors


Loaded saved MLBs
Loaded saved tokenizer


Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1:  22%|██▏       | 135/603 [00:17<01:01,  7.61it/s, loss=3.46, lr=3.36e-6]


KeyboardInterrupt: 

In [None]:
#  V0 Preentrenamiento con cie10 dataset
# ====================

# Dataset especializado para pre-entrenamiento
class PretrainDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.tokenizer = tokenizer
        self.examples = []
        self.mlb_parent = mlb_parent
        self.mlb_child = mlb_child

        for _, row in df.iterrows():
            code = row['code'].strip()
            desc = row['description'].strip()
            levels = parse_code(code)

            # Generar múltiples variantes textuales
            # Base variant is the description itself
            variants = []

            # Extract text inside parentheses and brackets
            parentheses_matches = re.findall(r'\((.*?)\)', desc)
            bracket_matches = re.findall(r'\[(.*?)\]', desc)

            # Get text outside parentheses and brackets
            clean_text = re.sub(r'\([^)]*\)|\[[^\]]*\]', '', desc).strip()
            if clean_text != desc:
                variants.append(clean_text)

            # Add matches from parentheses and brackets
            variants.extend(parentheses_matches)
            variants.extend(bracket_matches)

            for variant in variants:
                if len(variant) > 5:
                    self.examples.append({
                        'text': variant,
                        'levels': levels
                    })
                else:
                    print(f"Skipping short variant: {variant}")

        tokenized_examples = []
        for example in self.examples:
            encoding = self.tokenizer(
                example['text'],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            levels = example['levels']

            tokenized_examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.mlb_parent.transform([[levels[0]]] if levels else [[]])[0]),
                'child_labels': torch.FloatTensor(self.mlb_child.transform([[levels[1]]] if len(levels)>1 else [[]])[0])
            })
        self.examples = tokenized_examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

def pretrain():
    # Cargar datos de códigos CIE10
    df = pd.read_csv(Config.PRETRAIN_DATA_PATH)

    mlb_parent, mlb_child = calculate_mlb_classes()

    # Inicializar componentes
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
    dataset = PretrainDataset(df, tokenizer, mlb_parent, mlb_child)
    dataloader = DataLoader(dataset, batch_size=Config.PRETRAIN_BATCH_SIZE, shuffle=True)

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)

    # Bucle de pre-entrenamiento (similar al entrenamiento normal)
    for epoch in range(Config.PRETRAIN_EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Pre-train Epoch {epoch+1}")

        for batch in progress_bar:
            optimizer.zero_grad()

            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            outputs = model(**inputs)

            loss = hierarchical_loss(
                outputs[0], outputs[1],
                parent_labels, child_labels
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

        print(f"Pre-train Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}")

    # Guardar modelo pre-entrenado
    torch.save(model.state_dict(), Config.SAVE_PATH)
    tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
    print(f"Modelo pre-entrenado guardado en {Config.SAVE_PATH}")

pretrain()

NameError: name 'Config' is not defined

In [5]:
# Hierarchical - V1
# ====================

import os

class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        for codes in df['labels'].apply(eval): # FIXME: Unsafe eval
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        for idx in range(len(self.texts)):
            encoding = self.tokenizer(
                self.texts[idx],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            self.examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                'child_labels': torch.FloatTensor(self.child_labels[idx]),
            })

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.examples[idx]

# ====================
#  ENTRENAMIENTO
# ====================
def train():
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    try:
        model.load_state_dict(torch.load(Config.SAVE_STATE_PATH))
        print("Loaded previously saved best model")
    except:
        print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)
    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                outputs = model(**inputs)

                # Línea corregida
                loss = hierarchical_loss(
                    outputs[0],  # parent_logits
                    outputs[1],  # child_logits
                    parent_labels,
                    child_labels
                )

                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()

                if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scheduler.step()
                    optimizer.zero_grad()
                    scaler.update()

                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # Validación
        val_metrics = evaluate(model, val_loader, device,
                              mlb_parent, mlb_child)
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")
        # Store metrics

        loss_metric = val_metrics['f1_macro']
        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN) and epoch > 0:
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_2")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")


# ====================
#  EVALUACIÓN
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child):
    model.eval()
    parent_preds_all = []
    child_preds_all = []
    parent_labels_all = []
    child_labels_all = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            # Get labels
            parent_labels = batch['parent_labels'].numpy()
            child_labels = batch['child_labels'].numpy()

            parent_logits, child_logits = model(**inputs)

            # Convert to predictions
            parent_preds = (torch.sigmoid(parent_logits).cpu().numpy() > Config.THRESHOLDS['parent']).astype(int)
            child_preds = (torch.sigmoid(child_logits).cpu().numpy() > Config.THRESHOLDS['child']).astype(int)

            # Append to lists
            parent_preds_all.extend(parent_preds)
            child_preds_all.extend(child_preds)
            parent_labels_all.extend(parent_labels)
            child_labels_all.extend(child_labels)

    # Convert to numpy arrays
    parent_preds_all = np.array(parent_preds_all)
    child_preds_all = np.array(child_preds_all)
    parent_labels_all = np.array(parent_labels_all)
    child_labels_all = np.array(child_labels_all)

    # Print example comparison for parent level
    if len(parent_labels_all) > 0:
        parent_true = np.array(mlb_parent.classes_)[parent_labels_all[0].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds_all[0].astype(bool)]
        common_labels = len(set(parent_true) & set(parent_pred))
        total_labels = len(set(parent_true))
        accuracy_parent = common_labels / total_labels if total_labels > 0 else 0

        child_true = np.array(mlb_child.classes_)[child_labels_all[0].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds_all[0].astype(bool)]
        common_labels = len(set(child_true) & set(child_pred))
        total_labels = len(set(child_true))
        accuracy_child = common_labels / total_labels if total_labels > 0 else 0

        print("Expected parent labels:", sorted(parent_true))
        print("Predicted parent labels:", sorted(parent_pred))
        print("Expected child labels:", sorted(child_true))
        print("Predicted child labels:", sorted(child_pred))

        print(f"Percentage of correct parent labels: {accuracy_parent:.2%} | {accuracy_child:.2%}")

    # Calculate F1 scores for each level
    metrics = {
        'f1_micro_parent': f1_score(parent_labels_all, parent_preds_all, average='micro' , zero_division=0),
        'f1_macro_parent': f1_score(parent_labels_all, parent_preds_all, average='macro', zero_division=0),
        'f1_micro_child': f1_score(child_labels_all, child_preds_all, average='micro', zero_division=0),
        'f1_macro_child': f1_score(child_labels_all, child_preds_all, average='macro', zero_division=0)
    }

    # Calculate weighted average F1 scores
    metrics['f1_micro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    metrics['f1_macro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > Config.THRESHOLDS['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > Config.THRESHOLDS['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    train()

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.exists(f"{Config.SAVE_PATH}2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.exists(Config.SAVE_PATH"):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


SyntaxError: unterminated string literal (detected at line 326) (2590137000.py, line 326)

In [None]:
# V2 modelo jerárquico
# # ====================

import os
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, AutoConfig
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_recall_curve
from tqdm import tqdm


# ====================
#  Modelo v2
# ====================
class HierarchicalBERTv2(torch.nn.Module):
    def __init__(self, num_parents, num_children):
        super().__init__()
        config = AutoConfig.from_pretrained(Config.MODEL_NAME, output_hidden_states=True)
        self.bert = AutoModel.from_pretrained(Config.MODEL_NAME, config=config)

        hidden_size = self.bert.config.hidden_size  # This will be 768 for base models

        self.parent_classifier = torch.nn.Linear(hidden_size, num_parents)
        self.child_classifier = torch.nn.Linear(hidden_size + num_parents, num_children)
        self.dropout = torch.nn.Dropout(self.bert.config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        if Config.USE_FEATURE_PYRAMID:
            # Combine last 3 layers' [CLS] embeddings
            hidden_states = outputs.hidden_states[-3:]  # Get last 3 layers
            # Stack [CLS] embeddings (shape: [3, batch_size, hidden_size])
            pooled = torch.stack([state[:, 0] for state in hidden_states])
            # Weighted combination of layers (weights should sum to 1)
            pooled = torch.einsum('lbd,l->bd', pooled,
                                torch.tensor(Config.FEATURE_LAYER_WEIGHTS).to(pooled.device))
        else:
            pooled = outputs.last_hidden_state[:, 0, :]

        pooled = self.dropout(pooled)

        # Jerarquía de clasificación
        parent_logits = self.parent_classifier(pooled)
        parent_probs = torch.sigmoid(parent_logits)
        child_input = torch.cat([pooled, parent_probs], dim=1)
        child_logits = self.child_classifier(child_input)

        return parent_logits, child_logits, pooled

# ====================
#  FUNCIÓN DE PÉRDIDA MEJORADA
# ====================
def hierarchical_lossv2(parent_logits, child_logits,
                     parent_labels, child_labels,
                     parent_weights, child_weights):

    loss_parent = F.binary_cross_entropy_with_logits(
        parent_logits,
        parent_labels,
        pos_weight=parent_weights
    )

    loss_child = F.binary_cross_entropy_with_logits(
        child_logits,
        child_labels,
        pos_weight=child_weights
    )

    return (Config.HIERARCHICAL_WEIGHTS['parent'] * loss_parent +
            Config.HIERARCHICAL_WEIGHTS['child'] * loss_child)

# ====================
#  AJUSTE DINÁMICO DE UMBRALES
# ====================
def calculate_optimal_thresholds(y_true, y_probs):
    thresholds = {}
    for i in range(y_probs.shape[1]):
        if np.sum(y_true[:, i]) > 0:  # Solo clases presentes
            precision, recall, threshs = precision_recall_curve(y_true[:, i], y_probs[:, i])
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            best_idx = np.nanargmax(f1_scores)
            thresholds[i] = threshs[best_idx]
    return thresholds

# ====================
#  Dataset
# ====================
class HierarchicalMedicalDataset(Dataset):
    def __init__(self, df, tokenizer, mlb_parent, mlb_child):
        self.texts = df['text'].tolist()
        self.tokenizer = tokenizer
        self.examples = []

        # Procesar etiquetas
        self.parent_labels = []
        self.child_labels = []

        for codes in df['labels'].apply(eval): # FIXME: Unsafe eval
            parents, children = set(), set()
            for code in codes:
                code = code.strip().upper()
                levels = parse_code(code)
                if len(levels) >= 1: parents.add(levels[0])
                if len(levels) >= 2: children.add(levels[1])

            self.parent_labels.append(mlb_parent.transform([parents])[0])
            self.child_labels.append(mlb_child.transform([children])[0])

        for idx in range(len(self.texts)):
            encoding = self.tokenizer(
                self.texts[idx],
                max_length=Config.MAX_LENGTH,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            self.examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'parent_labels': torch.FloatTensor(self.parent_labels[idx]),
                'child_labels': torch.FloatTensor(self.child_labels[idx]),
            })

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.examples[idx]

# ====================
#  ENTRENAMIENTO
# ====================
def train(best_thresholds=Config.THRESHOLDS):
    epochs_without_improvement = 0
    early_stop = False
    best_f1 = 0

    # Cargar datos
    train_df = pd.read_csv(Config.DATA_PATHS['train'])
    val_df = pd.read_csv(Config.DATA_PATHS['val'])

    # Construir binarizadores
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Preparar datasets

    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        tokenizer.save_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Created new tokenizer")

    train_dataset = HierarchicalMedicalDataset(train_df, tokenizer, mlb_parent, mlb_child)
    val_dataset = HierarchicalMedicalDataset(val_df, tokenizer, mlb_parent, mlb_child)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.VAL_BATCH_SIZE)

    # Modelo y optimizador
    model = HierarchicalBERTv2(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    # Load best model if available
    if os.path.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.path.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")
    else:
        print("Starting training from scratch")

    optimizer = AdamW(model.parameters(), lr=Config.LEARNING_RATE)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps = Config.WARMUP_EPOCHS * len(train_loader),
        num_training_steps = Config.EPOCHS * len(train_loader)
    )

    # Bucle de entrenamiento
    scaler = torch.amp.GradScaler('cuda', enabled=Config.USE_FP16)

    # Calcular pesos de clases
    parent_counts = np.sum(train_dataset.parent_labels, axis=0)
    parent_weights = (len(train_dataset) - parent_counts) / (parent_counts + Config.CLASS_WEIGHT_SMOOTHING)
    parent_weights = torch.tensor(parent_weights).to(device)

    child_counts = np.sum(train_dataset.child_labels, axis=0)
    child_weights = (len(train_dataset) - child_counts) / (child_counts + Config.CLASS_WEIGHT_SMOOTHING)
    child_weights = torch.tensor(child_weights).to(device)

    for epoch in range(Config.EPOCHS):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        total_loss = 0

        # Ajuste periódico de umbrales
        if (epoch + 1) % Config.THRESHOLD_TUNING_INTERVAL == 0:
            val_probs, val_labels = get_validation_probabilities(model, val_loader, device)

            # Calcular mejores umbrales por clase
            parent_thresholds = calculate_optimal_thresholds(
                val_labels['parent'], val_probs['parent']
            )
            child_thresholds = calculate_optimal_thresholds(
                val_labels['child'], val_probs['child']
            )

            # Actualizar umbrales globales
            best_thresholds['parent'] = np.mean(list(parent_thresholds.values()))
            best_thresholds['child'] = np.mean(list(child_thresholds.values()))
            print(f"Nuevos umbrales: Parent={best_thresholds['parent']:.3f}, Child={best_thresholds['child']:.3f}")

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

        for step, batch in enumerate(progress_bar):
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            parent_labels = batch['parent_labels'].to(device)
            child_labels = batch['child_labels'].to(device)

            with torch.amp.autocast('cuda', enabled=Config.USE_FP16):
                outputs = model(**inputs)

                # Línea corregida
                loss = hierarchical_loss(
                    outputs[0],  # parent_logits
                    outputs[1],  # child_logits
                    parent_labels,
                    child_labels
                )

                loss = loss / Config.GRADIENT_ACCUMULATION_STEPS
                scaler.scale(loss).backward()

                if (step + 1) % Config.GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scheduler.step()
                    optimizer.zero_grad()
                    scaler.update()

                total_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0])

        # Validación
        val_metrics = evaluate(model, val_loader, device, mlb_parent, mlb_child, best_thresholds)

        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2E}")
        # Store metrics

        loss_metric = val_metrics['f1_macro']
        if epoch == 0:
            best_f1 = loss_metric

        if loss_metric > (best_f1 + Config.IMPROVEMENT_MARGIN):
            print(f"Saving best model... {best_f1:.5f} -> {loss_metric:.5f}")
            torch.save(model.state_dict(), f"{Config.SAVE_PATH}_3")
            best_f1 = loss_metric
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= Config.EARLY_STOP_PATIENCE:
                early_stop = True

        metrics_data = {
            'epoch': epoch + 1,
            'loss': total_loss/len(train_loader),
            'f1_micro': val_metrics['f1_micro'],
            'f1_macro': val_metrics['f1_macro'],
            'f1_micro_parent': val_metrics['f1_micro_parent'],
            'f1_macro_parent': val_metrics['f1_macro_parent'],
            'f1_micro_child': val_metrics['f1_micro_child'],
            'f1_macro_child': val_metrics['f1_macro_child'],
            'lr': scheduler.get_last_lr()[0],
            'epochs_without_improvement': epochs_without_improvement,
            'parent_threshold': best_thresholds['parent'],
            'child_threshold': best_thresholds['child']
        }

        # Write metrics to CSV
        metrics_df = pd.DataFrame([metrics_data])
        if epoch == 0:
            metrics_df.to_csv('training_metrics.csv', mode='a', index=False)
        else:
            metrics_df.to_csv('training_metrics.csv', mode='a', header=False, index=False)

        print(f"F1 Validation | Micro: {val_metrics['f1_micro']:.5f} | Macro: {val_metrics['f1_macro']:.5f} | Best: {best_f1:.5f} | Epochs without improvement: {epochs_without_improvement + 1}")

# ====================
#  FUNCIONES AUXILIARES
# ====================
def get_validation_probabilities(model, dataloader, device):
    model.eval()
    parent_probs, child_probs = [], []
    parent_labels, child_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device)
            }
            p_logits, c_logits, _ = model(**inputs)

            parent_probs.append(torch.sigmoid(p_logits).cpu().numpy())
            child_probs.append(torch.sigmoid(c_logits).cpu().numpy())

            parent_labels.append(batch['parent_labels'].numpy())
            child_labels.append(batch['child_labels'].numpy())

    return {
        'parent': np.concatenate(parent_probs),
        'child': np.concatenate(child_probs)
    }, {
        'parent': np.concatenate(parent_labels),
        'child': np.concatenate(child_labels)
    }


# ====================
#  EVALUACIÓN
# ====================
def evaluate(model, dataloader, device, mlb_parent, mlb_child, thresholds):
    model.eval()
    parent_preds_all = []
    child_preds_all = []
    parent_labels_all = []
    child_labels_all = []

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                k: v.to(device)
                for k, v in batch.items()
                if k in ['input_ids', 'attention_mask']
            }

            # Get labels
            parent_labels = batch['parent_labels'].numpy()
            child_labels = batch['child_labels'].numpy()

            parent_logits, child_logits, pooled = model(**inputs)

            # Convert to predictions
            parent_preds = (torch.sigmoid(parent_logits).cpu().numpy() > thresholds['parent']).astype(int)
            child_preds = (torch.sigmoid(child_logits).cpu().numpy() > thresholds['child']).astype(int)

            # Append to lists
            parent_preds_all.extend(parent_preds)
            child_preds_all.extend(child_preds)
            parent_labels_all.extend(parent_labels)
            child_labels_all.extend(child_labels)

    # Convert to numpy arrays
    parent_preds_all = np.array(parent_preds_all)
    child_preds_all = np.array(child_preds_all)
    parent_labels_all = np.array(parent_labels_all)
    child_labels_all = np.array(child_labels_all)

    # Print example comparison for parent level
    if len(parent_labels_all) > 0:
        parent_true = np.array(mlb_parent.classes_)[parent_labels_all[0].astype(bool)]
        parent_pred = np.array(mlb_parent.classes_)[parent_preds_all[0].astype(bool)]
        common_labels = len(set(parent_true) & set(parent_pred))
        total_labels = len(set(parent_true))
        accuracy_parent = common_labels / total_labels if total_labels > 0 else 0

        child_true = np.array(mlb_child.classes_)[child_labels_all[0].astype(bool)]
        child_pred = np.array(mlb_child.classes_)[child_preds_all[0].astype(bool)]
        common_labels = len(set(child_true) & set(child_pred))
        total_labels = len(set(child_true))
        accuracy_child = common_labels / total_labels if total_labels > 0 else 0

        print("Expected parent labels:", sorted(parent_true))
        print("Predicted parent labels:", sorted(parent_pred))
        print("Expected child labels:", sorted(child_true))
        print("Predicted child labels:", sorted(child_pred))

        print(f"Percentage of correct parent labels: {accuracy_parent:.2%} | {accuracy_child:.2%}")

    # Calculate F1 scores for each level
    metrics = {
        'f1_micro_parent': f1_score(parent_labels_all, parent_preds_all, average='micro' , zero_division=0),
        'f1_macro_parent': f1_score(parent_labels_all, parent_preds_all, average='macro', zero_division=0),
        'f1_micro_child': f1_score(child_labels_all, child_preds_all, average='micro', zero_division=0),
        'f1_macro_child': f1_score(child_labels_all, child_preds_all, average='macro', zero_division=0)
    }

    # Calculate weighted average F1 scores
    metrics['f1_micro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_micro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_micro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    metrics['f1_macro'] = (
        Config.HIERARCHICAL_WEIGHTS['parent'] * metrics['f1_macro_parent'] +
        Config.HIERARCHICAL_WEIGHTS['child'] * metrics['f1_macro_child']
    ) / sum(Config.HIERARCHICAL_WEIGHTS.values())

    return metrics

# ====================
#  PREDICCIÓN
# ====================
def predict(text, model, tokenizer, mlb_parent, mlb_child, device, thresholds):
    encoding = tokenizer(
        text,
        max_length=Config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    ).to(device)

    with torch.no_grad():
        parent_logits, child_logits = model(**encoding)

    # Obtener predicciones
    parent_probs = torch.sigmoid(parent_logits).cpu().numpy()
    child_probs = torch.sigmoid(child_logits).cpu().numpy()

    # Decodificar etiquetas
    parent_preds = mlb_parent.inverse_transform((parent_probs > thresholds['parent']).astype(int))
    child_preds = mlb_child.inverse_transform((child_probs > thresholds['child']).astype(int))

    # Combinar y asegurar jerarquía
    final_codes = set()
    for parent in parent_preds[0]:
        final_codes.add(parent)
        for child in child_preds[0]:
            if child.startswith(parent):
                final_codes.add(child)

    return sorted(final_codes)

# ====================
#  EJECUCIÓN
# ====================
if __name__ == "__main__":
    best_thresholds = Config.THRESHOLDS

    train(best_thresholds)

    # Cargar datos de test
    test_df = pd.read_csv(Config.DATA_PATHS['test'])
    mlb_parent, mlb_child = calculate_mlb_classes()

    # Cargar modelo
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    try:
        tokenizer = AutoTokenizer.from_pretrained(Config.SAVE_TOKENIZER_PATH)
        print("Loaded saved tokenizer")
    except:
        tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)
        print("Using default tokenizer")

    model = HierarchicalBERT(
        len(mlb_parent.classes_),
        len(mlb_child.classes_)
    ).to(device)

    if os.path.exists(f"{Config.SAVE_PATH}_2"):
        model.load_state_dict(torch.load(f"{Config.SAVE_PATH}_2"))
        print("Loaded best model - 2")
    elif os.path.exists(Config.SAVE_PATH):
        model.load_state_dict(torch.load(Config.SAVE_PATH))
        print("Loaded best model")

    # Evaluar en test
    test_dataset = HierarchicalMedicalDataset(test_df, tokenizer, mlb_parent, mlb_child)
    test_loader = DataLoader(test_dataset, batch_size=Config.TEST_BATCH_SIZE)

    test_metrics = evaluate(model, test_loader, device, mlb_parent, mlb_child)
    print("\nResultados en Test:")
    print(f"Micro F1: {test_metrics['f1_micro']:.4f}")
    print(f"Macro F1: {test_metrics['f1_macro']:.4f}")

    # Ejemplo de predicción
    sample_text = "Paciente con diabetes mellitus tipo 2 y complicaciones renales..."
    prediction = predict(sample_text, model, tokenizer, mlb_parent, mlb_child, device, best_thresholds)
    print("\nPredicción de ejemplo:", prediction)

    plot_metrics()


Loaded saved MLBs
Loaded saved tokenizer


Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded best model - 2


Epoch 1: 100%|██████████| 63/63 [00:11<00:00,  5.30it/s, loss=0.00711, lr=1e-5]   


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'D49', 'I10', 'N13', 'N28', 'N32', 'R10', 'R31', 'R58']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B99.9', 'C64.9', 'C67.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D64.9', 'D72.829', 'E11.9', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I48.91', 'I82.90', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'N13.30', 'N18.9', 'N20.0', 'N28.1', 'N28.89', 'N28.9', 'N32.9', 'N39.0', 'R10.9', 'R23.1', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R59.0', 'R59.9', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90']
Percentage of correct parent labels: 30.00% | 45.45%
Epoch 1 | Loss: 0.0056 | LR: 1.00E-05
F1 Validation | Micro: 0.27902 | Macro: 0.02075 | Best: 0.02075 | Epochs without improvement: 2


Epoch 2: 100%|██████████| 63/63 [00:11<00:00,  5.42it/s, loss=0.00477, lr=2e-5]   


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'C79', 'D49', 'I10', 'N13', 'N28', 'N32', 'R31', 'R58', 'R59']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B99.9', 'C64.9', 'C67.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D64.9', 'E11.9', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I48.91', 'I82.90', 'K65.9', 'M19.90', 'M54.5', 'N13.30', 'N18.9', 'N20.0', 'N28.1', 'N28.89', 'N28.9', 'N32.9', 'N39.0', 'R10.9', 'R19.00', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R59.0', 'R59.9', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90']
Percentage of correct parent labels: 30.00% | 45.45%
Epoch 2 | Loss: 0.0052 | LR: 2.00E-05
F1 Validation | Micro: 0.28390 | Macro: 0.02119 | Best: 0.02075 | Epochs without improvement: 3
Nuevos umbrales: Parent=0.036, Child=0.011


Epoch 3: 100%|██████████| 63/63 [00:11<00:00,  5.45it/s, loss=0.00471, lr=2e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'D49', 'I10', 'N13', 'N28', 'N32', 'N39', 'R31', 'R52', 'R58']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['A15.9', 'A41.9', 'B19.20', 'B99.9', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'F32.9', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K65.9', 'L53.9', 'L98.9', 'M19.90', 'M54.5', 'M89.9', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'Q61.3', 'R10.9', 'R11.10', 'R14.0', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R57.9', 'R59.0', 'R59.9', 'R60.0', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90', 'Z51.5', 'Z87.891', 'Z90.710']
P

Epoch 4: 100%|██████████| 63/63 [00:11<00:00,  5.35it/s, loss=0.00283, lr=2e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'C79', 'D49', 'I10', 'N13', 'N20', 'N23', 'N28', 'N32', 'N39', 'R10', 'R31', 'R52', 'R58', 'R59']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['A15.9', 'A41.9', 'B19.20', 'B96.20', 'B99.9', 'C16.9', 'C64.2', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.2', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'F32.9', 'I12.9', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K59.00', 'K65.9', 'L92.9', 'L98.9', 'M19.90', 'M54.5', 'M89.9', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'N50.89', 'N80.9', 'Q61.3', 'R10.13', 'R10.9', 'R11.10', 'R14.0', 'R18.8', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R35.0', 'R5

Epoch 5: 100%|██████████| 63/63 [00:11<00:00,  5.36it/s, loss=0.00487, lr=1.99e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'C79', 'D49', 'I10', 'M54', 'N13', 'N20', 'N23', 'N28', 'N32', 'N39', 'R10', 'R31', 'R52', 'R58', 'R59']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['A15.9', 'A41.9', 'B19.20', 'B96.20', 'B99.9', 'C64.2', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.2', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'F32.9', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K59.00', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'M89.9', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'N80.9', 'Q61.3', 'R10.13', 'R10.9', 'R11.10', 'R14.0', 'R16.0', 'R18.8', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R35.0', 'R50.9', 'R53.1', 'R53.8

Epoch 6: 100%|██████████| 63/63 [00:11<00:00,  5.46it/s, loss=0.00658, lr=1.99e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'D49', 'I10', 'N13', 'N28', 'N32', 'R31', 'R58']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B19.20', 'B99.9', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'Q61.3', 'R10.9', 'R11.10', 'R19.00', 'R19.7', 'R23.1', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R59.0', 'R59.9', 'R60.0', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90', 'Z51.5']
Percentage of correct parent labels: 30.00% | 45.45%
Epoch 6 | Loss: 0.0035 | LR: 1.99E-05
F1 Validation | Mi

Epoch 7: 100%|██████████| 63/63 [00:11<00:00,  5.37it/s, loss=0.00589, lr=1.99e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'C79', 'D49', 'I10', 'N13', 'N28', 'N32', 'R10', 'R31', 'R58', 'R59']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B19.20', 'B99.9', 'C64.2', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.2', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I12.9', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K59.00', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'M89.9', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'Q61.3', 'R10.13', 'R10.9', 'R11.10', 'R14.0', 'R18.8', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R35.0', 'R50.9', 'R53.1', 'R57.9', 'R59.0', 'R59.9', 'R60.0', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.9

Epoch 8: 100%|██████████| 63/63 [00:11<00:00,  5.41it/s, loss=0.0034, lr=1.99e-5] 


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'D49', 'I10', 'M54', 'N13', 'N28', 'N32', 'R10', 'R31', 'R52', 'R59']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B99.9', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K59.00', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'Q61.3', 'R10.13', 'R10.9', 'R11.10', 'R14.0', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R57.9', 'R59.0', 'R59.9', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90', 'Z51.5', 'Z90.710']
Percentage of correct parent labels: 30.00% | 54.

Epoch 9: 100%|██████████| 63/63 [00:11<00:00,  5.62it/s, loss=0.00407, lr=1.99e-5]


Expected parent labels: ['D18', 'K26', 'K59', 'N13', 'N20', 'N23', 'N28', 'N39', 'Q62', 'R31']
Predicted parent labels: ['C64', 'D49', 'I10', 'N13', 'N28', 'N32', 'N39', 'R31', 'R58']
Expected child labels: ['D18.09', 'K26.9', 'K59.00', 'N13.5', 'N20.0', 'N28.0', 'N28.89', 'N28.9', 'N39.0', 'Q62.11', 'R31.9']
Predicted child labels: ['B99.9', 'C64.2', 'C64.9', 'C67.9', 'C77.9', 'C78.7', 'C79.51', 'D49.0', 'D49.4', 'D49.59', 'D50.9', 'D64.9', 'D72.829', 'E11.9', 'E78.00', 'E78.5', 'E80.7', 'F17.200', 'F17.210', 'F17.290', 'I48.91', 'I82.90', 'J44.9', 'K52.9', 'K65.9', 'L98.9', 'M19.90', 'M54.5', 'N13.30', 'N18.9', 'N20.0', 'N26.1', 'N28.1', 'N28.89', 'N28.9', 'N32.89', 'N32.9', 'N39.0', 'N40.0', 'Q61.3', 'R10.9', 'R11.10', 'R19.00', 'R19.7', 'R23.1', 'R30.0', 'R31.0', 'R31.29', 'R31.9', 'R50.9', 'R53.1', 'R57.9', 'R59.0', 'R59.9', 'R60.9', 'R63.0', 'R63.4', 'R80.9', 'T14.8', 'T14.90', 'Z51.5', 'Z90.710']
Percentage of correct parent labels: 40.00% | 45.45%
Epoch 9 | Loss: 0.0030 | LR: 1

Epoch 10:  38%|███▊      | 24/63 [00:04<00:07,  5.43it/s, loss=0.00448, lr=1.99e-5]


KeyboardInterrupt: 